This is an automated email from the ASF dual-hosted git repository. alsay pushed a commit to branch nullable_theta in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git
commit 69361b01f6787f7a7117cfe94b45622c184ba825 Author: AlexanderSaydakov <[email protected]> AuthorDate: Tue Apr 1 23:09:09 2025 -0700 support nullable --- .../common/DatasketchesScalaFunctionsBase.scala | 2 +- .../theta/aggregate/ThetaSketchAggBuild.scala | 37 ++++++++++------------ .../theta/expressions/ThetaExpressions.scala | 8 +++-- .../spark/sql/datasketches/theta/functions.scala | 8 +++++ .../theta/types/ThetaSketchWrapper.scala | 5 +-- .../spark/sql/datasketches/theta/ThetaTest.scala | 12 +++++++ 6 files changed, 47 insertions(+), 25 deletions(-) diff --git a/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesScalaFunctionsBase.scala b/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesScalaFunctionsBase.scala index b060dad..67b44a2 100644 --- a/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesScalaFunctionsBase.scala +++ b/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesScalaFunctionsBase.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.Column -// this interfact provides a few helper methods defines and maps all the variants of each function invocation, analagous +// this interface provides a few helper methods defines and maps all the variants of each function invocation, analagous // to the functions object in core Spark's org.apache.spark.sql.functions trait DatasketchesScalaFunctionBase { protected def withExpr(expr: => Expression): Column = Column(expr) diff --git a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala index f6eaba5..c8a43a9 100644 --- a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala +++ b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala @@ -58,9 +58,10 @@ import org.apache.datasketches.thetacommon.ThetaUtil.DEFAULT_UPDATE_SEED // scalastyle:on line.size.limit case class ThetaSketchAggBuild( inputExpr: Expression, - lgKExpr: Expression, - seedExpr: Expression, - pExpr: Expression, + lgKExpr: Expression = Literal(DEFAULT_LG_K), + seedExpr: Expression = Literal(DEFAULT_UPDATE_SEED), + pExpr: Expression = Literal(1f), + nullable: Boolean = true, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[ThetaSketchWrapper] @@ -98,7 +99,7 @@ case class ThetaSketchAggBuild( override def third: Expression = seedExpr override def fourth: Expression = pExpr - def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression, pExpr: Expression) = this(inputExpr, lgKExpr, seedExpr, pExpr, 0, 0) + def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression, pExpr: Expression) = this(inputExpr, lgKExpr, seedExpr, pExpr, true, 0, 0) def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression) = this(inputExpr, lgKExpr, seedExpr, Literal(1f)) def this(inputExpr: Expression, lgKExpr: Expression) = this(inputExpr, lgKExpr, Literal(DEFAULT_UPDATE_SEED)) def this(inputExpr: Expression) = this(inputExpr, Literal(DEFAULT_LG_K)) @@ -106,6 +107,7 @@ case class ThetaSketchAggBuild( def this(inputExpr: Expression, lgK: Int) = this(inputExpr, Literal(lgK)) def this(inputExpr: Expression, lgK: Int, seed: Long) = this(inputExpr, Literal(lgK), Literal(seed)) def this(inputExpr: Expression, lgK: Int, seed: Long, p: Float) = this(inputExpr, Literal(lgK), Literal(seed), Literal(p)) + def this(inputExpr: Expression, lgK: Int, seed: Long, p: Float, nullable: Boolean) = this(inputExpr, Literal(lgK), Literal(seed), Literal(p), nullable) override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ThetaSketchAggBuild = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -121,16 +123,16 @@ case class ThetaSketchAggBuild( override def dataType: DataType = ThetaSketchType - override def nullable: Boolean = false - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, StringType), IntegerType, LongType, FloatType) - override def createAggregationBuffer(): ThetaSketchWrapper = new ThetaSketchWrapper(updateSketch - = Some(UpdateSketch.builder().setLogNominalEntries(lgK).setSeed(seed).setP(p).build())) + override def createAggregationBuffer(): ThetaSketchWrapper = new ThetaSketchWrapper() override def update(wrapper: ThetaSketchWrapper, input: InternalRow): ThetaSketchWrapper = { val value = inputExpr.eval(input) if (value != null) { + if (wrapper.updateSketch.isEmpty) { + wrapper.updateSketch = Some(UpdateSketch.builder().setLogNominalEntries(lgK).setSeed(seed).setP(p).build()) + } inputExpr.dataType match { case DoubleType => wrapper.updateSketch.get.update(value.asInstanceOf[Double]) case FloatType => wrapper.updateSketch.get.update(value.asInstanceOf[Float]) @@ -145,13 +147,9 @@ case class ThetaSketchAggBuild( } override def merge(wrapper: ThetaSketchWrapper, other: ThetaSketchWrapper): ThetaSketchWrapper = { - if (other != null && !other.compactSketch.get.isEmpty()) { + if (other.compactSketch.isDefined) { if (wrapper.union.isEmpty) { wrapper.union = Some(SetOperation.builder().setLogNominalEntries(lgK).setSeed(seed).setP(p).buildUnion()) - if (wrapper.compactSketch.isDefined) { - wrapper.union.get.union(wrapper.compactSketch.get) - wrapper.compactSketch = None - } } wrapper.union.get.union(other.compactSketch.get) } @@ -159,18 +157,17 @@ case class ThetaSketchAggBuild( } override def eval(wrapper: ThetaSketchWrapper): Any = { - if (wrapper == null || wrapper.union.isEmpty) { - null - } else { - wrapper.union.get.getResult.toByteArrayCompressed() - } + val result = wrapper.serialize() + if (result != null) return result + if (nullable) return null + UpdateSketch.builder().setSeed(seed).build().compact().toByteArrayCompressed() } override def serialize(wrapper: ThetaSketchWrapper): Array[Byte] = { - ThetaSketchType.serialize(wrapper) + wrapper.serialize } override def deserialize(bytes: Array[Byte]): ThetaSketchWrapper = { - ThetaSketchType.deserialize(bytes) + ThetaSketchWrapper.deserialize(bytes) } } diff --git a/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala b/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala index dc9985c..dc68fde 100644 --- a/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala +++ b/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala @@ -63,8 +63,12 @@ case class ThetaSketchGetEstimate(child: Expression) val sketch = ctx.freshName("sketch") val code = s""" ${childEval.code} - final org.apache.datasketches.theta.Sketch $sketch = org.apache.spark.sql.datasketches.theta.types.ThetaSketchWrapper.wrapAsReadOnlySketch(${childEval.value}); - final double ${ev.value} = $sketch.getEstimate(); + if (${childEval.isNull}) { + ${ev.value} = null; + } else { + final org.apache.datasketches.theta.Sketch $sketch = org.apache.spark.sql.datasketches.theta.types.ThetaSketchWrapper.wrapAsReadOnlySketch(${childEval.value}); + ${ev.value} = $sketch.getEstimate(); + } """ ev.copy(code = CodeBlock(Seq(code), Seq.empty), isNull = childEval.isNull) } diff --git a/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala b/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala index 7445d50..8de029f 100644 --- a/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala +++ b/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala @@ -26,6 +26,14 @@ import org.apache.spark.sql.datasketches.theta.expressions.{ThetaSketchGetEstima import org.apache.spark.sql.datasketches.common.DatasketchesScalaFunctionBase object functions extends DatasketchesScalaFunctionBase { + def theta_sketch_agg_build(column: Column, lgk: Int, seed: Long, p: Float, nullable: Boolean): Column = withAggregateFunction { + new ThetaSketchAggBuild(column.expr, lgk, seed, p, nullable) + } + + def theta_sketch_agg_build(columnName: String, lgk: Int, seed: Long, p: Float, nullable: Boolean): Column = { + theta_sketch_agg_build(Column(columnName), lgk, seed, p, nullable) + } + def theta_sketch_agg_build(column: Column, lgk: Int, seed: Long, p: Float): Column = withAggregateFunction { new ThetaSketchAggBuild(column.expr, lgk, seed, p) } diff --git a/src/main/scala/org/apache/spark/sql/datasketches/theta/types/ThetaSketchWrapper.scala b/src/main/scala/org/apache/spark/sql/datasketches/theta/types/ThetaSketchWrapper.scala index a699017..c0967d4 100644 --- a/src/main/scala/org/apache/spark/sql/datasketches/theta/types/ThetaSketchWrapper.scala +++ b/src/main/scala/org/apache/spark/sql/datasketches/theta/types/ThetaSketchWrapper.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types.SQLUserDefinedType @SQLUserDefinedType(udt = classOf[ThetaSketchType]) class ThetaSketchWrapper(var updateSketch: Option[UpdateSketch] = None, var compactSketch: Option[CompactSketch] = None, var union: Option[Union] = None) { - def serialize: Array[Byte] = { + def serialize(): Array[Byte] = { if (updateSketch.isDefined) return updateSketch.get.compact().toByteArrayCompressed else if (compactSketch.isDefined) return compactSketch.get.toByteArrayCompressed else if (union.isDefined) return union.get.getResult.toByteArrayCompressed @@ -35,12 +35,13 @@ class ThetaSketchWrapper(var updateSketch: Option[UpdateSketch] = None, var comp if (updateSketch.isDefined) return updateSketch.get.toString else if (compactSketch.isDefined) return compactSketch.get.toString else if (union.isDefined) return union.get.toString - "" + "NULL" } } object ThetaSketchWrapper { def deserialize(bytes: Array[Byte]): ThetaSketchWrapper = { + if (bytes == null) return new ThetaSketchWrapper() new ThetaSketchWrapper(compactSketch = Some(CompactSketch.heapify(Memory.wrap(bytes)))) } diff --git a/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala b/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala index 09add01..8b1dd99 100644 --- a/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala +++ b/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala @@ -23,10 +23,22 @@ import org.apache.spark.sql.datasketches.theta.functions._ import org.apache.spark.sql.datasketches.theta.ThetaFunctionRegistry import org.scalatest.matchers.should.Matchers._ +import org.apache.spark.sql.datasketches.theta.types.ThetaSketchType class ThetaTest extends SparkSessionManager { import spark.implicits._ + test("Theta Sketch build via Scala with defaults null input") { + val seq: Seq[Integer] = Seq(null, null, null, null, null, null, null, null, null, null) + val df = seq.toDF("value") + + val sketchDf = df.agg(theta_sketch_agg_build("value").as("sketch")) + assert(sketchDf.head().getAs[ThetaSketchType]("sketch") == null) + + val result: Row = sketchDf.select(theta_sketch_get_estimate("sketch").as("estimate")).head() + assert(result.getAs[Double]("estimate") == 0) + } + test("Theta Sketch build via Scala with defaults") { val n = 100 val data = (for (i <- 1 to n) yield i).toDF("value") --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
