This is an automated email from the ASF dual-hosted git repository.

alsay pushed a commit to branch nullable_theta
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git

commit 69361b01f6787f7a7117cfe94b45622c184ba825
Author: AlexanderSaydakov <[email protected]>
AuthorDate: Tue Apr 1 23:09:09 2025 -0700

    support nullable
---
 .../common/DatasketchesScalaFunctionsBase.scala    |  2 +-
 .../theta/aggregate/ThetaSketchAggBuild.scala      | 37 ++++++++++------------
 .../theta/expressions/ThetaExpressions.scala       |  8 +++--
 .../spark/sql/datasketches/theta/functions.scala   |  8 +++++
 .../theta/types/ThetaSketchWrapper.scala           |  5 +--
 .../spark/sql/datasketches/theta/ThetaTest.scala   | 12 +++++++
 6 files changed, 47 insertions(+), 25 deletions(-)

diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesScalaFunctionsBase.scala
 
b/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesScalaFunctionsBase.scala
index b060dad..67b44a2 100644
--- 
a/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesScalaFunctionsBase.scala
+++ 
b/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesScalaFunctionsBase.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
 import org.apache.spark.sql.Column
 
-// this interfact provides a few helper methods defines and maps all the 
variants of each function invocation, analagous
+// this interface provides a few helper methods defines and maps all the 
variants of each function invocation, analagous
 // to the functions object in core Spark's org.apache.spark.sql.functions
 trait DatasketchesScalaFunctionBase {
   protected def withExpr(expr: => Expression): Column = Column(expr)
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
index f6eaba5..c8a43a9 100644
--- 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
+++ 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
@@ -58,9 +58,10 @@ import 
org.apache.datasketches.thetacommon.ThetaUtil.DEFAULT_UPDATE_SEED
 // scalastyle:on line.size.limit
 case class ThetaSketchAggBuild(
     inputExpr: Expression,
-    lgKExpr: Expression,
-    seedExpr: Expression,
-    pExpr: Expression,
+    lgKExpr: Expression = Literal(DEFAULT_LG_K),
+    seedExpr: Expression = Literal(DEFAULT_UPDATE_SEED),
+    pExpr: Expression = Literal(1f),
+    nullable: Boolean = true,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
   extends TypedImperativeAggregate[ThetaSketchWrapper]
@@ -98,7 +99,7 @@ case class ThetaSketchAggBuild(
   override def third: Expression = seedExpr
   override def fourth: Expression = pExpr
 
-  def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression, 
pExpr: Expression) = this(inputExpr, lgKExpr, seedExpr, pExpr, 0, 0)
+  def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression, 
pExpr: Expression) = this(inputExpr, lgKExpr, seedExpr, pExpr, true, 0, 0)
   def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression) = 
this(inputExpr, lgKExpr, seedExpr, Literal(1f))
   def this(inputExpr: Expression, lgKExpr: Expression) = this(inputExpr, 
lgKExpr, Literal(DEFAULT_UPDATE_SEED))
   def this(inputExpr: Expression) = this(inputExpr, Literal(DEFAULT_LG_K))
@@ -106,6 +107,7 @@ case class ThetaSketchAggBuild(
   def this(inputExpr: Expression, lgK: Int) = this(inputExpr, Literal(lgK))
   def this(inputExpr: Expression, lgK: Int, seed: Long) = this(inputExpr, 
Literal(lgK), Literal(seed))
   def this(inputExpr: Expression, lgK: Int, seed: Long, p: Float) = 
this(inputExpr, Literal(lgK), Literal(seed), Literal(p))
+  def this(inputExpr: Expression, lgK: Int, seed: Long, p: Float, nullable: 
Boolean) = this(inputExpr, Literal(lgK), Literal(seed), Literal(p), nullable)
 
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ThetaSketchAggBuild =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -121,16 +123,16 @@ case class ThetaSketchAggBuild(
 
   override def dataType: DataType = ThetaSketchType
 
-  override def nullable: Boolean = false
-
   override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(NumericType, StringType), IntegerType, LongType, FloatType)
 
-  override def createAggregationBuffer(): ThetaSketchWrapper = new 
ThetaSketchWrapper(updateSketch
-    = 
Some(UpdateSketch.builder().setLogNominalEntries(lgK).setSeed(seed).setP(p).build()))
+  override def createAggregationBuffer(): ThetaSketchWrapper = new 
ThetaSketchWrapper()
 
   override def update(wrapper: ThetaSketchWrapper, input: InternalRow): 
ThetaSketchWrapper = {
     val value = inputExpr.eval(input)
     if (value != null) {
+      if (wrapper.updateSketch.isEmpty) {
+        wrapper.updateSketch = 
Some(UpdateSketch.builder().setLogNominalEntries(lgK).setSeed(seed).setP(p).build())
+      }
       inputExpr.dataType match {
         case DoubleType => 
wrapper.updateSketch.get.update(value.asInstanceOf[Double])
         case FloatType => 
wrapper.updateSketch.get.update(value.asInstanceOf[Float])
@@ -145,13 +147,9 @@ case class ThetaSketchAggBuild(
   }
 
   override def merge(wrapper: ThetaSketchWrapper, other: ThetaSketchWrapper): 
ThetaSketchWrapper = {
-    if (other != null && !other.compactSketch.get.isEmpty()) {
+    if (other.compactSketch.isDefined) {
       if (wrapper.union.isEmpty) {
         wrapper.union = 
Some(SetOperation.builder().setLogNominalEntries(lgK).setSeed(seed).setP(p).buildUnion())
-        if (wrapper.compactSketch.isDefined) {
-          wrapper.union.get.union(wrapper.compactSketch.get)
-          wrapper.compactSketch = None
-        }
       }
       wrapper.union.get.union(other.compactSketch.get)
     }
@@ -159,18 +157,17 @@ case class ThetaSketchAggBuild(
   }
 
   override def eval(wrapper: ThetaSketchWrapper): Any = {
-    if (wrapper == null || wrapper.union.isEmpty) {
-      null
-    } else {
-      wrapper.union.get.getResult.toByteArrayCompressed()
-    }
+    val result = wrapper.serialize()
+    if (result != null) return result
+    if (nullable) return null
+    
UpdateSketch.builder().setSeed(seed).build().compact().toByteArrayCompressed()
   }
 
   override def serialize(wrapper: ThetaSketchWrapper): Array[Byte] = {
-    ThetaSketchType.serialize(wrapper)
+    wrapper.serialize
   }
 
   override def deserialize(bytes: Array[Byte]): ThetaSketchWrapper = {
-    ThetaSketchType.deserialize(bytes)
+    ThetaSketchWrapper.deserialize(bytes)
   }
 }
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala
 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala
index dc9985c..dc68fde 100644
--- 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala
+++ 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala
@@ -63,8 +63,12 @@ case class ThetaSketchGetEstimate(child: Expression)
     val sketch = ctx.freshName("sketch")
     val code = s"""
       ${childEval.code}
-      final org.apache.datasketches.theta.Sketch $sketch = 
org.apache.spark.sql.datasketches.theta.types.ThetaSketchWrapper.wrapAsReadOnlySketch(${childEval.value});
-      final double ${ev.value} = $sketch.getEstimate();
+      if (${childEval.isNull}) {
+        ${ev.value} = null;
+      } else {
+        final org.apache.datasketches.theta.Sketch $sketch = 
org.apache.spark.sql.datasketches.theta.types.ThetaSketchWrapper.wrapAsReadOnlySketch(${childEval.value});
+        ${ev.value} = $sketch.getEstimate();
+      }
     """
     ev.copy(code = CodeBlock(Seq(code), Seq.empty), isNull = childEval.isNull)
   }
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala
index 7445d50..8de029f 100644
--- a/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala
+++ b/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala
@@ -26,6 +26,14 @@ import 
org.apache.spark.sql.datasketches.theta.expressions.{ThetaSketchGetEstima
 import org.apache.spark.sql.datasketches.common.DatasketchesScalaFunctionBase
 
 object functions extends DatasketchesScalaFunctionBase {
+  def theta_sketch_agg_build(column: Column, lgk: Int, seed: Long, p: Float, 
nullable: Boolean): Column = withAggregateFunction {
+    new ThetaSketchAggBuild(column.expr, lgk, seed, p, nullable)
+  }
+
+  def theta_sketch_agg_build(columnName: String, lgk: Int, seed: Long, p: 
Float, nullable: Boolean): Column = {
+    theta_sketch_agg_build(Column(columnName), lgk, seed, p, nullable)
+  }
+
   def theta_sketch_agg_build(column: Column, lgk: Int, seed: Long, p: Float): 
Column = withAggregateFunction {
     new ThetaSketchAggBuild(column.expr, lgk, seed, p)
   }
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/types/ThetaSketchWrapper.scala
 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/types/ThetaSketchWrapper.scala
index a699017..c0967d4 100644
--- 
a/src/main/scala/org/apache/spark/sql/datasketches/theta/types/ThetaSketchWrapper.scala
+++ 
b/src/main/scala/org/apache/spark/sql/datasketches/theta/types/ThetaSketchWrapper.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.types.SQLUserDefinedType
 @SQLUserDefinedType(udt = classOf[ThetaSketchType])
 class ThetaSketchWrapper(var updateSketch: Option[UpdateSketch] = None, var 
compactSketch: Option[CompactSketch] = None, var union: Option[Union] = None) {
 
-  def serialize: Array[Byte] = {
+  def serialize(): Array[Byte] = {
     if (updateSketch.isDefined) return 
updateSketch.get.compact().toByteArrayCompressed
     else if (compactSketch.isDefined) return 
compactSketch.get.toByteArrayCompressed
     else if (union.isDefined) return union.get.getResult.toByteArrayCompressed
@@ -35,12 +35,13 @@ class ThetaSketchWrapper(var updateSketch: 
Option[UpdateSketch] = None, var comp
     if (updateSketch.isDefined) return updateSketch.get.toString
     else if (compactSketch.isDefined) return compactSketch.get.toString
     else if (union.isDefined) return union.get.toString
-    ""
+    "NULL"
   }
 }
 
 object ThetaSketchWrapper {
   def deserialize(bytes: Array[Byte]): ThetaSketchWrapper = {
+    if (bytes == null) return new ThetaSketchWrapper()
     new ThetaSketchWrapper(compactSketch = 
Some(CompactSketch.heapify(Memory.wrap(bytes))))
   }
 
diff --git 
a/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala 
b/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala
index 09add01..8b1dd99 100644
--- a/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala
+++ b/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala
@@ -23,10 +23,22 @@ import org.apache.spark.sql.datasketches.theta.functions._
 import org.apache.spark.sql.datasketches.theta.ThetaFunctionRegistry
 
 import org.scalatest.matchers.should.Matchers._
+import org.apache.spark.sql.datasketches.theta.types.ThetaSketchType
 
 class ThetaTest extends SparkSessionManager {
   import spark.implicits._
 
+  test("Theta Sketch build via Scala with defaults null input") {
+    val seq: Seq[Integer] = Seq(null, null, null, null, null, null, null, 
null, null, null)
+    val df = seq.toDF("value")
+
+    val sketchDf = df.agg(theta_sketch_agg_build("value").as("sketch"))
+    assert(sketchDf.head().getAs[ThetaSketchType]("sketch") == null)
+
+    val result: Row = 
sketchDf.select(theta_sketch_get_estimate("sketch").as("estimate")).head()
+    assert(result.getAs[Double]("estimate") == 0)
+  }
+
   test("Theta Sketch build via Scala with defaults") {
     val n = 100
     val data = (for (i <- 1 to n) yield i).toDF("value")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to