This is an automated email from the ASF dual-hosted git repository. jmalkin pushed a commit to branch simplify_registration in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git
commit 455dcde210799a09b8f5df8dc5678d7f3a82376d Author: Jon Malkin <[email protected]> AuthorDate: Tue Jan 7 13:19:50 2025 -0800 Two unrelated changes. Add type-checking to k in sketch creation, and modify pmf/cdf to allow simpler registration for SQL --- build.sbt | 4 - .../apache/spark/sql/aggregate/KllAggregate.scala | 33 ++++-- .../spark/sql/expressions/KllExpressions.scala | 121 ++++++++++++++++----- .../scala/org/apache/spark/sql/functions_ds.scala | 11 +- .../registrar/DatasketchesFunctionRegistry.scala | 42 +++---- src/test/scala/org/apache/spark/sql/KllTest.scala | 4 +- 6 files changed, 141 insertions(+), 74 deletions(-) diff --git a/build.sbt b/build.sbt index daa4c63..eb93702 100644 --- a/build.sbt +++ b/build.sbt @@ -49,9 +49,5 @@ scalacOptions ++= Seq( Test / logBuffered := false -// Only show warnings and errors on the screen for compilations. -// This applies to both test:compile and compile and is Info by default -Compile / logLevel := Level.Warn - // Level.INFO is needed to see detailed output when running tests Test / logLevel := Level.Info diff --git a/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala b/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala index 3b7a506..c77c7ad 100644 --- a/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala +++ b/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala @@ -24,14 +24,15 @@ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, LongType, NumericType, FloatType, DoubleType, KllDoublesSketchType} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult /** * The KllDoublesSketchAgg function utilizes a Datasketches KllDoublesSketch instance * to create a sketch from a column of values which can be used to estimate quantiles * and histograms. * - * @param child child expression against which the sketch will be created - * @param k the size-accuracy trade-off parameter for the sketch + * @param left Expression against which the sketch will be created + * @param right k, the size-accuracy trade-off parameter for the sketch, int in range [1, 65535] */ // scalastyle:off line.size.limit @ExpressionDescription( @@ -58,6 +59,7 @@ case class KllDoublesSketchAgg( right.eval() match { case null => KllSketch.DEFAULT_K case k: Int => k + // this shouldn't happen after checkInputDataTypes() case _ => throw new SparkUnsupportedOperationException( s"Unsupported input type ${right.dataType.catalogString}", Map("dataType" -> dataType.toString)) @@ -65,7 +67,6 @@ case class KllDoublesSketchAgg( } // Constructors - def this(child: Expression) = { this(child, Literal(KllSketch.DEFAULT_K), 0, 0) } @@ -91,18 +92,36 @@ case class KllDoublesSketchAgg( } // overrides for TypedImperativeAggregate + override lazy val deterministic: Boolean = false + override def prettyName: String = "kll_sketch_agg" override def dataType: DataType = KllDoublesSketchType override def nullable: Boolean = false - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType, LongType, FloatType, DoubleType) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + // k must be a constant + if (!right.foldable) { + return TypeCheckResult.TypeCheckFailure(s"k must be foldable, but got: ${right}") + } + // Check if k > 0 + right.eval() match { + case k: Int if k > 0 => // valid state, do nothing + case k: Int if k > KllSketch.MAX_K => return TypeCheckResult.TypeCheckFailure( + s"k must be less than or equal to ${KllSketch.MAX_K}, but got: $k") + case k: Int => return TypeCheckResult.TypeCheckFailure(s"k must be greater than 0, but got: $k") + case _ => return TypeCheckResult.TypeCheckFailure(s"Unsupported input type ${right.dataType.catalogString}") + } + + // additional validations of k handled in the DataSketches library + TypeCheckResult.TypeCheckSuccess + } - // create buffer override def createAggregationBuffer(): KllDoublesSketch = KllDoublesSketch.newHeapInstance(k) - // update override def update(sketch: KllDoublesSketch, input: InternalRow): KllDoublesSketch = { val value = left.eval(input) if (value != null) { @@ -119,7 +138,6 @@ case class KllDoublesSketchAgg( sketch } - // union (merge) override def merge(sketch: KllDoublesSketch, other: KllDoublesSketch): KllDoublesSketch = { if (other != null && !other.isEmpty) { sketch.merge(other) @@ -127,7 +145,6 @@ case class KllDoublesSketchAgg( sketch } - // eval override def eval(sketch: KllDoublesSketch): Any = { if (sketch == null || sketch.isEmpty) { null diff --git a/src/main/scala/org/apache/spark/sql/expressions/KllExpressions.scala b/src/main/scala/org/apache/spark/sql/expressions/KllExpressions.scala index 745261b..34b6296 100644 --- a/src/main/scala/org/apache/spark/sql/expressions/KllExpressions.scala +++ b/src/main/scala/org/apache/spark/sql/expressions/KllExpressions.scala @@ -19,14 +19,16 @@ package org.apache.spark.sql.expressions import org.apache.datasketches.memory.Memory import org.apache.datasketches.kll.KllDoublesSketch -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpectsInputTypes, UnaryExpression, BinaryExpression} -import org.apache.spark.sql.types.{AbstractDataType, DataType, ArrayType, DoubleType, KllDoublesSketchType} -import org.apache.spark.sql.catalyst.expressions.NullIntolerant -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeBlock, CodegenContext, ExprCode} import org.apache.datasketches.quantilescommon.QuantileSearchCriteria +import org.apache.spark.sql.types.KllDoublesSketchType + +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, BooleanType, DataType, DoubleType} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ExpectsInputTypes, ImplicitCastInputTypes} +import org.apache.spark.sql.catalyst.expressions.{UnaryExpression, TernaryExpression} +import org.apache.spark.sql.catalyst.expressions.{Literal, NullIntolerant, RuntimeReplaceable} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeBlock, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.GenericArrayData -import org.apache.spark.sql.catalyst.expressions.ExpressionDescription -import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes +import org.apache.spark.sql.catalyst.trees.TernaryLike @ExpressionDescription( usage = """ @@ -128,13 +130,73 @@ case class KllGetMax(child: Expression) } } +@ExpressionDescription( + usage = """ + _FUNC_(expr, expr, isInclusive) - Returns an approximation to the PMF + of the given KllDoublesSketch using the specified search criteria (default: inclusive, isInclusive = true) + or exclusive using the given split points. + """, + examples = """ + Examples: + > SELECT _FUNC_(kll_sketch_agg(col), array(1.5, 3.5)) FROM VALUES (1.0), (2.0), (3.0) tab(col); + [0.3333333333333333, 0.6666666666666666, 0.0] + """ +) +case class KllGetPmf(first: Expression, + second: Expression, + third: Expression) + extends RuntimeReplaceable + with ImplicitCastInputTypes + with TernaryLike[Expression] { + + def this(first: Expression, second: Expression) = { + this(first, second, Literal(true)) + } + + override lazy val replacement: Expression = KllGetPmfCdf(first, second, third, true) + override def inputTypes: Seq[AbstractDataType] = Seq(KllDoublesSketchType, ArrayType(DoubleType), BooleanType) + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = { + copy(first = newFirst, second = newSecond, third = newThird) + } +} + +@ExpressionDescription( + usage = """ + _FUNC_(expr, expr, isInclusive) - Returns an approximation to the PMF + of the given KllDoublesSketch using the specified search criteria (default: inclusive, isInclusive = true) + or exclusive using the given split points. + """, + examples = """ + Examples: + > SELECT _FUNC_(kll_sketch_agg(col), array(1.5, 3.5)) FROM VALUES (1.0), (2.0), (3.0) tab(col); + [0.3333333333333333, 0.6666666666666666, 0.0] + """ +) +case class KllGetCdf(first: Expression, + second: Expression, + third: Expression) + extends RuntimeReplaceable + with ImplicitCastInputTypes + with TernaryLike[Expression] { + + def this(first: Expression, second: Expression) = { + this(first, second, Literal(true)) + } + + override lazy val replacement: Expression = KllGetPmfCdf(first, second, third, false) + override def inputTypes: Seq[AbstractDataType] = Seq(KllDoublesSketchType, ArrayType(DoubleType), BooleanType) + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = { + copy(first = newFirst, second = newSecond, third = newThird) + } +} + /** * Returns the PMF and CDF of the given quantile search criteria. * - * @param left A KllDoublesSketch sketch, in serialized form - * @param right An array of split points, as doubles - * @param isInclusive If true, use INCLUSIVE else EXCLUSIVE + * @param first A KllDoublesSketch sketch, in serialized form + * @param second An array of split points, as doubles + * @param third A boolean flag for inclusive mode. If true, use INCLUSIVE else EXCLUSIVE * @param isPmf Whether to return the PMF (true) or CDF (false) */ @ExpressionDescription( @@ -149,29 +211,32 @@ case class KllGetMax(child: Expression) [0.3333333333333333, 0.6666666666666666, 0.0] """ ) -case class KllGetPmfCdf(left: Expression, - right: Expression, - isInclusive: Boolean = true, +case class KllGetPmfCdf(first: Expression, + second: Expression, + third: Expression, isPmf: Boolean = false) - extends BinaryExpression + extends TernaryExpression with ExpectsInputTypes with NullIntolerant with ImplicitCastInputTypes { - override protected def withNewChildrenInternal(newLeft: Expression, - newRight: Expression) = { - copy(left = newLeft, right = newRight, isInclusive = isInclusive, isPmf = isPmf) + lazy val isInclusive = third.eval().asInstanceOf[Boolean] + + override protected def withNewChildrenInternal(newFirst: Expression, + newSecond: Expression, + newThird: Expression) = { + copy(first = newFirst, second = newSecond, third = newThird, isPmf = isPmf) } override def prettyName: String = "kll_get_pmf_cdf" - override def inputTypes: Seq[AbstractDataType] = Seq(KllDoublesSketchType, ArrayType(DoubleType)) + override def inputTypes: Seq[AbstractDataType] = Seq(KllDoublesSketchType, ArrayType(DoubleType), BooleanType) override def dataType: DataType = ArrayType(DoubleType, containsNull = false) - override def nullSafeEval(leftInput: Any, rightInput: Any): Any = { - val sketchBytes = leftInput.asInstanceOf[Array[Byte]] - val splitPoints = rightInput.asInstanceOf[GenericArrayData].toDoubleArray + override def nullSafeEval(firstInput: Any, secondInput: Any, thirdInput: Any): Any = { + val sketchBytes = firstInput.asInstanceOf[Array[Byte]] + val splitPoints = secondInput.asInstanceOf[GenericArrayData].toDoubleArray val sketch = KllDoublesSketch.wrap(Memory.wrap(sketchBytes)) val result: Array[Double] = @@ -183,30 +248,30 @@ case class KllGetPmfCdf(left: Expression, new GenericArrayData(result) } - override protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f: (String, String) => String): ExprCode = { - val sketchEval = left.genCode(ctx) + override protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f: (String, String, String) => String): ExprCode = { + val sketchEval = first.genCode(ctx) val sketch = ctx.freshName("sketch") - val splitPointsEval = right.genCode(ctx) + val splitPointsEval = second.genCode(ctx) val code = s""" |${sketchEval.code} |${splitPointsEval.code} |if (${sketchEval.isNull} || ${splitPointsEval.isNull}) { - | ${ev.isNull} = true; + | boolean ${ev.isNull} = true; |} else { - | QuantileSearchCriteria searchCriteria = ${if (isInclusive) "QuantileSearchCriteria.INCLUSIVE" else "QuantileSearchCriteria.EXCLUSIVE"}; + | org.apache.datasketches.quantilescommon.QuantileSearchCriteria searchCriteria = ${if (isInclusive) "org.apache.datasketches.quantilescommon.QuantileSearchCriteria.INCLUSIVE" else "org.apache.datasketches.quantilescommon.QuantileSearchCriteria.EXCLUSIVE"}; | final org.apache.datasketches.kll.KllDoublesSketch $sketch = org.apache.spark.sql.types.KllDoublesSketchType.wrap(${sketchEval.value}); | final double[] splitPoints = ((org.apache.spark.sql.catalyst.util.GenericArrayData)${splitPointsEval.value}).toDoubleArray(); | final double[] result = ${if (isPmf) s"$sketch.getPMF(splitPoints, searchCriteria)" else s"$sketch.getCDF(splitPoints, searchCriteria)"}; - | GenericArrayData ${ev.value} = new GenericArrayData(result); - | ${ev.isNull} = false; + | org.apache.spark.sql.catalyst.util.GenericArrayData ${ev.value} = new org.apache.spark.sql.catalyst.util.GenericArrayData(result); + | boolean ${ev.isNull} = false; |} """.stripMargin ev.copy(code = CodeBlock(Seq(code), Seq.empty)) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (arg1, arg2) => s"($arg1, $arg2)") + nullSafeCodeGen(ctx, ev, (arg1, arg2, arg3) => s"($arg1, $arg2, $arg3)") } } diff --git a/src/main/scala/org/apache/spark/sql/functions_ds.scala b/src/main/scala/org/apache/spark/sql/functions_ds.scala index 7c2a96e..7fa9e7f 100644 --- a/src/main/scala/org/apache/spark/sql/functions_ds.scala +++ b/src/main/scala/org/apache/spark/sql/functions_ds.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.aggregate.{KllDoublesSketchAgg, KllDoublesMergeAgg} import org.apache.spark.sql.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.types.ArrayType -import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.{ArrayType, BooleanType, DoubleType} // this class defines and maps all the variants of each function invocation, analagous // to the functions object in org.apache.spark.sql.functions @@ -87,11 +86,11 @@ object functions_ds { // get PMF def kll_get_pmf(sketch: Column, splitPoints: Column, isInclusive: Boolean): Column = withExpr { - new KllGetPmfCdf(sketch.expr, splitPoints.expr, isInclusive, true) + new KllGetPmfCdf(sketch.expr, splitPoints.expr, Literal.create(isInclusive, BooleanType), true) } def kll_get_pmf(sketch: Column, splitPoints: Column): Column = withExpr { - new KllGetPmfCdf(sketch.expr, splitPoints.expr, true, true) + new KllGetPmfCdf(sketch.expr, splitPoints.expr, Literal(true), true) } def kll_get_pmf(columnName: String, splitPoints: Column, isInclusive: Boolean): Column = { @@ -121,11 +120,11 @@ object functions_ds { // get CDF def kll_get_cdf(sketch: Column, splitPoints: Column, isInclusive: Boolean): Column = withExpr { - new KllGetPmfCdf(sketch.expr, splitPoints.expr, isInclusive, false) + new KllGetPmfCdf(sketch.expr, splitPoints.expr, Literal.create(isInclusive, BooleanType), false) } def kll_get_cdf(sketch: Column, splitPoints: Column): Column = withExpr { - new KllGetPmfCdf(sketch.expr, splitPoints.expr, true, false) + new KllGetPmfCdf(sketch.expr, splitPoints.expr, Literal(true), false) } def kll_get_cdf(columnName: String, splitPoints: Column, isInclusive: Boolean): Column = { diff --git a/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala b/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala index 381aa49..e5feb96 100644 --- a/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala +++ b/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala @@ -28,13 +28,20 @@ import scala.reflect.ClassTag // DataSketches imports import org.apache.spark.sql.aggregate.{KllDoublesSketchAgg, KllDoublesMergeAgg} import org.apache.spark.sql.expressions.{KllGetMin, KllGetMax} -import org.apache.spark.sql.expressions.KllGetPmfCdf +import org.apache.spark.sql.expressions.{KllGetPmf, KllGetCdf} // based on org.apache.spark.sql.catalyst.FunctionRegistry trait DatasketchesFunctionRegistry { // override this to define the actual functions val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] + // registers all the functions in the expressions Map + def registerFunctions(spark: SparkSession): Unit = { + expressions.foreach { case (name, (info, builder)) => + spark.sessionState.functionRegistry.registerFunction(FunctionIdentifier(name), info, builder) + } + } + // simplifies defining the expression (ignoring "since" as a stand-alone library) protected def expression[T <: Expression : ClassTag](name: String): (String, (ExpressionInfo, FunctionBuilder)) = { val (expressionInfo, builder) = FunctionRegistryBase.build[T](name, None) @@ -44,6 +51,12 @@ trait DatasketchesFunctionRegistry { // some functions throw a query compile-time exception around the wrong // number of parameters when using expression(). This function allows // explicit argument handling by providing a lambda to use for the bulder. + // This seems to be related to non-Expression inputs to the classes, but keeping + // this an an example of usage for now in case it really is needed: + // complexExpression[KllGetPmfCdf]("kll_get_cdf") { args: Seq[Expression] => + // val isInclusive = if (args.length > 2) args(2).eval().asInstanceOf[Boolean] else true + // new KllGetPmfCdf(args(0), args(1), isInclusive = isInclusive, isPmf = false) + // } protected def complexExpression[T <: Expression : ClassTag](name: String)(f: (Seq[Expression]) => T): (String, (ExpressionInfo, FunctionBuilder)) = { val expressionInfo = FunctionRegistryBase.expressionInfo[T](name, None) val builder: FunctionBuilder = (args: Seq[Expression]) => f(args) @@ -58,30 +71,7 @@ object DatasketchesFunctionRegistry extends DatasketchesFunctionRegistry { expression[KllDoublesMergeAgg]("kll_merge_agg"), expression[KllGetMin]("kll_get_min"), expression[KllGetMax]("kll_get_max"), - - // TODO: it seems like there's got to be a way to simplify this, but - // perhaps not with the optional isInclusive parameter? - // Spark uses ExpressionBuilder, extending that class via a builder class - // and overriding build() to handle the lambda. - // It allows for a cleaner registry here, so we can look at where to put - // the builder classes in the future. - // See org.apache.spark.sql.catalyst.expressions.variant.variantExpressions.scala - complexExpression[KllGetPmfCdf]("kll_get_pmf") { args: Seq[Expression] => - val isInclusive = if (args.length > 2) args(2).eval().asInstanceOf[Boolean] else true - new KllGetPmfCdf(args(0), args(1), isInclusive = isInclusive, isPmf = true) - }, - complexExpression[KllGetPmfCdf]("kll_get_cdf") { args: Seq[Expression] => - val isInclusive = if (args.length > 2) args(2).eval().asInstanceOf[Boolean] else true - new KllGetPmfCdf(args(0), args(1), isInclusive = isInclusive, isPmf = false) - } + expression[KllGetPmf]("kll_get_pmf"), + expression[KllGetCdf]("kll_get_cdf") ) - - // registers all the functions in the expressions Map - def registerFunctions(spark: SparkSession): Unit = { - val functionRegistry = spark.sessionState.functionRegistry - expressions.foreach { case (name, (info, builder)) => - functionRegistry.registerFunction(FunctionIdentifier(name), info, builder) - } - } - } diff --git a/src/test/scala/org/apache/spark/sql/KllTest.scala b/src/test/scala/org/apache/spark/sql/KllTest.scala index 570d72b..f9c835d 100644 --- a/src/test/scala/org/apache/spark/sql/KllTest.scala +++ b/src/test/scala/org/apache/spark/sql/KllTest.scala @@ -58,7 +58,7 @@ class KllTest extends SparkSessionManager { )) val df = spark.createDataFrame(dataList, schema) - df.show() + assert(df.count() == numClass) } test("Create DataFrame from parallelize()") { @@ -80,7 +80,7 @@ class KllTest extends SparkSessionManager { val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) .select($"id", KllDoublesSketchType.wrapBytes($"kll").as("sketch")) - df.show() + assert(df.count() == numClass) } test("KLL Doubles Sketch via scala") { --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
