This is an automated email from the ASF dual-hosted git repository. alsay pushed a commit to branch theta_params in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git
commit e3f9503105e2ff32fc140089a3b57541e17feb5c Author: AlexanderSaydakov <[email protected]> AuthorDate: Fri Feb 21 22:51:13 2025 -0800 added params to theta build agg --- .../theta/aggregate/ThetaSketchAggBuild.scala | 79 ++++++++++++++++------ .../spark/sql/datasketches/theta/functions.scala | 16 +++++ .../spark/sql/datasketches/theta/ThetaTest.scala | 68 ++++++++++++++++++- 3 files changed, 140 insertions(+), 23 deletions(-) diff --git a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala index 95dbd61..70d13e2 100644 --- a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala +++ b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala @@ -23,9 +23,16 @@ import org.apache.spark.sql.datasketches.theta.types.{ThetaSketchType, ThetaSket import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate -import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.catalyst.trees.QuaternaryLike import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, LongType, NumericType, FloatType, DoubleType} +import org.apache.datasketches.thetacommon.ThetaUtil.DEFAULT_UPDATE_SEED +import org.apache.datasketches.common.ResizeFactor + +object ThetaSketchConstants { + final val DEFAULT_LG_K: Int = 12 +} + /** * The ThetaSketchBuild function creates a Theta sketch from a column of values * which can be used to estimate distinct count. @@ -48,26 +55,55 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, Long ) // scalastyle:on line.size.limit case class ThetaSketchAggBuild( - left: Expression, - right: Expression, + inputExpr: Expression, + lgKExpr: Expression, + seedExpr: Expression, + pExpr: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[ThetaSketchWrapper] - with BinaryLike[Expression] + with QuaternaryLike[Expression] with ExpectsInputTypes { - lazy val lgk: Int = { - right.eval() match { - case null => 12 - case lgk: Int => lgk + lazy val lgK: Int = { + lgKExpr.eval() match { + case null => ThetaSketchConstants.DEFAULT_LG_K + case lgK: Int => lgK + case _ => throw new IllegalArgumentException( + s"Unsupported input type ${lgKExpr.dataType.catalogString}") + } + } + + lazy val seed: Long = { + seedExpr.eval() match { + case null => DEFAULT_UPDATE_SEED + case seed: Long => seed + case _ => throw new IllegalArgumentException( + s"Unsupported input type ${seedExpr.dataType.catalogString}") + } + } + + lazy val p: Float = { + pExpr.eval() match { + case null => 1f + case p: Float => p case _ => throw new IllegalArgumentException( - s"Unsupported input type ${right.dataType.catalogString}") + s"Unsupported input type ${pExpr.dataType.catalogString}") } } + override def first: Expression = inputExpr + override def second: Expression = lgKExpr + override def third: Expression = seedExpr + override def fourth: Expression = pExpr + + def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression, pExpr: Expression) = this(inputExpr, lgKExpr, seedExpr, pExpr, 0, 0) + def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression) = this(inputExpr, lgKExpr, seedExpr, Literal(1f)) + def this(inputExpr: Expression, lgKExpr: Expression) = this(inputExpr, lgKExpr, Literal(DEFAULT_UPDATE_SEED)) + def this(inputExpr: Expression) = this(inputExpr, Literal(ThetaSketchConstants.DEFAULT_LG_K)) - def this(child: Expression) = this(child, Literal(12), 0, 0) - def this(child: Expression, lgk: Expression) = this(child, lgk, 0, 0) - def this(child: Expression, lgk: Int) = this(child, Literal(lgk), 0, 0) + def this(inputExpr: Expression, lgK: Int) = this(inputExpr, Literal(lgK)) + def this(inputExpr: Expression, lgK: Int, seed: Long) = this(inputExpr, Literal(lgK), Literal(seed)) + def this(inputExpr: Expression, lgK: Int, seed: Long, p: Float) = this(inputExpr, Literal(lgK), Literal(seed), Literal(p)) override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ThetaSketchAggBuild = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -75,8 +111,8 @@ case class ThetaSketchAggBuild( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ThetaSketchAggBuild = copy(inputAggBufferOffset = newInputAggBufferOffset) - override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): ThetaSketchAggBuild = { - copy(left = newLeft, right = newRight) + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression, newFourth: Expression): ThetaSketchAggBuild = { + copy(inputExpr = newFirst, lgKExpr = newSecond, seedExpr = newThird, pExpr = newFourth) } override def prettyName: String = "theta_sketch_build" @@ -87,27 +123,28 @@ case class ThetaSketchAggBuild( override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType, LongType, FloatType, DoubleType) - override def createAggregationBuffer(): ThetaSketchWrapper = new ThetaSketchWrapper(updateSketch = Some(UpdateSketch.builder().setLogNominalEntries(lgk).build())) + override def createAggregationBuffer(): ThetaSketchWrapper = new ThetaSketchWrapper(updateSketch + = Some(UpdateSketch.builder().setLogNominalEntries(lgK).setSeed(seed).setP(p).build())) override def update(wrapper: ThetaSketchWrapper, input: InternalRow): ThetaSketchWrapper = { - val value = left.eval(input) + val value = inputExpr.eval(input) if (value != null) { - left.dataType match { + inputExpr.dataType match { case DoubleType => wrapper.updateSketch.get.update(value.asInstanceOf[Double]) case FloatType => wrapper.updateSketch.get.update(value.asInstanceOf[Float]) case IntegerType => wrapper.updateSketch.get.update(value.asInstanceOf[Int]) case LongType => wrapper.updateSketch.get.update(value.asInstanceOf[Long]) case _ => throw new IllegalArgumentException( - s"Unsupported input type ${left.dataType.catalogString}") + s"Unsupported input type ${inputExpr.dataType.catalogString}") } } wrapper } override def merge(wrapper: ThetaSketchWrapper, other: ThetaSketchWrapper): ThetaSketchWrapper = { - if (other != null && !other.compactSketch.get.isEmpty) { + if (other != null && !other.compactSketch.get.isEmpty()) { if (wrapper.union.isEmpty) { - wrapper.union = Some(SetOperation.builder().setLogNominalEntries(lgk).buildUnion) + wrapper.union = Some(SetOperation.builder().setLogNominalEntries(lgK).setSeed(seed).setP(p).buildUnion()) if (wrapper.compactSketch.isDefined) { wrapper.union.get.union(wrapper.compactSketch.get) wrapper.compactSketch = None @@ -122,7 +159,7 @@ case class ThetaSketchAggBuild( if (wrapper == null || wrapper.union.isEmpty) { null } else { - wrapper.union.get.getResult.toByteArrayCompressed + wrapper.union.get.getResult.toByteArrayCompressed() } } diff --git a/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala b/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala index 2a788a7..240aeea 100644 --- a/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala +++ b/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala @@ -26,6 +26,22 @@ import org.apache.spark.sql.datasketches.theta.expressions.ThetaSketchGetEstimat import org.apache.spark.sql.datasketches.common.DatasketchesScalaFunctionBase object functions extends DatasketchesScalaFunctionBase { + def theta_sketch_agg_build(column: Column, lgk: Int, seed: Long, p: Float): Column = withAggregateFunction { + new ThetaSketchAggBuild(column.expr, lgk, seed, p) + } + + def theta_sketch_agg_build(columnName: String, lgk: Int, seed: Long, p: Float): Column = { + theta_sketch_agg_build(Column(columnName), lgk, seed, p) + } + + def theta_sketch_agg_build(column: Column, lgk: Int, seed: Long): Column = withAggregateFunction { + new ThetaSketchAggBuild(column.expr, lgk, seed) + } + + def theta_sketch_agg_build(columnName: String, lgk: Int, seed: Long): Column = { + theta_sketch_agg_build(Column(columnName), lgk, seed) + } + def theta_sketch_agg_build(column: Column, lgk: Int): Column = withAggregateFunction { new ThetaSketchAggBuild(column.expr, lgk) } diff --git a/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala b/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala index 53bd0aa..62a547f 100644 --- a/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala +++ b/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala @@ -22,10 +22,12 @@ import org.apache.spark.sql.datasketches.common.SparkSessionManager import org.apache.spark.sql.datasketches.theta.functions._ import org.apache.spark.sql.datasketches.theta.ThetaFunctionRegistry +import org.scalatest.matchers.should.Matchers._ + class ThetaTest extends SparkSessionManager { import spark.implicits._ - test("Theta Sketch build via Scala") { + test("Theta Sketch build via Scala with defaults") { val n = 100 val data = (for (i <- 1 to n) yield i).toDF("value") @@ -35,7 +37,37 @@ class ThetaTest extends SparkSessionManager { assert(result.getAs[Double]("estimate") == 100.0) } - test("Theta Sketch build via SQL default lgk") { + test("Theta Sketch build via Scala with lgk") { + val n = 100 + val data = (for (i <- 1 to n) yield i).toDF("value") + + val sketchDf = data.agg(theta_sketch_agg_build("value", 14).as("sketch")) + val result: Row = sketchDf.select(theta_sketch_get_estimate("sketch").as("estimate")).head() + + assert(result.getAs[Double]("estimate") == 100.0) + } + + test("Theta Sketch build via Scala with lgk and seed") { + val n = 100 + val data = (for (i <- 1 to n) yield i).toDF("value") + + val sketchDf = data.agg(theta_sketch_agg_build("value", 14, 111).as("sketch")) + val result: Row = sketchDf.select(theta_sketch_get_estimate("sketch").as("estimate")).head() + + assert(result.getAs[Double]("estimate") == 100.0) + } + + test("Theta Sketch build via Scala with lgk, seed and p") { + val n = 100 + val data = (for (i <- 1 to n) yield i).toDF("value") + + val sketchDf = data.agg(theta_sketch_agg_build("value", 14, 111, 0.99f).as("sketch")) + val result: Row = sketchDf.select(theta_sketch_get_estimate("sketch").as("estimate")).head() + + result.getAs[Double]("estimate") shouldBe (100.0 +- 2.0) + } + + test("Theta Sketch build via SQL with defaults") { ThetaFunctionRegistry.registerFunctions(spark) val n = 100 @@ -67,6 +99,38 @@ class ThetaTest extends SparkSessionManager { assert(df.head().getAs[Double]("estimate") == 100.0) } + test("Theta Sketch build via SQL with lgk and seed") { + ThetaFunctionRegistry.registerFunctions(spark) + + val n = 100 + val data = (for (i <- 1 to n) yield i).toDF("value") + data.createOrReplaceTempView("theta_input_table") + + val df = spark.sql(s""" + SELECT + theta_sketch_get_estimate(theta_sketch_agg_build(value, 14, 111L)) AS estimate + FROM + theta_input_table + """) + assert(df.head().getAs[Double]("estimate") == 100.0) + } + + test("Theta Sketch build via SQL with lgk, seed and p") { + ThetaFunctionRegistry.registerFunctions(spark) + + val n = 100 + val data = (for (i <- 1 to n) yield i).toDF("value") + data.createOrReplaceTempView("theta_input_table") + + val df = spark.sql(s""" + SELECT + theta_sketch_get_estimate(theta_sketch_agg_build(value, 14, 111L, 0.99f)) AS estimate + FROM + theta_input_table + """) + df.head().getAs[Double]("estimate") shouldBe (100.0 +- 2.0) + } + test("Theta Union via Scala") { val numGroups = 10 val numDistinct = 2000 --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
