This is an automated email from the ASF dual-hosted git repository.

alsay pushed a commit to branch theta
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git

commit ed9ef30461aada12b44179f7ffce05cd02b92432
Author: AlexanderSaydakov <[email protected]>
AuthorDate: Mon Jan 6 16:22:39 2025 -0800

    theta prototype
---
 .../spark/sql/aggregate/ThetaSketchBuild.scala     | 137 +++++++++++++++++++++
 .../apache/spark/sql/aggregate/ThetaUnion.scala    | 133 ++++++++++++++++++++
 .../spark/sql/expressions/ThetaExpressions.scala   |  74 +++++++++++
 .../scala/org/apache/spark/sql/functions_ds.scala  |  52 +++++++-
 .../registrar/DatasketchesFunctionRegistry.scala   |  12 +-
 .../apache/spark/sql/types/ThetaSketchType.scala   |  37 ++++++
 .../spark/sql/types/ThetaSketchWrapper.scala       |  51 ++++++++
 .../scala/org/apache/spark/sql/ThetaTest.scala     | 132 ++++++++++++++++++++
 8 files changed, 619 insertions(+), 9 deletions(-)

diff --git 
a/src/main/scala/org/apache/spark/sql/aggregate/ThetaSketchBuild.scala 
b/src/main/scala/org/apache/spark/sql/aggregate/ThetaSketchBuild.scala
new file mode 100644
index 0000000..d1bb88d
--- /dev/null
+++ b/src/main/scala/org/apache/spark/sql/aggregate/ThetaSketchBuild.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.aggregate
+
+import org.apache.datasketches.theta.{UpdateSketch, SetOperation}
+import org.apache.spark.SparkUnsupportedOperationException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, ExpressionDescription, Literal}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
+import org.apache.spark.sql.catalyst.trees.BinaryLike
+import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, 
LongType, NumericType, FloatType, DoubleType, ThetaSketchWrapper, 
ThetaSketchType}
+
+/**
+ * The ThetaSketchBuild function creates a Theta sketch from a column of values
+ * which can be used to estimate distinct count.
+ *
+ * See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for 
more information.
+ * 
+ * @param child child expression, from which to build a sketch
+ * @param lgk the size-accraucy trade-off parameter for the sketch
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr, lgk) - Creates a Theta Sketch and returns the binary 
representation.
+      `lgk` (optional, default: 12) the size-accuracy trade-off parameter.""",
+  examples = """
+    Example:
+      > SELECT theta_sketch_get_estimate(_FUNC_(col, 12)) FROM VALUES (1), 
(2), (3), (4), (5) tab(col);
+       5.0
+  """,
+)
+// scalastyle:on line.size.limit
+case class ThetaSketchBuild(
+    left: Expression,
+    right: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends TypedImperativeAggregate[ThetaSketchWrapper]
+    with BinaryLike[Expression]
+    with ExpectsInputTypes {
+
+  lazy val lgk: Int = {
+    right.eval() match {
+      case null => 12
+      case lgk: Int => lgk
+      case _ => throw new SparkUnsupportedOperationException(
+        s"Unsupported input type ${right.dataType.catalogString}",
+        Map("dataType" -> dataType.toString))
+    }
+  }
+
+  def this(child: Expression) = this(child, Literal(12), 0, 0)
+  def this(child: Expression, lgk: Expression) = this(child, lgk, 0, 0)
+  def this(child: Expression, lgk: Int) = this(child, Literal(lgk), 0, 0)
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ThetaSketchBuild =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ThetaSketchBuild =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override protected def withNewChildrenInternal(newLeft: Expression, 
newRight: Expression): ThetaSketchBuild = {
+    copy(left = newLeft, right = newRight)
+  }
+
+  override def prettyName: String = "theta_sketch_build"
+
+  override def dataType: DataType = ThetaSketchType
+
+  override def nullable: Boolean = false
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, 
IntegerType, LongType, FloatType, DoubleType)
+
+  override def createAggregationBuffer(): ThetaSketchWrapper = new 
ThetaSketchWrapper(updateSketch = 
Some(UpdateSketch.builder().setLogNominalEntries(lgk).build()))
+
+  override def update(wrapper: ThetaSketchWrapper, input: InternalRow): 
ThetaSketchWrapper = {
+    val value = left.eval(input)
+    if (value != null) {
+      left.dataType match {
+        case DoubleType => 
wrapper.updateSketch.get.update(value.asInstanceOf[Double])
+        case FloatType => 
wrapper.updateSketch.get.update(value.asInstanceOf[Float])
+        case IntegerType => 
wrapper.updateSketch.get.update(value.asInstanceOf[Int])
+        case LongType => 
wrapper.updateSketch.get.update(value.asInstanceOf[Long])
+        case _ => throw new SparkUnsupportedOperationException(
+          s"Unsupported input type ${left.dataType.catalogString}",
+          Map("dataType" -> dataType.toString))
+      }
+    }
+    wrapper
+  }
+
+  override def merge(wrapper: ThetaSketchWrapper, other: ThetaSketchWrapper): 
ThetaSketchWrapper = {
+    if (other != null && !other.compactSketch.get.isEmpty) {
+      if (wrapper.union.isEmpty) {
+        wrapper.union = 
Some(SetOperation.builder().setLogNominalEntries(lgk).buildUnion)
+        if (wrapper.compactSketch.isDefined) {
+          wrapper.union.get.union(wrapper.compactSketch.get)
+          wrapper.compactSketch = None
+        }
+      }
+      wrapper.union.get.union(other.compactSketch.get)
+    }
+    wrapper
+  }
+
+  override def eval(wrapper: ThetaSketchWrapper): Any = {
+    if (wrapper == null || wrapper.union.isEmpty) {
+      null
+    } else {
+      wrapper.union.get.getResult.toByteArrayCompressed
+    }
+  }
+
+  override def serialize(wrapper: ThetaSketchWrapper): Array[Byte] = {
+    ThetaSketchType.serialize(wrapper)
+  }
+
+  override def deserialize(bytes: Array[Byte]): ThetaSketchWrapper = {
+    ThetaSketchType.deserialize(bytes)
+  }
+}
diff --git a/src/main/scala/org/apache/spark/sql/aggregate/ThetaUnion.scala 
b/src/main/scala/org/apache/spark/sql/aggregate/ThetaUnion.scala
new file mode 100644
index 0000000..29dd408
--- /dev/null
+++ b/src/main/scala/org/apache/spark/sql/aggregate/ThetaUnion.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.aggregate
+
+import org.apache.datasketches.memory.Memory
+import org.apache.datasketches.theta.{Sketch, SetOperation}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, ExpressionDescription, Literal}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
+import org.apache.spark.sql.catalyst.trees.BinaryLike
+import org.apache.spark.sql.types.{AbstractDataType, DataType, 
ThetaSketchWrapper, ThetaSketchType}
+import org.apache.spark.SparkUnsupportedOperationException
+
+/**
+ * Theta Union operation.
+ *
+ * See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for 
more information.
+ *
+ * @param child child expression, on which to perform the union operation
+ * @param lgk the size-accraucy trade-off parameter for the sketch
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr, lgk) - Performs Theta Union operation and returns the result 
as Theta Sketch in binary form
+  """,
+  examples = """
+    Example:
+      > SELECT theta_sketch_get_estimate(_FUNC_(sketch)) FROM (SELECT 
theta_sketch_build(col) as sketch FROM VALUES (1), (2), (3) tab(col) UNION ALL 
SELECT theta_sketch_build(col) as sketch FROM VALUES (3), (4), (5) tab(col));
+       5.0
+  """
+)
+// scalastyle:on line.size.limit
+case class ThetaUnion(
+    left: Expression,
+    right: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends TypedImperativeAggregate[ThetaSketchWrapper]
+    with BinaryLike[Expression]
+    with ExpectsInputTypes {
+
+  lazy val lgk: Int = {
+    right.eval() match {
+      case null => 12
+      case lgk: Int => lgk
+      case _ => throw new SparkUnsupportedOperationException(
+        s"Unsupported input type ${right.dataType.catalogString}",
+        Map("dataType" -> dataType.toString))
+    }
+  }
+
+  def this(left: Expression, right: Expression) = this(left, right, 0, 0)
+  def this(left: Expression) = this(left, Literal(12), 0, 0)
+//  def this(child: Expression, lgk: Expression) = this(child, lgk, 0, 0)
+//  def this(child: Expression, lgk: Int) = this(child, Literal(lgk), 0, 0)
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ThetaUnion =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ThetaUnion =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override protected def withNewChildrenInternal(newLeft: Expression, 
newRight: Expression): ThetaUnion = {
+    copy(left = newLeft, right = newRight)
+  }
+
+  // overrides for TypedImperativeAggregate
+  override def prettyName: String = "theta_union"
+  override def dataType: DataType = ThetaSketchType
+  override def nullable: Boolean = false
+
+  // TODO: refine this?
+  override def inputTypes: Seq[AbstractDataType] = Seq(ThetaSketchType)
+
+  override def createAggregationBuffer(): ThetaSketchWrapper = new 
ThetaSketchWrapper(union = 
Some(SetOperation.builder().setLogNominalEntries(lgk).buildUnion))
+
+  override def update(wrapper: ThetaSketchWrapper, input: InternalRow): 
ThetaSketchWrapper = {
+    val bytes = left.eval(input)
+    if (bytes != null) {
+      left.dataType match {
+        case ThetaSketchType =>
+          
wrapper.union.get.union(Sketch.wrap(Memory.wrap(bytes.asInstanceOf[Array[Byte]])))
+        case _ => throw new SparkUnsupportedOperationException(
+          s"Unsupported input type ${left.dataType.catalogString}",
+          Map("dataType" -> dataType.toString))
+      }
+    }
+    wrapper
+  }
+
+  override def merge(wrapper: ThetaSketchWrapper, other: ThetaSketchWrapper): 
ThetaSketchWrapper = {
+    if (other != null && !other.compactSketch.get.isEmpty) {
+      if (wrapper.union.isEmpty) {
+        wrapper.union = 
Some(SetOperation.builder().setLogNominalEntries(lgk).buildUnion)
+        if (wrapper.compactSketch.isDefined) {
+          wrapper.union.get.union(wrapper.compactSketch.get)
+          wrapper.compactSketch = None
+        }
+      }
+      wrapper.union.get.union(other.compactSketch.get)
+    }
+    wrapper
+  }
+
+  override def eval(wrapper: ThetaSketchWrapper): Any = {
+    wrapper.union.get.getResult.toByteArrayCompressed
+  }
+
+  override def serialize(wrapper: ThetaSketchWrapper): Array[Byte] = {
+    ThetaSketchType.serialize(wrapper)
+  }
+
+  override def deserialize(bytes: Array[Byte]): ThetaSketchWrapper = {
+    ThetaSketchType.deserialize(bytes)
+  }
+}
diff --git 
a/src/main/scala/org/apache/spark/sql/expressions/ThetaExpressions.scala 
b/src/main/scala/org/apache/spark/sql/expressions/ThetaExpressions.scala
new file mode 100644
index 0000000..37709a7
--- /dev/null
+++ b/src/main/scala/org/apache/spark/sql/expressions/ThetaExpressions.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.datasketches.memory.Memory
+import org.apache.datasketches.theta.Sketch
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, 
ExpectsInputTypes, UnaryExpression, BinaryExpression}
+import org.apache.spark.sql.catalyst.expressions.NullIntolerant
+import org.apache.spark.sql.catalyst.expressions.ExpressionDescription
+import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeBlock, 
CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.types.{AbstractDataType, DataType, ArrayType, 
DoubleType, ThetaSketchType}
+
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr) - Returns distinct count estimate from a given sketch
+  """,
+  examples = """
+    Example:
+      > SELECT _FUNC_(theta_sketch_build(col)) FROM VALUES (1), (2), (3) 
tab(col);
+       3.0
+  """
+)
+case class ThetaSketchGetEstimate(child: Expression)
+ extends UnaryExpression
+ with ExpectsInputTypes
+ with NullIntolerant {
+
+  override protected def withNewChildInternal(newChild: Expression): 
ThetaSketchGetEstimate = {
+    copy(child = newChild)
+  }
+
+  override def prettyName: String = "theta_sketch_get_estimate"
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(ThetaSketchType)
+
+  override def dataType: DataType = DoubleType
+
+  override def nullSafeEval(input: Any): Any = {
+    Sketch.wrap(Memory.wrap(input.asInstanceOf[Array[Byte]])).getEstimate
+  }
+
+  override protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f: 
String => String): ExprCode = {
+    val childEval = child.genCode(ctx)
+    val sketch = ctx.freshName("sketch")
+    val code = s"""
+      ${childEval.code}
+      final org.apache.datasketches.theta.Sketch $sketch = 
org.apache.spark.sql.types.ThetaSketchWrapper.wrapAsReadOnlySketch(${childEval.value});
+      final double ${ev.value} = $sketch.getEstimate();
+    """
+    ev.copy(code = CodeBlock(Seq(code), Seq.empty), isNull = childEval.isNull)
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, c => s"($c)")
+  }
+}
diff --git a/src/main/scala/org/apache/spark/sql/functions_ds.scala 
b/src/main/scala/org/apache/spark/sql/functions_ds.scala
index 7c2a96e..f30c452 100644
--- a/src/main/scala/org/apache/spark/sql/functions_ds.scala
+++ b/src/main/scala/org/apache/spark/sql/functions_ds.scala
@@ -17,15 +17,14 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.functions.lit
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
 
-import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.aggregate.{KllDoublesSketchAgg, KllDoublesMergeAgg}
+import org.apache.spark.sql.aggregate.{ThetaSketchBuild, ThetaUnion}
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
 import org.apache.spark.sql.expressions._
-import org.apache.spark.sql.catalyst.expressions.Literal
-import org.apache.spark.sql.types.ArrayType
-import org.apache.spark.sql.types.DoubleType
+import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.types.{ArrayType, DoubleType}
 
 // this class defines and maps all the variants of each function invocation, 
analagous
 // to the functions object in org.apache.spark.sql.functions
@@ -152,4 +151,45 @@ object functions_ds {
     kll_get_cdf(Column(columnName), splitPoints)
   }
 
+  // Theta
+
+  def theta_sketch_build(column: Column, lgk: Int): Column = 
withAggregateFunction {
+    new ThetaSketchBuild(column.expr, lgk)
+  }
+
+  def theta_sketch_build(columnName: String, lgk: Int): Column = {
+    theta_sketch_build(Column(columnName), lgk)
+  }
+
+  def theta_sketch_build(column: Column): Column = withAggregateFunction {
+    new ThetaSketchBuild(column.expr)
+  }
+
+  def theta_sketch_build(columnName: String): Column = {
+    theta_sketch_build(Column(columnName))
+  }
+
+  def theta_union(column: Column, lgk: Int): Column = withAggregateFunction {
+    new ThetaUnion(column.expr, lit(lgk).expr)
+  }
+
+  def theta_union(columnName: String, lgk: Int): Column = 
withAggregateFunction {
+    new ThetaUnion(Column(columnName).expr, lit(lgk).expr)
+  }
+
+  def theta_union(column: Column): Column = withAggregateFunction {
+    new ThetaUnion(column.expr)
+  }
+
+  def theta_union(columnName: String): Column = withAggregateFunction {
+    new ThetaUnion(Column(columnName).expr)
+  }
+
+  def theta_sketch_get_estimate(column: Column): Column = withExpr {
+    new ThetaSketchGetEstimate(column.expr)
+  }
+
+  def theta_sketch_get_estimate(columnName: String): Column = {
+    theta_sketch_get_estimate(Column(columnName))
+  }
 }
diff --git 
a/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala
 
b/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala
index 5ab1738..58eacee 100644
--- 
a/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala
+++ 
b/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala
@@ -27,8 +27,10 @@ import scala.reflect.ClassTag
 
 // DataSketches imports
 import org.apache.spark.sql.aggregate.{KllDoublesSketchAgg, KllDoublesMergeAgg}
-import org.apache.spark.sql.expressions.{KllGetMin, KllGetMax}
-import org.apache.spark.sql.expressions.KllGetPmfCdf
+import org.apache.spark.sql.expressions.{KllGetMin, KllGetMax, KllGetPmfCdf}
+
+import org.apache.spark.sql.aggregate.{ThetaSketchBuild, ThetaUnion}
+import org.apache.spark.sql.expressions.ThetaSketchGetEstimate
 
 // based on org.apache.spark.sql.catalyst.FunctionRegistry
 trait DatasketchesFunctionRegistry {
@@ -80,6 +82,10 @@ object DatasketchesFunctionRegistry extends 
DatasketchesFunctionRegistry {
     complexExpression[KllGetPmfCdf]("kll_get_cdf") { args: Seq[Expression] =>
       val isInclusive = if (args.length > 2) 
args(2).eval().asInstanceOf[Boolean] else true
       new KllGetPmfCdf(args(0), args(1), isInclusive = isInclusive, isPmf = 
false)
-    }
+    },
+
+    expression[ThetaSketchBuild]("theta_sketch_build"),
+    expression[ThetaUnion]("theta_union"),
+    expression[ThetaSketchGetEstimate]("theta_sketch_get_estimate"),
   )
 }
diff --git a/src/main/scala/org/apache/spark/sql/types/ThetaSketchType.scala 
b/src/main/scala/org/apache/spark/sql/types/ThetaSketchType.scala
new file mode 100644
index 0000000..e5a5e2c
--- /dev/null
+++ b/src/main/scala/org/apache/spark/sql/types/ThetaSketchType.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+class ThetaSketchType extends UserDefinedType[ThetaSketchWrapper] {
+  override def sqlType: DataType = DataTypes.BinaryType
+
+  override def serialize(wrapper: ThetaSketchWrapper): Array[Byte] = {
+    wrapper.serialize
+  }
+
+  override def deserialize(data: Any): ThetaSketchWrapper = {
+    val bytes = data.asInstanceOf[Array[Byte]]
+    ThetaSketchWrapper.deserialize(bytes)
+  }
+
+  override def userClass: Class[ThetaSketchWrapper] = 
classOf[ThetaSketchWrapper]
+
+  override def catalogString: String = "ThetaSketch"
+}
+
+case object ThetaSketchType extends ThetaSketchType
diff --git a/src/main/scala/org/apache/spark/sql/types/ThetaSketchWrapper.scala 
b/src/main/scala/org/apache/spark/sql/types/ThetaSketchWrapper.scala
new file mode 100644
index 0000000..86e3b89
--- /dev/null
+++ b/src/main/scala/org/apache/spark/sql/types/ThetaSketchWrapper.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.apache.datasketches.theta.{UpdateSketch, CompactSketch, Union}
+import org.apache.datasketches.memory.Memory
+
+@SQLUserDefinedType(udt = classOf[ThetaSketchType])
+class ThetaSketchWrapper(var updateSketch: Option[UpdateSketch] = None, var 
compactSketch: Option[CompactSketch] = None, var union: Option[Union] = None) {
+
+  def serialize: Array[Byte] = {
+    if (updateSketch.isDefined) return 
updateSketch.get.compact().toByteArrayCompressed
+    else if (compactSketch.isDefined) return 
compactSketch.get.toByteArrayCompressed
+    else if (union.isDefined) return union.get.getResult.toByteArrayCompressed
+    null
+  }
+
+  override def toString(): String = {
+    if (updateSketch.isDefined) return updateSketch.get.toString
+    else if (compactSketch.isDefined) return compactSketch.get.toString
+    else if (union.isDefined) return union.get.toString
+    ""
+  }
+}
+
+object ThetaSketchWrapper {
+  def deserialize(bytes: Array[Byte]): ThetaSketchWrapper = {
+    new ThetaSketchWrapper(compactSketch = 
Some(CompactSketch.heapify(Memory.wrap(bytes))))
+  }
+
+  // this can go away in favor of directly calling the Sketch.wrap
+  // from codegen once janino can generate java 8+ code
+  def wrapAsReadOnlySketch(bytes: Array[Byte]): CompactSketch = {
+    CompactSketch.wrap(Memory.wrap(bytes))
+  }
+}
diff --git a/src/test/scala/org/apache/spark/sql/ThetaTest.scala 
b/src/test/scala/org/apache/spark/sql/ThetaTest.scala
new file mode 100644
index 0000000..18fc021
--- /dev/null
+++ b/src/test/scala/org/apache/spark/sql/ThetaTest.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions_ds._
+import org.apache.spark.registrar.DatasketchesFunctionRegistry
+
+class ThetaTest extends SparkSessionManager {
+  import spark.implicits._
+
+  test("Theta Sketch build via Scala") {
+    val n = 100
+    val data = (for (i <- 1 to n) yield i).toDF("value")
+
+    val sketchDf = data.agg(theta_sketch_build("value").as("sketch"))
+    val result: Row = 
sketchDf.select(theta_sketch_get_estimate("sketch").as("estimate")).head
+
+    assert(result.getAs[Double]("estimate") == 100.0)
+  }
+
+  test("Theta Sketch build via SQL default lgk") {
+    DatasketchesFunctionRegistry.registerFunctions(spark)
+
+    val n = 100
+    val data = (for (i <- 1 to n) yield i).toDF("value")
+    data.createOrReplaceTempView("theta_input_table")
+
+    val df = spark.sql(s"""
+      SELECT
+        theta_sketch_get_estimate(theta_sketch_build(value)) AS estimate
+      FROM
+        theta_input_table
+    """)
+    assert(df.head.getAs[Double]("estimate") == 100.0)
+  }
+
+  test("Theta Sketch build via SQL with lgk") {
+    DatasketchesFunctionRegistry.registerFunctions(spark)
+
+    val n = 100
+    val data = (for (i <- 1 to n) yield i).toDF("value")
+    data.createOrReplaceTempView("theta_input_table")
+
+    val df = spark.sql(s"""
+      SELECT
+        theta_sketch_get_estimate(theta_sketch_build(value, 14)) AS estimate
+      FROM
+        theta_input_table
+    """)
+    assert(df.head.getAs[Double]("estimate") == 100.0)
+  }
+
+  test("Theta Union via Scala") {
+    val numGroups = 10
+    val numDistinct = 2000
+    val data = (for (i <- 1 to numDistinct) yield (i % numGroups, 
i)).toDF("group", "value")
+
+    val groupedDf = 
data.groupBy("group").agg(theta_sketch_build("value").as("sketch"))
+    val mergedDf = groupedDf.agg(theta_union("sketch").as("merged"))
+    val result: Row = 
mergedDf.select(theta_sketch_get_estimate("merged").as("estimate")).head
+    assert(result.getAs[Double]("estimate") == numDistinct)
+  }
+
+/*
+  test("Theta Union via SQL default lgk") {
+    val numGroups = 10
+    val numDistinct = 2000
+    val data = (for (i <- 1 to numDistinct) yield (i % numGroups, 
i)).toDF("group", "value")
+    data.createOrReplaceTempView("theta_input_table")
+
+    val groupedDf = spark.sql(s"""
+      SELECT
+        group,
+        theta_sketch_build(value) AS sketch
+      FROM
+        theta_input_table
+      GROUP BY
+        group
+    """)
+    groupedDf.createOrReplaceTempView("theta_sketch_table")
+
+    val mergedDf = spark.sql(s"""
+      SELECT
+        theta_sketch_get_estimate(theta_union(sketch)) AS estimate
+      FROM
+        theta_sketch_table
+    """)
+    assert(mergedDf.head.getAs[Double]("estimate") == numDistinct)
+  }
+
+  test("Theta Union via SQL with lgk") {
+    val numGroups = 10
+    val numDistinct = 2000
+    val data = (for (i <- 1 to numDistinct) yield (i % numGroups, 
i)).toDF("group", "value")
+    data.createOrReplaceTempView("theta_input_table")
+
+    val groupedDf = spark.sql(s"""
+      SELECT
+        group,
+        theta_sketch_build(value, 14) AS sketch
+      FROM
+        theta_input_table
+      GROUP BY
+        group
+    """)
+    groupedDf.createOrReplaceTempView("theta_sketch_table")
+
+    val mergedDf = spark.sql(s"""
+      SELECT
+        theta_sketch_get_estimate(theta_union(sketch, 14)) AS estimate
+      FROM
+        theta_sketch_table
+    """)
+    assert(mergedDf.head.getAs[Double]("estimate") == numDistinct)
+  }
+*/
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to