This is an automated email from the ASF dual-hosted git repository.

alsay pushed a commit to branch theta_params
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git


The following commit(s) were added to refs/heads/theta_params by this push:
     new ffe0066  theta functions with parameters
ffe0066 is described below

commit ffe00665819ea1a92cd4f0ddd11cacc9672c9fce
Author: AlexanderSaydakov <[email protected]>
AuthorDate: Mon Feb 24 20:30:28 2025 -0800

    theta functions with parameters
---
 .../datasketches/theta/ThetaFunctionRegistry.scala |  5 +-
 ...onRegistry.scala => ThetaSketchConstants.scala} | 18 +----
 .../theta/aggregate/ThetaSketchAggBuild.scala      | 39 +++++-----
 .../theta/aggregate/ThetaSketchAggUnion.scala      | 88 ++++++++++++++--------
 .../theta/expressions/ThetaExpressions.scala       | 73 ++++++++++++++++--
 .../spark/sql/datasketches/theta/functions.scala   | 28 +++++--
 .../spark/sql/datasketches/theta/ThetaTest.scala   | 65 ++++++++++++++--
 7 files changed, 230 insertions(+), 86 deletions(-)

diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaFunctionRegistry.scala
 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaFunctionRegistry.scala
index ddbe615..3ceeb3c 100644
--- 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaFunctionRegistry.scala
+++ 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaFunctionRegistry.scala
@@ -22,13 +22,14 @@ import 
org.apache.spark.sql.catalyst.expressions.{ExpressionInfo}
 
 import org.apache.spark.sql.datasketches.common.DatasketchesFunctionRegistry
 import org.apache.spark.sql.datasketches.theta.aggregate.{ThetaSketchAggBuild, 
ThetaSketchAggUnion}
-import 
org.apache.spark.sql.datasketches.theta.expressions.ThetaSketchGetEstimate
+import 
org.apache.spark.sql.datasketches.theta.expressions.{ThetaSketchGetEstimate, 
ThetaSketchToString}
 import org.apache.spark.sql.datasketches.common.DatasketchesFunctionRegistry
 
 object ThetaFunctionRegistry extends DatasketchesFunctionRegistry {
   override val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = 
Map(
     expression[ThetaSketchAggBuild]("theta_sketch_agg_build"),
     expression[ThetaSketchAggUnion]("theta_sketch_agg_union"),
-    expression[ThetaSketchGetEstimate]("theta_sketch_get_estimate")
+    expression[ThetaSketchGetEstimate]("theta_sketch_get_estimate"),
+    expression[ThetaSketchToString]("theta_sketch_to_string")
   )
 }
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaFunctionRegistry.scala
 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaSketchConstants.scala
similarity index 50%
copy from 
src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaFunctionRegistry.scala
copy to 
src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaSketchConstants.scala
index ddbe615..1e8befc 100644
--- 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaFunctionRegistry.scala
+++ 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaSketchConstants.scala
@@ -17,18 +17,6 @@
 
 package org.apache.spark.sql.datasketches.theta
 
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.expressions.{ExpressionInfo}
-
-import org.apache.spark.sql.datasketches.common.DatasketchesFunctionRegistry
-import org.apache.spark.sql.datasketches.theta.aggregate.{ThetaSketchAggBuild, 
ThetaSketchAggUnion}
-import 
org.apache.spark.sql.datasketches.theta.expressions.ThetaSketchGetEstimate
-import org.apache.spark.sql.datasketches.common.DatasketchesFunctionRegistry
-
-object ThetaFunctionRegistry extends DatasketchesFunctionRegistry {
-  override val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = 
Map(
-    expression[ThetaSketchAggBuild]("theta_sketch_agg_build"),
-    expression[ThetaSketchAggUnion]("theta_sketch_agg_union"),
-    expression[ThetaSketchGetEstimate]("theta_sketch_get_estimate")
-  )
-}
+final object ThetaSketchConstants {
+  final val DEFAULT_LG_K: Int = 12
+}
\ No newline at end of file
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
index 70d13e2..f6eaba5 100644
--- 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
+++ 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
@@ -17,21 +17,19 @@
 
 package org.apache.spark.sql.datasketches.theta.aggregate
 
-import org.apache.datasketches.theta.{UpdateSketch, SetOperation}
-import org.apache.spark.sql.datasketches.theta.types.{ThetaSketchType, 
ThetaSketchWrapper}
-
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, ExpressionDescription, Literal}
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
 import org.apache.spark.sql.catalyst.trees.QuaternaryLike
-import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, 
LongType, NumericType, FloatType, DoubleType}
+import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, 
LongType, NumericType, FloatType, DoubleType, StringType, TypeCollection}
+import org.apache.spark.unsafe.types.UTF8String
 
-import org.apache.datasketches.thetacommon.ThetaUtil.DEFAULT_UPDATE_SEED
-import org.apache.datasketches.common.ResizeFactor
+import 
org.apache.spark.sql.datasketches.theta.ThetaSketchConstants.DEFAULT_LG_K
+import org.apache.spark.sql.datasketches.theta.types.{ThetaSketchType, 
ThetaSketchWrapper}
 
-object ThetaSketchConstants {
-  final val DEFAULT_LG_K: Int = 12
-}
+import org.apache.datasketches.common.ResizeFactor
+import org.apache.datasketches.theta.{UpdateSketch, SetOperation}
+import org.apache.datasketches.thetacommon.ThetaUtil.DEFAULT_UPDATE_SEED
 
 /**
  * The ThetaSketchBuild function creates a Theta sketch from a column of values
@@ -39,17 +37,21 @@ object ThetaSketchConstants {
  *
  * See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for 
more information.
  *
- * @param child child expression, from which to build a sketch
- * @param lgk the size-accraucy trade-off parameter for the sketch
+ * @param input expression, from which to build a sketch
+ * @param lgK size-accraucy trade-off parameter for the sketch
+ * @param seed update seed for the sketch
+ * @param p initial sampling probability for the sketch
  */
 // scalastyle:off line.size.limit
 @ExpressionDescription(
   usage = """
-    _FUNC_(expr, lgk) - Creates a Theta Sketch and returns the binary 
representation.
-      `lgk` (optional, default: 12) the size-accuracy trade-off parameter.""",
+    _FUNC_(expr, lgK, seed, p) - Creates a Theta Sketch and returns the binary 
representation.
+      `lgK` (optional, default: 12) size-accuracy trade-off parameter.
+      `seed` (optional, default: 9001) update seed for the sketch.
+      `p` (optional, default: 1) initial sampling probability for the 
sketch.""",
   examples = """
     Example:
-      > SELECT theta_sketch_get_estimate(_FUNC_(col, 12)) FROM VALUES (1), 
(2), (3), (4), (5) tab(col);
+      > SELECT theta_sketch_get_estimate(_FUNC_(col)) FROM VALUES (1), (2), 
(3), (4), (5) tab(col);
        5.0
   """,
 )
@@ -67,7 +69,7 @@ case class ThetaSketchAggBuild(
 
   lazy val lgK: Int = {
     lgKExpr.eval() match {
-      case null => ThetaSketchConstants.DEFAULT_LG_K
+      case null => DEFAULT_LG_K
       case lgK: Int => lgK
       case _ => throw new IllegalArgumentException(
         s"Unsupported input type ${lgKExpr.dataType.catalogString}")
@@ -99,7 +101,7 @@ case class ThetaSketchAggBuild(
   def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression, 
pExpr: Expression) = this(inputExpr, lgKExpr, seedExpr, pExpr, 0, 0)
   def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression) = 
this(inputExpr, lgKExpr, seedExpr, Literal(1f))
   def this(inputExpr: Expression, lgKExpr: Expression) = this(inputExpr, 
lgKExpr, Literal(DEFAULT_UPDATE_SEED))
-  def this(inputExpr: Expression) = this(inputExpr, 
Literal(ThetaSketchConstants.DEFAULT_LG_K))
+  def this(inputExpr: Expression) = this(inputExpr, Literal(DEFAULT_LG_K))
 
   def this(inputExpr: Expression, lgK: Int) = this(inputExpr, Literal(lgK))
   def this(inputExpr: Expression, lgK: Int, seed: Long) = this(inputExpr, 
Literal(lgK), Literal(seed))
@@ -115,13 +117,13 @@ case class ThetaSketchAggBuild(
     copy(inputExpr = newFirst, lgKExpr = newSecond, seedExpr = newThird, pExpr 
= newFourth)
   }
 
-  override def prettyName: String = "theta_sketch_build"
+  override def prettyName: String = "theta_sketch_agg_build"
 
   override def dataType: DataType = ThetaSketchType
 
   override def nullable: Boolean = false
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, 
IntegerType, LongType, FloatType, DoubleType)
+  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(NumericType, StringType), IntegerType, LongType, FloatType)
 
   override def createAggregationBuffer(): ThetaSketchWrapper = new 
ThetaSketchWrapper(updateSketch
     = 
Some(UpdateSketch.builder().setLogNominalEntries(lgK).setSeed(seed).setP(p).build()))
@@ -134,6 +136,7 @@ case class ThetaSketchAggBuild(
         case FloatType => 
wrapper.updateSketch.get.update(value.asInstanceOf[Float])
         case IntegerType => 
wrapper.updateSketch.get.update(value.asInstanceOf[Int])
         case LongType => 
wrapper.updateSketch.get.update(value.asInstanceOf[Long])
+        case StringType => 
wrapper.updateSketch.get.update(value.asInstanceOf[UTF8String].toString)
         case _ => throw new IllegalArgumentException(
           s"Unsupported input type ${inputExpr.dataType.catalogString}")
       }
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggUnion.scala
 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggUnion.scala
index 6e067c9..fd8ce17 100644
--- 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggUnion.scala
+++ 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggUnion.scala
@@ -17,28 +17,34 @@
 
 package org.apache.spark.sql.datasketches.theta.aggregate
 
-import org.apache.datasketches.memory.Memory
-import org.apache.datasketches.theta.{Sketch, SetOperation}
-import org.apache.spark.sql.datasketches.theta.types.{ThetaSketchType, 
ThetaSketchWrapper}
-
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, ExpressionDescription, Literal}
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
-import org.apache.spark.sql.catalyst.trees.BinaryLike
-import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType}
+import org.apache.spark.sql.catalyst.trees.TernaryLike
+import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, 
LongType}
+
+import 
org.apache.spark.sql.datasketches.theta.ThetaSketchConstants.DEFAULT_LG_K
+import org.apache.spark.sql.datasketches.theta.types.{ThetaSketchType, 
ThetaSketchWrapper}
+
+import org.apache.datasketches.memory.Memory
+import org.apache.datasketches.theta.{Sketch, SetOperation}
+import org.apache.datasketches.thetacommon.ThetaUtil.DEFAULT_UPDATE_SEED
 
 /**
  * Theta Union operation.
  *
  * See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for 
more information.
  *
- * @param child child expression, on which to perform the union operation
- * @param lgk the size-accraucy trade-off parameter for the sketch
+ * @param input expression, on which to perform the union operation
+ * @param lgK size-accraucy trade-off parameter for the sketch
+ * @param seed update seed for the sketch
  */
 // scalastyle:off line.size.limit
 @ExpressionDescription(
   usage = """
-    _FUNC_(expr, lgk) - Performs Theta Union operation and returns the result 
as Theta Sketch in binary form
+    _FUNC_(expr, lgK, seed) - Performs Theta Union operation and returns the 
result as Theta Sketch in binary form
+      `lgK` (optional, default: 12) size-accuracy trade-off parameter.
+      `seed` (optional, default: 9001) update seed for the sketch.
   """,
   examples = """
     Example:
@@ -48,27 +54,43 @@ import org.apache.spark.sql.types.{AbstractDataType, 
DataType, IntegerType}
 )
 // scalastyle:on line.size.limit
 case class ThetaSketchAggUnion(
-    left: Expression,
-    right: Expression,
+    inputExpr: Expression,
+    lgKExpr: Expression,
+    seedExpr: Expression,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
   extends TypedImperativeAggregate[ThetaSketchWrapper]
-    with BinaryLike[Expression]
+    with TernaryLike[Expression]
     with ExpectsInputTypes {
 
-  lazy val lgk: Int = {
-    right.eval() match {
-      case null => 12
-      case lgk: Int => lgk
+  lazy val lgK: Int = {
+    lgKExpr.eval() match {
+      case null => DEFAULT_LG_K
+      case lgK: Int => lgK
       case _ => throw new IllegalArgumentException(
-        s"Unsupported input type ${right.dataType.catalogString}")
+        s"Unsupported input type ${lgKExpr.dataType.catalogString}")
     }
   }
 
-  def this(left: Expression, right: Expression) = this(left, right, 0, 0)
-  def this(left: Expression) = this(left, Literal(12), 0, 0)
-//  def this(child: Expression, lgk: Expression) = this(child, lgk, 0, 0)
-//  def this(child: Expression, lgk: Int) = this(child, Literal(lgk), 0, 0)
+  lazy val seed: Long = {
+    seedExpr.eval() match {
+      case null => DEFAULT_UPDATE_SEED
+      case seed: Long => seed
+      case _ => throw new IllegalArgumentException(
+        s"Unsupported input type ${seedExpr.dataType.catalogString}")
+    }
+  }
+
+  override def first: Expression = inputExpr
+  override def second: Expression = lgKExpr
+  override def third: Expression = seedExpr
+
+  def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression) = 
this(inputExpr, lgKExpr, seedExpr, 0, 0)
+  def this(inputExpr: Expression, lgKExpr: Expression) = this(inputExpr, 
lgKExpr, Literal(DEFAULT_UPDATE_SEED))
+  def this(inputExpr: Expression) = this(inputExpr, Literal(DEFAULT_LG_K))
+
+  def this(inputExpr: Expression, lgK: Int) = this(inputExpr, Literal(lgK))
+  def this(inputExpr: Expression, lgK: Int, seed: Long) = this(inputExpr, 
Literal(lgK), Literal(seed))
 
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ThetaSketchAggUnion =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -76,37 +98,37 @@ case class ThetaSketchAggUnion(
   override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ThetaSketchAggUnion =
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-  override protected def withNewChildrenInternal(newLeft: Expression, 
newRight: Expression): ThetaSketchAggUnion = {
-    copy(left = newLeft, right = newRight)
+  override protected def withNewChildrenInternal(newInputExpr: Expression, 
newLgKExpr: Expression, newSeedExpr: Expression): ThetaSketchAggUnion = {
+    copy(inputExpr = newInputExpr, lgKExpr = newLgKExpr, seedExpr = 
newSeedExpr)
   }
 
   // overrides for TypedImperativeAggregate
-  override def prettyName: String = "theta_union"
+  override def prettyName: String = "theta_sketch_agg_union"
   override def dataType: DataType = ThetaSketchType
   override def nullable: Boolean = false
 
-  // TODO: refine this?
-  override def inputTypes: Seq[AbstractDataType] = Seq(ThetaSketchType, 
IntegerType)
+  override def inputTypes: Seq[AbstractDataType] = Seq(ThetaSketchType, 
IntegerType, LongType)
 
-  override def createAggregationBuffer(): ThetaSketchWrapper = new 
ThetaSketchWrapper(union = 
Some(SetOperation.builder().setLogNominalEntries(lgk).buildUnion))
+  override def createAggregationBuffer(): ThetaSketchWrapper = new 
ThetaSketchWrapper(union
+    = 
Some(SetOperation.builder().setLogNominalEntries(lgK).setSeed(seed).buildUnion()))
 
   override def update(wrapper: ThetaSketchWrapper, input: InternalRow): 
ThetaSketchWrapper = {
-    val bytes = left.eval(input)
+    val bytes = inputExpr.eval(input)
     if (bytes != null) {
-      left.dataType match {
+      inputExpr.dataType match {
         case ThetaSketchType =>
           
wrapper.union.get.union(Sketch.wrap(Memory.wrap(bytes.asInstanceOf[Array[Byte]])))
         case _ => throw new IllegalArgumentException(
-          s"Unsupported input type ${left.dataType.catalogString}")
+          s"Unsupported input type ${inputExpr.dataType.catalogString}")
       }
     }
     wrapper
   }
 
   override def merge(wrapper: ThetaSketchWrapper, other: ThetaSketchWrapper): 
ThetaSketchWrapper = {
-    if (other != null && !other.compactSketch.get.isEmpty) {
+    if (other != null && !other.compactSketch.get.isEmpty()) {
       if (wrapper.union.isEmpty) {
-        wrapper.union = 
Some(SetOperation.builder().setLogNominalEntries(lgk).buildUnion)
+        wrapper.union = 
Some(SetOperation.builder().setLogNominalEntries(lgK).setSeed(seed).buildUnion())
         if (wrapper.compactSketch.isDefined) {
           wrapper.union.get.union(wrapper.compactSketch.get)
           wrapper.compactSketch = None
@@ -118,7 +140,7 @@ case class ThetaSketchAggUnion(
   }
 
   override def eval(wrapper: ThetaSketchWrapper): Any = {
-    wrapper.union.get.getResult.toByteArrayCompressed
+    wrapper.union.get.getResult.toByteArrayCompressed()
   }
 
   override def serialize(wrapper: ThetaSketchWrapper): Array[Byte] = {
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala
 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala
index c426ee4..c77fcd4 100644
--- 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala
+++ 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala
@@ -17,15 +17,17 @@
 
 package org.apache.spark.sql.datasketches.theta.expressions
 
-import org.apache.datasketches.memory.Memory
-import org.apache.datasketches.theta.Sketch
-import org.apache.spark.sql.datasketches.theta.types.ThetaSketchType
-
 import org.apache.spark.sql.catalyst.expressions.{Expression, 
ExpectsInputTypes, UnaryExpression}
 import org.apache.spark.sql.catalyst.expressions.NullIntolerant
 import org.apache.spark.sql.catalyst.expressions.ExpressionDescription
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodeBlock, 
CodegenContext, ExprCode}
-import org.apache.spark.sql.types.{AbstractDataType, DataType, DoubleType}
+import org.apache.spark.sql.types.{AbstractDataType, DataType, DoubleType, 
StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+import org.apache.spark.sql.datasketches.theta.types.ThetaSketchType
+
+import org.apache.datasketches.memory.Memory
+import org.apache.datasketches.theta.Sketch
 
 @ExpressionDescription(
   usage = """
@@ -33,7 +35,7 @@ import org.apache.spark.sql.types.{AbstractDataType, 
DataType, DoubleType}
   """,
   examples = """
     Example:
-      > SELECT _FUNC_(theta_sketch_build(col)) FROM VALUES (1), (2), (3) 
tab(col);
+      > SELECT _FUNC_(theta_sketch_agg_build(col)) FROM VALUES (1), (2), (3) 
tab(col);
        3.0
   """
 )
@@ -53,7 +55,7 @@ case class ThetaSketchGetEstimate(child: Expression)
   override def dataType: DataType = DoubleType
 
   override def nullSafeEval(input: Any): Any = {
-    Sketch.wrap(Memory.wrap(input.asInstanceOf[Array[Byte]])).getEstimate
+    Sketch.wrap(Memory.wrap(input.asInstanceOf[Array[Byte]])).getEstimate()
   }
 
   override protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f: 
String => String): ExprCode = {
@@ -71,3 +73,60 @@ case class ThetaSketchGetEstimate(child: Expression)
     nullSafeCodeGen(ctx, ev, c => s"($c)")
   }
 }
+
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr) - Returns a summary string that represents the state of the 
given sketch
+  """,
+  examples = """
+    Example:
+      > SELECT _FUNC_(theta_sketch_agg_build(col)) FROM VALUES (1), (2), (3) 
tab(col);
+      ### HeapCompactSketch SUMMARY: 
+         Estimate                : 3.0
+         Upper Bound, 95% conf   : 3.0
+         Lower Bound, 95% conf   : 3.0
+         Theta (double)          : 1.0
+         Theta (long)            : 9223372036854775807
+         Theta (long) hex        : 7fffffffffffffff
+         EstMode?                : false
+         Empty?                  : false
+         Ordered?                : true
+         Retained Entries        : 3
+         Seed Hash               : 93cc | 37836
+      ### END SKETCH SUMMARY
+  """
+)
+case class ThetaSketchToString(child: Expression)
+ extends UnaryExpression
+ with ExpectsInputTypes
+ with NullIntolerant {
+
+  override protected def withNewChildInternal(newChild: Expression): 
ThetaSketchToString = {
+    copy(child = newChild)
+  }
+
+  override def prettyName: String = "theta_sketch_to_string"
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(ThetaSketchType)
+
+  override def dataType: DataType = StringType
+
+  override def nullSafeEval(input: Any): Any = {
+    
UTF8String.fromString(Sketch.wrap(Memory.wrap(input.asInstanceOf[Array[Byte]])).toString());
+  }
+
+  override protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f: 
String => String): ExprCode = {
+    val childEval = child.genCode(ctx)
+    val sketch = ctx.freshName("sketch")
+    val code = s"""
+      ${childEval.code}
+      final org.apache.datasketches.theta.Sketch $sketch = 
org.apache.spark.sql.types.ThetaSketchWrapper.wrapAsReadOnlySketch(${childEval.value});
+      final double ${ev.value} = $sketch.toString());
+    """
+    ev.copy(code = CodeBlock(Seq(code), Seq.empty), isNull = childEval.isNull)
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, c => s"($c)")
+  }
+}
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala
index 240aeea..7445d50 100644
--- a/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala
+++ b/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.functions.lit
 
 import org.apache.spark.sql.datasketches.common.DatasketchesScalaFunctionBase
 import org.apache.spark.sql.datasketches.theta.aggregate.{ThetaSketchAggBuild, 
ThetaSketchAggUnion}
-import 
org.apache.spark.sql.datasketches.theta.expressions.ThetaSketchGetEstimate
+import 
org.apache.spark.sql.datasketches.theta.expressions.{ThetaSketchGetEstimate, 
ThetaSketchToString}
 import org.apache.spark.sql.datasketches.common.DatasketchesScalaFunctionBase
 
 object functions extends DatasketchesScalaFunctionBase {
@@ -58,20 +58,28 @@ object functions extends DatasketchesScalaFunctionBase {
     theta_sketch_agg_build(Column(columnName))
   }
 
+  def theta_sketch_agg_union(column: Column, lgk: Int, seed: Long): Column = 
withAggregateFunction {
+    new ThetaSketchAggUnion(column.expr, lgk, seed)
+  }
+
+  def theta_sketch_agg_union(columnName: String, lgk: Int, seed: Long): Column 
= {
+    theta_sketch_agg_union(Column(columnName), lgk, seed)
+  }
+
   def theta_sketch_agg_union(column: Column, lgk: Int): Column = 
withAggregateFunction {
-    new ThetaSketchAggUnion(column.expr, lit(lgk).expr)
+    new ThetaSketchAggUnion(column.expr, lgk)
   }
 
-  def theta_sketch_agg_union(columnName: String, lgk: Int): Column = 
withAggregateFunction {
-    new ThetaSketchAggUnion(Column(columnName).expr, lit(lgk).expr)
+  def theta_sketch_agg_union(columnName: String, lgk: Int): Column = {
+    theta_sketch_agg_union(Column(columnName), lgk)
   }
 
   def theta_sketch_agg_union(column: Column): Column = withAggregateFunction {
     new ThetaSketchAggUnion(column.expr)
   }
 
-  def theta_sketch_agg_union(columnName: String): Column = 
withAggregateFunction {
-    new ThetaSketchAggUnion(Column(columnName).expr)
+  def theta_sketch_agg_union(columnName: String): Column = {
+    theta_sketch_agg_union(Column(columnName))
   }
 
   def theta_sketch_get_estimate(column: Column): Column = withExpr {
@@ -81,4 +89,12 @@ object functions extends DatasketchesScalaFunctionBase {
   def theta_sketch_get_estimate(columnName: String): Column = {
     theta_sketch_get_estimate(Column(columnName))
   }
+
+  def theta_sketch_to_string(column: Column): Column = withExpr {
+    new ThetaSketchToString(column.expr)
+  }
+
+  def theta_sketch_to_string(columnName: String): Column = {
+    theta_sketch_to_string(Column(columnName))
+  }
 }
diff --git 
a/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala 
b/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala
index 62a547f..09add01 100644
--- a/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala
+++ b/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala
@@ -99,16 +99,16 @@ class ThetaTest extends SparkSessionManager {
     assert(df.head().getAs[Double]("estimate") == 100.0)
   }
 
-  test("Theta Sketch build via SQL with lgk and seed") {
+  test("Theta Sketch build from strings via SQL with lgk and seed") {
     ThetaFunctionRegistry.registerFunctions(spark)
 
     val n = 100
-    val data = (for (i <- 1 to n) yield i).toDF("value")
+    val data = (for (i <- 1 to n) yield i.toString()).toDF("str")
     data.createOrReplaceTempView("theta_input_table")
 
     val df = spark.sql(s"""
       SELECT
-        theta_sketch_get_estimate(theta_sketch_agg_build(value, 14, 111L)) AS 
estimate
+        theta_sketch_get_estimate(theta_sketch_agg_build(str, 14, 111L)) AS 
estimate
       FROM
         theta_input_table
     """)
@@ -131,7 +131,7 @@ class ThetaTest extends SparkSessionManager {
     df.head().getAs[Double]("estimate") shouldBe (100.0 +- 2.0)
   }
 
-  test("Theta Union via Scala") {
+  test("Theta Union via Scala with defauls") {
     val numGroups = 10
     val numDistinct = 2000
     val data = (for (i <- 1 to numDistinct) yield (i % numGroups, 
i)).toDF("group", "value")
@@ -142,8 +142,34 @@ class ThetaTest extends SparkSessionManager {
     assert(result.getAs[Double]("estimate") == numDistinct)
   }
 
+  test("Theta Union via Scala with lgk") {
+    val numGroups = 10
+    val numDistinct = 2000
+    val data = (for (i <- 1 to numDistinct) yield (i % numGroups, 
i)).toDF("group", "value")
+
+    val groupedDf = data.groupBy("group").agg(theta_sketch_agg_build("value", 
14).as("sketch"))
+    val mergedDf = groupedDf.agg(theta_sketch_agg_union("sketch", 
14).as("merged"))
+    val result: Row = 
mergedDf.select(theta_sketch_get_estimate("merged").as("estimate")).head()
+    assert(result.getAs[Double]("estimate") == numDistinct)
+  }
+
+  test("Theta Union via Scala with lgk and seed") {
+    val numGroups = 10
+    val numDistinct = 2000
+    val data = (for (i <- 1 to numDistinct) yield (i % numGroups, 
i)).toDF("group", "value")
+
+    val groupedDf = data.groupBy("group").agg(theta_sketch_agg_build("value", 
14, 111).as("sketch"))
+    val mergedDf = groupedDf.agg(theta_sketch_agg_union("sketch", 14, 
111).as("merged"))
+    val result: Row = 
mergedDf.select(theta_sketch_get_estimate("merged").as("estimate")).head()
+    assert(result.getAs[Double]("estimate") == numDistinct)
+
+    val toStr: Row = 
mergedDf.select(theta_sketch_to_string("merged").as("summary")).head()
+    toStr.getAs[String]("summary") should startWith ("\n### HeapCompactSketch")
+  }
+
+  test("Theta Union via SQL with defaults") {
+    ThetaFunctionRegistry.registerFunctions(spark)
 
-  test("Theta Union via SQL default lgk") {
     val numGroups = 10
     val numDistinct = 2000
     val data = (for (i <- 1 to numDistinct) yield (i % numGroups, 
i)).toDF("group", "value")
@@ -170,6 +196,8 @@ class ThetaTest extends SparkSessionManager {
   }
 
   test("Theta Union via SQL with lgk") {
+    ThetaFunctionRegistry.registerFunctions(spark)
+
     val numGroups = 10
     val numDistinct = 2000
     val data = (for (i <- 1 to numDistinct) yield (i % numGroups, 
i)).toDF("group", "value")
@@ -195,4 +223,31 @@ class ThetaTest extends SparkSessionManager {
     assert(mergedDf.head().getAs[Double]("estimate") == numDistinct)
   }
 
+  test("Theta Union via SQL with lgk and seed") {
+    ThetaFunctionRegistry.registerFunctions(spark)
+
+    val numGroups = 10
+    val numDistinct = 2000
+    val data = (for (i <- 1 to numDistinct) yield (i % numGroups, 
i)).toDF("group", "value")
+    data.createOrReplaceTempView("theta_input_table")
+
+    val groupedDf = spark.sql(s"""
+      SELECT
+        group,
+        theta_sketch_agg_build(value, 14, 111L) AS sketch
+      FROM
+        theta_input_table
+      GROUP BY
+        group
+    """)
+    groupedDf.createOrReplaceTempView("theta_sketch_table")
+
+    val mergedDf = spark.sql(s"""
+      SELECT
+        theta_sketch_get_estimate(theta_sketch_agg_union(sketch, 14, 111L)) AS 
estimate
+      FROM
+        theta_sketch_table
+    """)
+    assert(mergedDf.head().getAs[Double]("estimate") == numDistinct)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to