This is an automated email from the ASF dual-hosted git repository.

alsay pushed a commit to branch theta_params
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git

commit e3f9503105e2ff32fc140089a3b57541e17feb5c
Author: AlexanderSaydakov <[email protected]>
AuthorDate: Fri Feb 21 22:51:13 2025 -0800

    added params to theta build agg
---
 .../theta/aggregate/ThetaSketchAggBuild.scala      | 79 ++++++++++++++++------
 .../spark/sql/datasketches/theta/functions.scala   | 16 +++++
 .../spark/sql/datasketches/theta/ThetaTest.scala   | 68 ++++++++++++++++++-
 3 files changed, 140 insertions(+), 23 deletions(-)

diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
index 95dbd61..70d13e2 100644
--- 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
+++ 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
@@ -23,9 +23,16 @@ import 
org.apache.spark.sql.datasketches.theta.types.{ThetaSketchType, ThetaSket
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, ExpressionDescription, Literal}
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
-import org.apache.spark.sql.catalyst.trees.BinaryLike
+import org.apache.spark.sql.catalyst.trees.QuaternaryLike
 import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, 
LongType, NumericType, FloatType, DoubleType}
 
+import org.apache.datasketches.thetacommon.ThetaUtil.DEFAULT_UPDATE_SEED
+import org.apache.datasketches.common.ResizeFactor
+
+object ThetaSketchConstants {
+  final val DEFAULT_LG_K: Int = 12
+}
+
 /**
  * The ThetaSketchBuild function creates a Theta sketch from a column of values
  * which can be used to estimate distinct count.
@@ -48,26 +55,55 @@ import org.apache.spark.sql.types.{AbstractDataType, 
DataType, IntegerType, Long
 )
 // scalastyle:on line.size.limit
 case class ThetaSketchAggBuild(
-    left: Expression,
-    right: Expression,
+    inputExpr: Expression,
+    lgKExpr: Expression,
+    seedExpr: Expression,
+    pExpr: Expression,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
   extends TypedImperativeAggregate[ThetaSketchWrapper]
-    with BinaryLike[Expression]
+    with QuaternaryLike[Expression]
     with ExpectsInputTypes {
 
-  lazy val lgk: Int = {
-    right.eval() match {
-      case null => 12
-      case lgk: Int => lgk
+  lazy val lgK: Int = {
+    lgKExpr.eval() match {
+      case null => ThetaSketchConstants.DEFAULT_LG_K
+      case lgK: Int => lgK
+      case _ => throw new IllegalArgumentException(
+        s"Unsupported input type ${lgKExpr.dataType.catalogString}")
+    }
+  }
+
+  lazy val seed: Long = {
+    seedExpr.eval() match {
+      case null => DEFAULT_UPDATE_SEED
+      case seed: Long => seed
+      case _ => throw new IllegalArgumentException(
+        s"Unsupported input type ${seedExpr.dataType.catalogString}")
+    }
+  }
+
+  lazy val p: Float = {
+    pExpr.eval() match {
+      case null => 1f
+      case p: Float => p
       case _ => throw new IllegalArgumentException(
-        s"Unsupported input type ${right.dataType.catalogString}")
+        s"Unsupported input type ${pExpr.dataType.catalogString}")
     }
   }
+  override def first: Expression = inputExpr
+  override def second: Expression = lgKExpr
+  override def third: Expression = seedExpr
+  override def fourth: Expression = pExpr
+
+  def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression, 
pExpr: Expression) = this(inputExpr, lgKExpr, seedExpr, pExpr, 0, 0)
+  def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression) = 
this(inputExpr, lgKExpr, seedExpr, Literal(1f))
+  def this(inputExpr: Expression, lgKExpr: Expression) = this(inputExpr, 
lgKExpr, Literal(DEFAULT_UPDATE_SEED))
+  def this(inputExpr: Expression) = this(inputExpr, 
Literal(ThetaSketchConstants.DEFAULT_LG_K))
 
-  def this(child: Expression) = this(child, Literal(12), 0, 0)
-  def this(child: Expression, lgk: Expression) = this(child, lgk, 0, 0)
-  def this(child: Expression, lgk: Int) = this(child, Literal(lgk), 0, 0)
+  def this(inputExpr: Expression, lgK: Int) = this(inputExpr, Literal(lgK))
+  def this(inputExpr: Expression, lgK: Int, seed: Long) = this(inputExpr, 
Literal(lgK), Literal(seed))
+  def this(inputExpr: Expression, lgK: Int, seed: Long, p: Float) = 
this(inputExpr, Literal(lgK), Literal(seed), Literal(p))
 
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ThetaSketchAggBuild =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -75,8 +111,8 @@ case class ThetaSketchAggBuild(
   override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ThetaSketchAggBuild =
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-  override protected def withNewChildrenInternal(newLeft: Expression, 
newRight: Expression): ThetaSketchAggBuild = {
-    copy(left = newLeft, right = newRight)
+  override protected def withNewChildrenInternal(newFirst: Expression, 
newSecond: Expression, newThird: Expression, newFourth: Expression): 
ThetaSketchAggBuild = {
+    copy(inputExpr = newFirst, lgKExpr = newSecond, seedExpr = newThird, pExpr 
= newFourth)
   }
 
   override def prettyName: String = "theta_sketch_build"
@@ -87,27 +123,28 @@ case class ThetaSketchAggBuild(
 
   override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, 
IntegerType, LongType, FloatType, DoubleType)
 
-  override def createAggregationBuffer(): ThetaSketchWrapper = new 
ThetaSketchWrapper(updateSketch = 
Some(UpdateSketch.builder().setLogNominalEntries(lgk).build()))
+  override def createAggregationBuffer(): ThetaSketchWrapper = new 
ThetaSketchWrapper(updateSketch
+    = 
Some(UpdateSketch.builder().setLogNominalEntries(lgK).setSeed(seed).setP(p).build()))
 
   override def update(wrapper: ThetaSketchWrapper, input: InternalRow): 
ThetaSketchWrapper = {
-    val value = left.eval(input)
+    val value = inputExpr.eval(input)
     if (value != null) {
-      left.dataType match {
+      inputExpr.dataType match {
         case DoubleType => 
wrapper.updateSketch.get.update(value.asInstanceOf[Double])
         case FloatType => 
wrapper.updateSketch.get.update(value.asInstanceOf[Float])
         case IntegerType => 
wrapper.updateSketch.get.update(value.asInstanceOf[Int])
         case LongType => 
wrapper.updateSketch.get.update(value.asInstanceOf[Long])
         case _ => throw new IllegalArgumentException(
-          s"Unsupported input type ${left.dataType.catalogString}")
+          s"Unsupported input type ${inputExpr.dataType.catalogString}")
       }
     }
     wrapper
   }
 
   override def merge(wrapper: ThetaSketchWrapper, other: ThetaSketchWrapper): 
ThetaSketchWrapper = {
-    if (other != null && !other.compactSketch.get.isEmpty) {
+    if (other != null && !other.compactSketch.get.isEmpty()) {
       if (wrapper.union.isEmpty) {
-        wrapper.union = 
Some(SetOperation.builder().setLogNominalEntries(lgk).buildUnion)
+        wrapper.union = 
Some(SetOperation.builder().setLogNominalEntries(lgK).setSeed(seed).setP(p).buildUnion())
         if (wrapper.compactSketch.isDefined) {
           wrapper.union.get.union(wrapper.compactSketch.get)
           wrapper.compactSketch = None
@@ -122,7 +159,7 @@ case class ThetaSketchAggBuild(
     if (wrapper == null || wrapper.union.isEmpty) {
       null
     } else {
-      wrapper.union.get.getResult.toByteArrayCompressed
+      wrapper.union.get.getResult.toByteArrayCompressed()
     }
   }
 
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala
index 2a788a7..240aeea 100644
--- a/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala
+++ b/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala
@@ -26,6 +26,22 @@ import 
org.apache.spark.sql.datasketches.theta.expressions.ThetaSketchGetEstimat
 import org.apache.spark.sql.datasketches.common.DatasketchesScalaFunctionBase
 
 object functions extends DatasketchesScalaFunctionBase {
+  def theta_sketch_agg_build(column: Column, lgk: Int, seed: Long, p: Float): 
Column = withAggregateFunction {
+    new ThetaSketchAggBuild(column.expr, lgk, seed, p)
+  }
+
+  def theta_sketch_agg_build(columnName: String, lgk: Int, seed: Long, p: 
Float): Column = {
+    theta_sketch_agg_build(Column(columnName), lgk, seed, p)
+  }
+
+  def theta_sketch_agg_build(column: Column, lgk: Int, seed: Long): Column = 
withAggregateFunction {
+    new ThetaSketchAggBuild(column.expr, lgk, seed)
+  }
+
+  def theta_sketch_agg_build(columnName: String, lgk: Int, seed: Long): Column 
= {
+    theta_sketch_agg_build(Column(columnName), lgk, seed)
+  }
+
   def theta_sketch_agg_build(column: Column, lgk: Int): Column = 
withAggregateFunction {
     new ThetaSketchAggBuild(column.expr, lgk)
   }
diff --git 
a/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala 
b/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala
index 53bd0aa..62a547f 100644
--- a/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala
+++ b/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala
@@ -22,10 +22,12 @@ import 
org.apache.spark.sql.datasketches.common.SparkSessionManager
 import org.apache.spark.sql.datasketches.theta.functions._
 import org.apache.spark.sql.datasketches.theta.ThetaFunctionRegistry
 
+import org.scalatest.matchers.should.Matchers._
+
 class ThetaTest extends SparkSessionManager {
   import spark.implicits._
 
-  test("Theta Sketch build via Scala") {
+  test("Theta Sketch build via Scala with defaults") {
     val n = 100
     val data = (for (i <- 1 to n) yield i).toDF("value")
 
@@ -35,7 +37,37 @@ class ThetaTest extends SparkSessionManager {
     assert(result.getAs[Double]("estimate") == 100.0)
   }
 
-  test("Theta Sketch build via SQL default lgk") {
+  test("Theta Sketch build via Scala with lgk") {
+    val n = 100
+    val data = (for (i <- 1 to n) yield i).toDF("value")
+
+    val sketchDf = data.agg(theta_sketch_agg_build("value", 14).as("sketch"))
+    val result: Row = 
sketchDf.select(theta_sketch_get_estimate("sketch").as("estimate")).head()
+
+    assert(result.getAs[Double]("estimate") == 100.0)
+  }
+
+  test("Theta Sketch build via Scala with lgk and seed") {
+    val n = 100
+    val data = (for (i <- 1 to n) yield i).toDF("value")
+
+    val sketchDf = data.agg(theta_sketch_agg_build("value", 14, 
111).as("sketch"))
+    val result: Row = 
sketchDf.select(theta_sketch_get_estimate("sketch").as("estimate")).head()
+
+    assert(result.getAs[Double]("estimate") == 100.0)
+  }
+
+  test("Theta Sketch build via Scala with lgk, seed and p") {
+    val n = 100
+    val data = (for (i <- 1 to n) yield i).toDF("value")
+
+    val sketchDf = data.agg(theta_sketch_agg_build("value", 14, 111, 
0.99f).as("sketch"))
+    val result: Row = 
sketchDf.select(theta_sketch_get_estimate("sketch").as("estimate")).head()
+
+    result.getAs[Double]("estimate") shouldBe (100.0 +- 2.0)
+  }
+
+  test("Theta Sketch build via SQL with defaults") {
     ThetaFunctionRegistry.registerFunctions(spark)
 
     val n = 100
@@ -67,6 +99,38 @@ class ThetaTest extends SparkSessionManager {
     assert(df.head().getAs[Double]("estimate") == 100.0)
   }
 
+  test("Theta Sketch build via SQL with lgk and seed") {
+    ThetaFunctionRegistry.registerFunctions(spark)
+
+    val n = 100
+    val data = (for (i <- 1 to n) yield i).toDF("value")
+    data.createOrReplaceTempView("theta_input_table")
+
+    val df = spark.sql(s"""
+      SELECT
+        theta_sketch_get_estimate(theta_sketch_agg_build(value, 14, 111L)) AS 
estimate
+      FROM
+        theta_input_table
+    """)
+    assert(df.head().getAs[Double]("estimate") == 100.0)
+  }
+
+  test("Theta Sketch build via SQL with lgk, seed and p") {
+    ThetaFunctionRegistry.registerFunctions(spark)
+
+    val n = 100
+    val data = (for (i <- 1 to n) yield i).toDF("value")
+    data.createOrReplaceTempView("theta_input_table")
+
+    val df = spark.sql(s"""
+      SELECT
+        theta_sketch_get_estimate(theta_sketch_agg_build(value, 14, 111L, 
0.99f)) AS estimate
+      FROM
+        theta_input_table
+    """)
+    df.head().getAs[Double]("estimate") shouldBe (100.0 +- 2.0)
+  }
+
   test("Theta Union via Scala") {
     val numGroups = 10
     val numDistinct = 2000


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to