This is an automated email from the ASF dual-hosted git repository. jmalkin pushed a commit to branch kll_merge in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git
commit ea0c699d1c482b7dead4fb1f42052c72d9601d78 Author: Jon Malkin <[email protected]> AuthorDate: Tue Jan 7 20:00:41 2025 -0800 Update kll merge to accept k or fall back toa default value. SQL currently complains if specifying k --- .../apache/spark/sql/aggregate/KllAggregate.scala | 8 +- .../org/apache/spark/sql/aggregate/KllMerge.scala | 135 +++++++++++---------- .../scala/org/apache/spark/sql/functions_ds.scala | 12 ++ src/test/scala/org/apache/spark/sql/KllTest.scala | 57 +++++++-- .../org/apache/spark/sql/SparkSessionManager.scala | 1 + 5 files changed, 141 insertions(+), 72 deletions(-) diff --git a/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala b/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala index c77c7ad..ae2422a 100644 --- a/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala +++ b/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala @@ -100,6 +100,8 @@ case class KllDoublesSketchAgg( override def nullable: Boolean = false + override def stateful: Boolean = true + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) override def checkInputDataTypes(): TypeCheckResult = { @@ -107,12 +109,12 @@ case class KllDoublesSketchAgg( if (!right.foldable) { return TypeCheckResult.TypeCheckFailure(s"k must be foldable, but got: ${right}") } - // Check if k > 0 + // Check if k >= 8 and k <= MAX_K right.eval() match { - case k: Int if k > 0 => // valid state, do nothing + case k: Int if k >= 8 && k <= KllSketch.MAX_K => // valid state, do nothing case k: Int if k > KllSketch.MAX_K => return TypeCheckResult.TypeCheckFailure( s"k must be less than or equal to ${KllSketch.MAX_K}, but got: $k") - case k: Int => return TypeCheckResult.TypeCheckFailure(s"k must be greater than 0, but got: $k") + case k: Int => return TypeCheckResult.TypeCheckFailure(s"k must be at least 8 and no greater than ${KllSketch.MAX_K}, but got: $k") case _ => return TypeCheckResult.TypeCheckFailure(s"Unsupported input type ${right.dataType.catalogString}") } diff --git a/src/main/scala/org/apache/spark/sql/aggregate/KllMerge.scala b/src/main/scala/org/apache/spark/sql/aggregate/KllMerge.scala index 4a0d572..77ef12a 100644 --- a/src/main/scala/org/apache/spark/sql/aggregate/KllMerge.scala +++ b/src/main/scala/org/apache/spark/sql/aggregate/KllMerge.scala @@ -17,45 +17,59 @@ package org.apache.spark.sql.aggregate +import org.apache.datasketches.memory.Memory import org.apache.datasketches.kll.{KllSketch, KllDoublesSketch} import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate -import org.apache.spark.sql.catalyst.trees.UnaryLike -import org.apache.spark.sql.types.{AbstractDataType, DataType, KllDoublesSketchType} -import org.apache.datasketches.memory.Memory +import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, KllDoublesSketchType} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult /** * The KllDoublesMergeAgg function utilizes a Datasketches KllDoublesSketch instance to * combine multiple sketches into a single sketch. * - * @param child child expression against which the sketch will be created + * @param left Expression from which the sketch will be merged + * @param right k, the size-accuracy trade-off parameter for the sketch, int in range [1, 65535] */ // scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(expr, k) - Merges multiple KllDoublesSketch images and returns the binary representation + _FUNC_(expr[, k]) - Merges multiple KllDoublesSketch images and returns the binary representation """, examples = """ Examples: - > SELECT kll_get_quantile(_FUNC_(sketch), 0.5) FROM (SELECT kll_sketch_agg(col) as sketch FROM VALUES (1.0), (2.0) tab(col) UNION ALL SELECT kll_sketch_agg(col) as sketch FROM VALUES (2.0), (3.0) tab(col)); + > SELECT kll_get_quantile(_FUNC_(sketch), 200) FROM (SELECT kll_sketch_agg(col) as sketch FROM VALUES (1.0), (2.0) tab(col) UNION ALL SELECT kll_sketch_agg(col) as sketch FROM VALUES (2.0), (3.0) tab(col)); 2.0 """, //group = "agg_funcs", ) // scalastyle:on line.size.limit case class KllDoublesMergeAgg( - child: Expression, + left: Expression, + right: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[Option[KllDoublesSketch]] - with UnaryLike[Expression] + extends TypedImperativeAggregate[KllDoublesSketch] + with BinaryLike[Expression] with ExpectsInputTypes { + lazy val k: Int = { + right.eval() match { + case null => KllSketch.DEFAULT_K + case k: Int => k + // this shouldn't happen after checkInputDataTypes() + case _ => throw new SparkUnsupportedOperationException( + s"Unsupported input type ${right.dataType.catalogString}", + Map("dataType" -> dataType.toString)) + } + } + // Constructors - def this(child: Expression) = this(child, 0, 0) + def this(left: Expression) = this(left, Literal(KllSketch.DEFAULT_K), 0, 0) // Copy constructors override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): KllDoublesMergeAgg = @@ -64,9 +78,8 @@ case class KllDoublesMergeAgg( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): KllDoublesMergeAgg = copy(inputAggBufferOffset = newInputAggBufferOffset) - override protected def withNewChildInternal(newChild: Expression): KllDoublesMergeAgg = - copy(child = newChild) - + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): KllDoublesMergeAgg = + copy(left = newLeft, right = newRight) // overrides for TypedImperativeAggregate override def prettyName: String = "kll_merge_agg" @@ -75,74 +88,74 @@ case class KllDoublesMergeAgg( override def nullable: Boolean = false - // TODO: refine this? - override def inputTypes: Seq[AbstractDataType] = Seq(KllDoublesSketchType) + override def stateful: Boolean = true - // create buffer - override def createAggregationBuffer(): Option[KllDoublesSketch] = { - None + override def inputTypes: Seq[AbstractDataType] = Seq(KllDoublesSketchType, IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + // k must be a constant + if (!right.foldable) { + return TypeCheckResult.TypeCheckFailure(s"k must be foldable, but got: ${right}") + } + // Check if k >= 8 and k <= MAX_K + right.eval() match { + case k: Int if k >= 8 && k <= KllSketch.MAX_K => // valid state, do nothing + case k: Int if k > KllSketch.MAX_K => return TypeCheckResult.TypeCheckFailure( + s"k must be less than or equal to ${KllSketch.MAX_K}, but got: $k") + case k: Int => return TypeCheckResult.TypeCheckFailure(s"k must be at least 8 and no greater than ${KllSketch.MAX_K}, but got: $k") + case _ => return TypeCheckResult.TypeCheckFailure(s"Unsupported input type ${right.dataType.catalogString}") + } + + // additional validations of k handled in the DataSketches library + TypeCheckResult.TypeCheckSuccess + } + + override def createAggregationBuffer(): KllDoublesSketch = { + KllDoublesSketch.newHeapInstance(k) } - // update - override def update(unionOption: Option[KllDoublesSketch], input: InternalRow): Option[KllDoublesSketch] = { - val value = child.eval(input) + override def update(union: KllDoublesSketch, input: InternalRow): KllDoublesSketch = { + val value = left.eval(input) if (value != null && value != None) { - child.dataType match { + left.dataType match { case KllDoublesSketchType => - if (unionOption == None || unionOption.get.isEmpty) { - // if union is empty, just return a copy of the input sketch - // TODO: is this serialized or already as a sketch object? - Some(KllDoublesSketch.heapify(Memory.wrap(value.asInstanceOf[Array[Byte]]))) - } else { - unionOption.get.merge(KllDoublesSketch.wrap(Memory.wrap(value.asInstanceOf[Array[Byte]]))) - unionOption - } + union.merge(KllDoublesSketch.wrap(Memory.wrap(value.asInstanceOf[Array[Byte]]))) + union case _ => throw new SparkUnsupportedOperationException( - s"Unsupported input type ${child.dataType.catalogString}", + s"Unsupported input type ${left.dataType.catalogString}", Map("dataType" -> dataType.toString)) } } else { - unionOption + union } } - // union (merge) - override def merge(unionOption: Option[KllDoublesSketch], otherOption: Option[KllDoublesSketch]): Option[KllDoublesSketch] = { - (unionOption, otherOption) match { - case (Some(union), Some(other)) => - union.merge(other) - Some(union) - - // for these others, we'll return the input even if degenerate - case (Some(union), None) => - unionOption - case (None, Some(other)) => - otherOption - case (None, None) => - unionOption + override def merge(union: KllDoublesSketch, other: KllDoublesSketch): KllDoublesSketch = { + if (union != null && other != null) { + union.merge(other) + union + } else if (union != null && other == null) { + union + } else if (union == null && other != null) { + other + } else { + union } } - // eval - override def eval(unionOption: Option[KllDoublesSketch]): Any = { - unionOption match { - case Some(sketch) => sketch.toByteArray - case None => None // can this happen in practice? If so, what should we return? - } + override def eval(sketch: KllDoublesSketch): Any = { + sketch.toByteArray } - override def serialize(sketchOption: Option[KllDoublesSketch]): Array[Byte] = { - sketchOption match { - case Some(sketch) => sketch.toByteArray - case None => KllDoublesSketch.newHeapInstance(KllSketch.DEFAULT_K).toByteArray - } + override def serialize(sketch: KllDoublesSketch): Array[Byte] = { + sketch.toByteArray() } - override def deserialize(bytes: Array[Byte]): Option[KllDoublesSketch] = { + override def deserialize(bytes: Array[Byte]): KllDoublesSketch = { if (bytes.length > 0) { - Some(KllDoublesSketchType.deserialize(bytes)) + KllDoublesSketchType.deserialize(bytes) } else { - None + null } } } diff --git a/src/main/scala/org/apache/spark/sql/functions_ds.scala b/src/main/scala/org/apache/spark/sql/functions_ds.scala index 7fa9e7f..62aaf95 100644 --- a/src/main/scala/org/apache/spark/sql/functions_ds.scala +++ b/src/main/scala/org/apache/spark/sql/functions_ds.scala @@ -84,6 +84,18 @@ object functions_ds { kll_merge_agg(Column(columnName)) } + def kll_merge_agg(expr: Column, k: Column): Column = withAggregateFunction { + new KllDoublesMergeAgg(expr.expr, k.expr) + } + + def kll_merge_agg(expr: Column, k: Int): Column = withAggregateFunction { + new KllDoublesMergeAgg(expr.expr, lit(k).expr) + } + + def kll_merge_agg(columnName: String, k: Int): Column = { + kll_merge_agg(Column(columnName), lit(k)) + } + // get PMF def kll_get_pmf(sketch: Column, splitPoints: Column, isInclusive: Boolean): Column = withExpr { new KllGetPmfCdf(sketch.expr, splitPoints.expr, Literal.create(isInclusive, BooleanType), true) diff --git a/src/test/scala/org/apache/spark/sql/KllTest.scala b/src/test/scala/org/apache/spark/sql/KllTest.scala index f9c835d..1824ef8 100644 --- a/src/test/scala/org/apache/spark/sql/KllTest.scala +++ b/src/test/scala/org/apache/spark/sql/KllTest.scala @@ -180,16 +180,32 @@ class KllTest extends SparkSessionManager { // create a sketch for each id value val idSketchDf = data.groupBy($"id").agg(kll_sketch_agg($"value").as("sketch")) + // default k // merge into an aggregate sketch - val mergedSketchDf = idSketchDf.agg(kll_merge_agg($"sketch").as("sketch")) + var mergedSketchDf = idSketchDf.agg(kll_merge_agg($"sketch").as("sketch")) // check min and max - val result: Row = mergedSketchDf.select(kll_get_min($"sketch").as("min"), + var result: Row = mergedSketchDf.select(kll_get_min($"sketch").as("min"), kll_get_max($"sketch").as("max")) .head - val sketchMin = result.getAs[Double]("min") - val sketchMax = result.getAs[Double]("max") + var sketchMin = result.getAs[Double]("min") + var sketchMax = result.getAs[Double]("max") + + assert(globalMin == sketchMin) + assert(globalMax == sketchMax) + + // specified k + // merge into an aggregate sketch + mergedSketchDf = idSketchDf.agg(kll_merge_agg($"sketch", 160).as("sketch")) + + // check min and max + result = mergedSketchDf.select(kll_get_min($"sketch").as("min"), + kll_get_max($"sketch").as("max")) + .head + + sketchMin = result.getAs[Double]("min") + sketchMax = result.getAs[Double]("max") assert(globalMin == sketchMin) assert(globalMax == sketchMax) @@ -222,8 +238,9 @@ class KllTest extends SparkSessionManager { ) idSketchDf.createOrReplaceTempView("sketch_table") + // default k // now merge the sketches - val mergedSketchDf = spark.sql( + var mergedSketchDf = spark.sql( s""" |SELECT | kll_get_min(sub.sketch) AS min, @@ -238,9 +255,33 @@ class KllTest extends SparkSessionManager { ) // check min and max - val result: Row = mergedSketchDf.head - val sketchMin = result.getAs[Double]("min") - val sketchMax = result.getAs[Double]("max") + var result: Row = mergedSketchDf.head + var sketchMin = result.getAs[Double]("min") + var sketchMax = result.getAs[Double]("max") + + assert(globalMin == sketchMin) + assert(globalMax == sketchMax) + + // specified k + // now merge the sketches + mergedSketchDf = spark.sql( + s""" + |SELECT + | kll_get_min(sub.sketch) AS min, + | kll_get_max(sub.sketch) AS max + |FROM + | (SELECT + | kll_merge_agg(sketch, 160) AS sketch + | FROM + | sketch_table + | ) sub + """.stripMargin + ) + + // check min and max + result = mergedSketchDf.head + sketchMin = result.getAs[Double]("min") + sketchMax = result.getAs[Double]("max") assert(globalMin == sketchMin) assert(globalMax == sketchMax) diff --git a/src/test/scala/org/apache/spark/sql/SparkSessionManager.scala b/src/test/scala/org/apache/spark/sql/SparkSessionManager.scala index 3675620..4c96cb4 100644 --- a/src/test/scala/org/apache/spark/sql/SparkSessionManager.scala +++ b/src/test/scala/org/apache/spark/sql/SparkSessionManager.scala @@ -33,6 +33,7 @@ trait SparkSessionManager extends AnyFunSuite with BeforeAndAfterAll { .builder() .appName("datasketches-spark-tests") .master("local[3]") + //.config("spark.sql.debug.codegen", "true") .getOrCreate() override def beforeAll(): Unit = { --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
