This is an automated email from the ASF dual-hosted git repository.

jmalkin pushed a commit to branch kll_merge
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git

commit ea0c699d1c482b7dead4fb1f42052c72d9601d78
Author: Jon Malkin <[email protected]>
AuthorDate: Tue Jan 7 20:00:41 2025 -0800

    Update kll merge to accept k or fall back toa default value. SQL currently 
complains if specifying k
---
 .../apache/spark/sql/aggregate/KllAggregate.scala  |   8 +-
 .../org/apache/spark/sql/aggregate/KllMerge.scala  | 135 +++++++++++----------
 .../scala/org/apache/spark/sql/functions_ds.scala  |  12 ++
 src/test/scala/org/apache/spark/sql/KllTest.scala  |  57 +++++++--
 .../org/apache/spark/sql/SparkSessionManager.scala |   1 +
 5 files changed, 141 insertions(+), 72 deletions(-)

diff --git a/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala 
b/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala
index c77c7ad..ae2422a 100644
--- a/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala
+++ b/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala
@@ -100,6 +100,8 @@ case class KllDoublesSketchAgg(
 
   override def nullable: Boolean = false
 
+  override def stateful: Boolean = true
+
   override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, 
IntegerType)
 
   override def checkInputDataTypes(): TypeCheckResult = {
@@ -107,12 +109,12 @@ case class KllDoublesSketchAgg(
     if (!right.foldable) {
       return TypeCheckResult.TypeCheckFailure(s"k must be foldable, but got: 
${right}")
     }
-    // Check if k > 0
+    // Check if k >= 8 and k <= MAX_K
     right.eval() match {
-      case k: Int if k > 0 => // valid state, do nothing
+      case k: Int if k >= 8 && k <= KllSketch.MAX_K => // valid state, do 
nothing
       case k: Int if k > KllSketch.MAX_K => return 
TypeCheckResult.TypeCheckFailure(
         s"k must be less than or equal to ${KllSketch.MAX_K}, but got: $k")
-      case k: Int => return TypeCheckResult.TypeCheckFailure(s"k must be 
greater than 0, but got: $k")
+      case k: Int => return TypeCheckResult.TypeCheckFailure(s"k must be at 
least 8 and no greater than ${KllSketch.MAX_K}, but got: $k")
       case _ => return TypeCheckResult.TypeCheckFailure(s"Unsupported input 
type ${right.dataType.catalogString}")
     }
 
diff --git a/src/main/scala/org/apache/spark/sql/aggregate/KllMerge.scala 
b/src/main/scala/org/apache/spark/sql/aggregate/KllMerge.scala
index 4a0d572..77ef12a 100644
--- a/src/main/scala/org/apache/spark/sql/aggregate/KllMerge.scala
+++ b/src/main/scala/org/apache/spark/sql/aggregate/KllMerge.scala
@@ -17,45 +17,59 @@
 
 package org.apache.spark.sql.aggregate
 
+import org.apache.datasketches.memory.Memory
 import org.apache.datasketches.kll.{KllSketch, KllDoublesSketch}
 
 import org.apache.spark.SparkUnsupportedOperationException
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, ExpressionDescription}
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, ExpressionDescription, Literal}
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
-import org.apache.spark.sql.catalyst.trees.UnaryLike
-import org.apache.spark.sql.types.{AbstractDataType, DataType, 
KllDoublesSketchType}
-import org.apache.datasketches.memory.Memory
+import org.apache.spark.sql.catalyst.trees.BinaryLike
+import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, 
KllDoublesSketchType}
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 
 /**
  * The KllDoublesMergeAgg function utilizes a Datasketches KllDoublesSketch 
instance to
  * combine multiple sketches into a single sketch.
  *
- * @param child child expression against which the sketch will be created
+ * @param left Expression from which the sketch will be merged
+ * @param right k, the size-accuracy trade-off parameter for the sketch, int 
in range [1, 65535]
  */
 // scalastyle:off line.size.limit
 @ExpressionDescription(
   usage = """
-    _FUNC_(expr, k) - Merges multiple KllDoublesSketch images and returns the 
binary representation
+    _FUNC_(expr[, k]) - Merges multiple KllDoublesSketch images and returns 
the binary representation
     """,
   examples = """
     Examples:
-      > SELECT kll_get_quantile(_FUNC_(sketch), 0.5) FROM (SELECT 
kll_sketch_agg(col) as sketch FROM VALUES (1.0), (2.0) tab(col) UNION ALL 
SELECT kll_sketch_agg(col) as sketch FROM VALUES (2.0), (3.0) tab(col));
+      > SELECT kll_get_quantile(_FUNC_(sketch), 200) FROM (SELECT 
kll_sketch_agg(col) as sketch FROM VALUES (1.0), (2.0) tab(col) UNION ALL 
SELECT kll_sketch_agg(col) as sketch FROM VALUES (2.0), (3.0) tab(col));
        2.0
   """,
   //group = "agg_funcs",
 )
 // scalastyle:on line.size.limit
 case class KllDoublesMergeAgg(
-    child: Expression,
+    left: Expression,
+    right: Expression,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
-  extends TypedImperativeAggregate[Option[KllDoublesSketch]]
-    with UnaryLike[Expression]
+  extends TypedImperativeAggregate[KllDoublesSketch]
+    with BinaryLike[Expression]
     with ExpectsInputTypes {
 
+  lazy val k: Int = {
+    right.eval() match {
+      case null => KllSketch.DEFAULT_K
+      case k: Int => k
+      // this shouldn't happen after checkInputDataTypes()
+      case _ => throw new SparkUnsupportedOperationException(
+        s"Unsupported input type ${right.dataType.catalogString}",
+        Map("dataType" -> dataType.toString))
+    }
+  }
+
   // Constructors
-  def this(child: Expression) = this(child, 0, 0)
+  def this(left: Expression) = this(left, Literal(KllSketch.DEFAULT_K), 0, 0)
 
   // Copy constructors
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
KllDoublesMergeAgg =
@@ -64,9 +78,8 @@ case class KllDoublesMergeAgg(
   override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
KllDoublesMergeAgg =
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-  override protected def withNewChildInternal(newChild: Expression): 
KllDoublesMergeAgg =
-    copy(child = newChild)
-
+  override protected def withNewChildrenInternal(newLeft: Expression, 
newRight: Expression): KllDoublesMergeAgg =
+    copy(left = newLeft, right = newRight)
 
   // overrides for TypedImperativeAggregate
   override def prettyName: String = "kll_merge_agg"
@@ -75,74 +88,74 @@ case class KllDoublesMergeAgg(
 
   override def nullable: Boolean = false
 
-  // TODO: refine this?
-  override def inputTypes: Seq[AbstractDataType] = Seq(KllDoublesSketchType)
+  override def stateful: Boolean = true
 
-  // create buffer
-  override def createAggregationBuffer(): Option[KllDoublesSketch] = {
-    None
+  override def inputTypes: Seq[AbstractDataType] = Seq(KllDoublesSketchType, 
IntegerType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    // k must be a constant
+    if (!right.foldable) {
+      return TypeCheckResult.TypeCheckFailure(s"k must be foldable, but got: 
${right}")
+    }
+    // Check if k >= 8 and k <= MAX_K
+    right.eval() match {
+      case k: Int if k >= 8 && k <= KllSketch.MAX_K => // valid state, do 
nothing
+      case k: Int if k > KllSketch.MAX_K => return 
TypeCheckResult.TypeCheckFailure(
+        s"k must be less than or equal to ${KllSketch.MAX_K}, but got: $k")
+      case k: Int => return TypeCheckResult.TypeCheckFailure(s"k must be at 
least 8 and no greater than ${KllSketch.MAX_K}, but got: $k")
+      case _ => return TypeCheckResult.TypeCheckFailure(s"Unsupported input 
type ${right.dataType.catalogString}")
+    }
+
+    // additional validations of k handled in the DataSketches library
+    TypeCheckResult.TypeCheckSuccess
+  }
+
+  override def createAggregationBuffer(): KllDoublesSketch = {
+    KllDoublesSketch.newHeapInstance(k)
   }
 
-  // update
-  override def update(unionOption: Option[KllDoublesSketch], input: 
InternalRow): Option[KllDoublesSketch] = {
-    val value = child.eval(input)
+  override def update(union: KllDoublesSketch, input: InternalRow): 
KllDoublesSketch = {
+    val value = left.eval(input)
     if (value != null && value != None) {
-      child.dataType match {
+      left.dataType match {
         case KllDoublesSketchType =>
-          if (unionOption == None || unionOption.get.isEmpty) {
-            // if union is empty, just return a copy of the input sketch
-            // TODO: is this serialized or already as a sketch object?
-            
Some(KllDoublesSketch.heapify(Memory.wrap(value.asInstanceOf[Array[Byte]])))
-          } else {
-            
unionOption.get.merge(KllDoublesSketch.wrap(Memory.wrap(value.asInstanceOf[Array[Byte]])))
-            unionOption
-          }
+            
union.merge(KllDoublesSketch.wrap(Memory.wrap(value.asInstanceOf[Array[Byte]])))
+            union
         case _ => throw new SparkUnsupportedOperationException(
-          s"Unsupported input type ${child.dataType.catalogString}",
+          s"Unsupported input type ${left.dataType.catalogString}",
           Map("dataType" -> dataType.toString))
       }
     } else {
-      unionOption
+      union
     }
   }
 
-  // union (merge)
-  override def merge(unionOption: Option[KllDoublesSketch], otherOption: 
Option[KllDoublesSketch]): Option[KllDoublesSketch] = {
-    (unionOption, otherOption) match {
-      case (Some(union), Some(other)) =>
-        union.merge(other)
-        Some(union)
-
-      // for these others, we'll return the input even if degenerate
-      case (Some(union), None) =>
-        unionOption
-      case (None, Some(other)) =>
-        otherOption
-      case (None, None) =>
-        unionOption
+  override def merge(union: KllDoublesSketch, other: KllDoublesSketch): 
KllDoublesSketch = {
+    if (union != null && other != null) {
+      union.merge(other)
+      union
+    } else if (union != null && other == null) {
+      union
+    } else if (union == null && other != null) {
+      other
+    } else {
+      union
     }
   }
 
-  // eval
-  override def eval(unionOption: Option[KllDoublesSketch]): Any = {
-    unionOption match {
-      case Some(sketch) => sketch.toByteArray
-      case None => None // can this happen in practice? If so, what should we 
return?
-    }
+  override def eval(sketch: KllDoublesSketch): Any = {
+    sketch.toByteArray
   }
 
-  override def serialize(sketchOption: Option[KllDoublesSketch]): Array[Byte] 
= {
-    sketchOption match {
-      case Some(sketch) => sketch.toByteArray
-      case None => 
KllDoublesSketch.newHeapInstance(KllSketch.DEFAULT_K).toByteArray
-    }
+  override def serialize(sketch: KllDoublesSketch): Array[Byte] = {
+    sketch.toByteArray()
   }
 
-  override def deserialize(bytes: Array[Byte]): Option[KllDoublesSketch] = {
+  override def deserialize(bytes: Array[Byte]): KllDoublesSketch = {
     if (bytes.length > 0) {
-      Some(KllDoublesSketchType.deserialize(bytes))
+      KllDoublesSketchType.deserialize(bytes)
     } else {
-      None
+      null
     }
   }
 }
diff --git a/src/main/scala/org/apache/spark/sql/functions_ds.scala 
b/src/main/scala/org/apache/spark/sql/functions_ds.scala
index 7fa9e7f..62aaf95 100644
--- a/src/main/scala/org/apache/spark/sql/functions_ds.scala
+++ b/src/main/scala/org/apache/spark/sql/functions_ds.scala
@@ -84,6 +84,18 @@ object functions_ds {
     kll_merge_agg(Column(columnName))
   }
 
+  def kll_merge_agg(expr: Column, k: Column): Column = withAggregateFunction {
+    new KllDoublesMergeAgg(expr.expr, k.expr)
+  }
+
+  def kll_merge_agg(expr: Column, k: Int): Column = withAggregateFunction {
+    new KllDoublesMergeAgg(expr.expr, lit(k).expr)
+  }
+
+  def kll_merge_agg(columnName: String, k: Int): Column = {
+    kll_merge_agg(Column(columnName), lit(k))
+  }
+
   // get PMF
   def kll_get_pmf(sketch: Column, splitPoints: Column, isInclusive: Boolean): 
Column = withExpr {
     new KllGetPmfCdf(sketch.expr, splitPoints.expr, 
Literal.create(isInclusive, BooleanType), true)
diff --git a/src/test/scala/org/apache/spark/sql/KllTest.scala 
b/src/test/scala/org/apache/spark/sql/KllTest.scala
index f9c835d..1824ef8 100644
--- a/src/test/scala/org/apache/spark/sql/KllTest.scala
+++ b/src/test/scala/org/apache/spark/sql/KllTest.scala
@@ -180,16 +180,32 @@ class KllTest extends SparkSessionManager {
     // create a sketch for each id value
     val idSketchDf = 
data.groupBy($"id").agg(kll_sketch_agg($"value").as("sketch"))
 
+    // default k
     // merge into an aggregate sketch
-    val mergedSketchDf = idSketchDf.agg(kll_merge_agg($"sketch").as("sketch"))
+    var mergedSketchDf = idSketchDf.agg(kll_merge_agg($"sketch").as("sketch"))
 
     // check min and max
-    val result: Row = mergedSketchDf.select(kll_get_min($"sketch").as("min"),
+    var result: Row = mergedSketchDf.select(kll_get_min($"sketch").as("min"),
                                             kll_get_max($"sketch").as("max"))
                                     .head
 
-    val sketchMin = result.getAs[Double]("min")
-    val sketchMax = result.getAs[Double]("max")
+    var sketchMin = result.getAs[Double]("min")
+    var sketchMax = result.getAs[Double]("max")
+
+    assert(globalMin == sketchMin)
+    assert(globalMax == sketchMax)
+
+    // specified k
+    // merge into an aggregate sketch
+    mergedSketchDf = idSketchDf.agg(kll_merge_agg($"sketch", 160).as("sketch"))
+
+    // check min and max
+    result = mergedSketchDf.select(kll_get_min($"sketch").as("min"),
+                                   kll_get_max($"sketch").as("max"))
+                           .head
+
+    sketchMin = result.getAs[Double]("min")
+    sketchMax = result.getAs[Double]("max")
 
     assert(globalMin == sketchMin)
     assert(globalMax == sketchMax)
@@ -222,8 +238,9 @@ class KllTest extends SparkSessionManager {
     )
     idSketchDf.createOrReplaceTempView("sketch_table")
 
+    // default k
     // now merge the sketches
-    val mergedSketchDf = spark.sql(
+    var mergedSketchDf = spark.sql(
       s"""
       |SELECT
       |  kll_get_min(sub.sketch) AS min,
@@ -238,9 +255,33 @@ class KllTest extends SparkSessionManager {
     )
 
     // check min and max
-    val result: Row = mergedSketchDf.head
-    val sketchMin = result.getAs[Double]("min")
-    val sketchMax = result.getAs[Double]("max")
+    var result: Row = mergedSketchDf.head
+    var sketchMin = result.getAs[Double]("min")
+    var sketchMax = result.getAs[Double]("max")
+
+    assert(globalMin == sketchMin)
+    assert(globalMax == sketchMax)
+
+    // specified k
+    // now merge the sketches
+    mergedSketchDf = spark.sql(
+      s"""
+      |SELECT
+      |  kll_get_min(sub.sketch) AS min,
+      |  kll_get_max(sub.sketch) AS max
+      |FROM
+      |  (SELECT
+      |     kll_merge_agg(sketch, 160) AS sketch
+      |  FROM
+      |    sketch_table
+      |  ) sub
+      """.stripMargin
+    )
+
+    // check min and max
+    result = mergedSketchDf.head
+    sketchMin = result.getAs[Double]("min")
+    sketchMax = result.getAs[Double]("max")
 
     assert(globalMin == sketchMin)
     assert(globalMax == sketchMax)
diff --git a/src/test/scala/org/apache/spark/sql/SparkSessionManager.scala 
b/src/test/scala/org/apache/spark/sql/SparkSessionManager.scala
index 3675620..4c96cb4 100644
--- a/src/test/scala/org/apache/spark/sql/SparkSessionManager.scala
+++ b/src/test/scala/org/apache/spark/sql/SparkSessionManager.scala
@@ -33,6 +33,7 @@ trait SparkSessionManager extends AnyFunSuite with 
BeforeAndAfterAll {
       .builder()
       .appName("datasketches-spark-tests")
       .master("local[3]")
+      //.config("spark.sql.debug.codegen", "true")
       .getOrCreate()
 
   override def beforeAll(): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to