This is an automated email from the ASF dual-hosted git repository.

jmalkin pushed a commit to branch simplify_registration
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git

commit 455dcde210799a09b8f5df8dc5678d7f3a82376d
Author: Jon Malkin <[email protected]>
AuthorDate: Tue Jan 7 13:19:50 2025 -0800

    Two unrelated changes. Add type-checking to k in sketch creation, and 
modify pmf/cdf to allow simpler registration for SQL
---
 build.sbt                                          |   4 -
 .../apache/spark/sql/aggregate/KllAggregate.scala  |  33 ++++--
 .../spark/sql/expressions/KllExpressions.scala     | 121 ++++++++++++++++-----
 .../scala/org/apache/spark/sql/functions_ds.scala  |  11 +-
 .../registrar/DatasketchesFunctionRegistry.scala   |  42 +++----
 src/test/scala/org/apache/spark/sql/KllTest.scala  |   4 +-
 6 files changed, 141 insertions(+), 74 deletions(-)

diff --git a/build.sbt b/build.sbt
index daa4c63..eb93702 100644
--- a/build.sbt
+++ b/build.sbt
@@ -49,9 +49,5 @@ scalacOptions ++= Seq(
 
 Test / logBuffered := false
 
-// Only show warnings and errors on the screen for compilations.
-// This applies to both test:compile and compile and is Info by default
-Compile / logLevel := Level.Warn
-
 // Level.INFO is needed to see detailed output when running tests
 Test / logLevel := Level.Info
diff --git a/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala 
b/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala
index 3b7a506..c77c7ad 100644
--- a/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala
+++ b/src/main/scala/org/apache/spark/sql/aggregate/KllAggregate.scala
@@ -24,14 +24,15 @@ import 
org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression,
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
 import org.apache.spark.sql.catalyst.trees.BinaryLike
 import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, 
LongType, NumericType, FloatType, DoubleType, KllDoublesSketchType}
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 
 /**
  * The KllDoublesSketchAgg function utilizes a Datasketches KllDoublesSketch 
instance
  * to create a sketch from a column of values which can be used to estimate 
quantiles
  * and histograms.
  *
- * @param child child expression against which the sketch will be created
- * @param k the size-accuracy trade-off parameter for the sketch
+ * @param left Expression against which the sketch will be created
+ * @param right k, the size-accuracy trade-off parameter for the sketch, int 
in range [1, 65535]
  */
 // scalastyle:off line.size.limit
 @ExpressionDescription(
@@ -58,6 +59,7 @@ case class KllDoublesSketchAgg(
     right.eval() match {
       case null => KllSketch.DEFAULT_K
       case k: Int => k
+      // this shouldn't happen after checkInputDataTypes()
       case _ => throw new SparkUnsupportedOperationException(
         s"Unsupported input type ${right.dataType.catalogString}",
         Map("dataType" -> dataType.toString))
@@ -65,7 +67,6 @@ case class KllDoublesSketchAgg(
   }
 
   // Constructors
-
   def this(child: Expression) = {
     this(child, Literal(KllSketch.DEFAULT_K), 0, 0)
   }
@@ -91,18 +92,36 @@ case class KllDoublesSketchAgg(
   }
 
   // overrides for TypedImperativeAggregate
+  override lazy val deterministic: Boolean = false
+
   override def prettyName: String = "kll_sketch_agg"
 
   override def dataType: DataType = KllDoublesSketchType
 
   override def nullable: Boolean = false
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, 
IntegerType, LongType, FloatType, DoubleType)
+  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, 
IntegerType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    // k must be a constant
+    if (!right.foldable) {
+      return TypeCheckResult.TypeCheckFailure(s"k must be foldable, but got: 
${right}")
+    }
+    // Check if k > 0
+    right.eval() match {
+      case k: Int if k > 0 => // valid state, do nothing
+      case k: Int if k > KllSketch.MAX_K => return 
TypeCheckResult.TypeCheckFailure(
+        s"k must be less than or equal to ${KllSketch.MAX_K}, but got: $k")
+      case k: Int => return TypeCheckResult.TypeCheckFailure(s"k must be 
greater than 0, but got: $k")
+      case _ => return TypeCheckResult.TypeCheckFailure(s"Unsupported input 
type ${right.dataType.catalogString}")
+    }
+
+    // additional validations of k handled in the DataSketches library
+    TypeCheckResult.TypeCheckSuccess
+  }
 
-  // create buffer
   override def createAggregationBuffer(): KllDoublesSketch = 
KllDoublesSketch.newHeapInstance(k)
 
-  // update
   override def update(sketch: KllDoublesSketch, input: InternalRow): 
KllDoublesSketch = {
     val value = left.eval(input)
     if (value != null) {
@@ -119,7 +138,6 @@ case class KllDoublesSketchAgg(
     sketch
   }
 
-  // union (merge)
   override def merge(sketch: KllDoublesSketch, other: KllDoublesSketch): 
KllDoublesSketch = {
     if (other != null && !other.isEmpty) {
       sketch.merge(other)
@@ -127,7 +145,6 @@ case class KllDoublesSketchAgg(
     sketch
   }
 
-  // eval
   override def eval(sketch: KllDoublesSketch): Any = {
     if (sketch == null || sketch.isEmpty) {
       null
diff --git 
a/src/main/scala/org/apache/spark/sql/expressions/KllExpressions.scala 
b/src/main/scala/org/apache/spark/sql/expressions/KllExpressions.scala
index 745261b..34b6296 100644
--- a/src/main/scala/org/apache/spark/sql/expressions/KllExpressions.scala
+++ b/src/main/scala/org/apache/spark/sql/expressions/KllExpressions.scala
@@ -19,14 +19,16 @@ package org.apache.spark.sql.expressions
 
 import org.apache.datasketches.memory.Memory
 import org.apache.datasketches.kll.KllDoublesSketch
-import org.apache.spark.sql.catalyst.expressions.{Expression, 
ExpectsInputTypes, UnaryExpression, BinaryExpression}
-import org.apache.spark.sql.types.{AbstractDataType, DataType, ArrayType, 
DoubleType, KllDoublesSketchType}
-import org.apache.spark.sql.catalyst.expressions.NullIntolerant
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodeBlock, 
CodegenContext, ExprCode}
 import org.apache.datasketches.quantilescommon.QuantileSearchCriteria
+import org.apache.spark.sql.types.KllDoublesSketchType
+
+import org.apache.spark.sql.types.{AbstractDataType, ArrayType, BooleanType, 
DataType, DoubleType}
+import org.apache.spark.sql.catalyst.expressions.{Expression, 
ExpressionDescription, ExpectsInputTypes, ImplicitCastInputTypes}
+import org.apache.spark.sql.catalyst.expressions.{UnaryExpression, 
TernaryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Literal, NullIntolerant, 
RuntimeReplaceable}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeBlock, 
CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.util.GenericArrayData
-import org.apache.spark.sql.catalyst.expressions.ExpressionDescription
-import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes
+import org.apache.spark.sql.catalyst.trees.TernaryLike
 
 @ExpressionDescription(
   usage = """
@@ -128,13 +130,73 @@ case class KllGetMax(child: Expression)
   }
 }
 
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr, expr, isInclusive) - Returns an approximation to the PMF
+      of the given KllDoublesSketch using the specified search criteria 
(default: inclusive, isInclusive = true)
+      or exclusive using the given split points.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(kll_sketch_agg(col), array(1.5, 3.5)) FROM VALUES (1.0), 
(2.0), (3.0) tab(col);
+       [0.3333333333333333, 0.6666666666666666, 0.0]
+  """
+)
+case class KllGetPmf(first: Expression,
+                     second: Expression,
+                     third: Expression)
+    extends RuntimeReplaceable
+    with ImplicitCastInputTypes
+    with TernaryLike[Expression] {
+
+    def this(first: Expression, second: Expression) = {
+        this(first, second, Literal(true))
+    }
+
+    override lazy val replacement: Expression = KllGetPmfCdf(first, second, 
third, true)
+    override def inputTypes: Seq[AbstractDataType] = Seq(KllDoublesSketchType, 
ArrayType(DoubleType), BooleanType)
+    override protected def withNewChildrenInternal(newFirst: Expression, 
newSecond: Expression, newThird: Expression): Expression = {
+        copy(first = newFirst, second = newSecond, third = newThird)
+    }
+}
+
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr, expr, isInclusive) - Returns an approximation to the PMF
+      of the given KllDoublesSketch using the specified search criteria 
(default: inclusive, isInclusive = true)
+      or exclusive using the given split points.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(kll_sketch_agg(col), array(1.5, 3.5)) FROM VALUES (1.0), 
(2.0), (3.0) tab(col);
+       [0.3333333333333333, 0.6666666666666666, 0.0]
+  """
+)
+case class KllGetCdf(first: Expression,
+                     second: Expression,
+                     third: Expression)
+    extends RuntimeReplaceable
+    with ImplicitCastInputTypes
+    with TernaryLike[Expression] {
+
+    def this(first: Expression, second: Expression) = {
+        this(first, second, Literal(true))
+    }
+
+    override lazy val replacement: Expression = KllGetPmfCdf(first, second, 
third, false)
+    override def inputTypes: Seq[AbstractDataType] = Seq(KllDoublesSketchType, 
ArrayType(DoubleType), BooleanType)
+    override protected def withNewChildrenInternal(newFirst: Expression, 
newSecond: Expression, newThird: Expression): Expression = {
+        copy(first = newFirst, second = newSecond, third = newThird)
+    }
+}
+
 
 /**
   * Returns the PMF and CDF of the given quantile search criteria.
   *
-  * @param left A KllDoublesSketch sketch, in serialized form
-  * @param right An array of split points, as doubles
-  * @param isInclusive If true, use INCLUSIVE else EXCLUSIVE
+  * @param first A KllDoublesSketch sketch, in serialized form
+  * @param second An array of split points, as doubles
+  * @param third A boolean flag for inclusive mode. If true, use INCLUSIVE 
else EXCLUSIVE
   * @param isPmf Whether to return the PMF (true) or CDF (false)
   */
 @ExpressionDescription(
@@ -149,29 +211,32 @@ case class KllGetMax(child: Expression)
        [0.3333333333333333, 0.6666666666666666, 0.0]
   """
 )
-case class KllGetPmfCdf(left: Expression,
-                        right: Expression,
-                        isInclusive: Boolean = true,
+case class KllGetPmfCdf(first: Expression,
+                        second: Expression,
+                        third: Expression,
                         isPmf: Boolean = false)
- extends BinaryExpression
+ extends TernaryExpression
  with ExpectsInputTypes
  with NullIntolerant
  with ImplicitCastInputTypes {
 
-  override protected def withNewChildrenInternal(newLeft: Expression,
-                                              newRight: Expression) = {
-    copy(left = newLeft, right = newRight, isInclusive = isInclusive, isPmf = 
isPmf)
+  lazy val isInclusive = third.eval().asInstanceOf[Boolean]
+
+  override protected def withNewChildrenInternal(newFirst: Expression,
+                                                 newSecond: Expression,
+                                                 newThird: Expression) = {
+    copy(first = newFirst, second = newSecond, third = newThird, isPmf = isPmf)
   }
 
   override def prettyName: String = "kll_get_pmf_cdf"
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(KllDoublesSketchType, 
ArrayType(DoubleType))
+  override def inputTypes: Seq[AbstractDataType] = Seq(KllDoublesSketchType, 
ArrayType(DoubleType), BooleanType)
 
   override def dataType: DataType = ArrayType(DoubleType, containsNull = false)
 
-  override def nullSafeEval(leftInput: Any, rightInput: Any): Any = {
-    val sketchBytes = leftInput.asInstanceOf[Array[Byte]]
-    val splitPoints = rightInput.asInstanceOf[GenericArrayData].toDoubleArray
+  override def nullSafeEval(firstInput: Any, secondInput: Any, thirdInput: 
Any): Any = {
+    val sketchBytes = firstInput.asInstanceOf[Array[Byte]]
+    val splitPoints = secondInput.asInstanceOf[GenericArrayData].toDoubleArray
     val sketch = KllDoublesSketch.wrap(Memory.wrap(sketchBytes))
 
     val result: Array[Double] =
@@ -183,30 +248,30 @@ case class KllGetPmfCdf(left: Expression,
     new GenericArrayData(result)
   }
 
-  override protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f: 
(String, String) => String): ExprCode = {
-    val sketchEval = left.genCode(ctx)
+  override protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f: 
(String, String, String) => String): ExprCode = {
+    val sketchEval = first.genCode(ctx)
     val sketch = ctx.freshName("sketch")
-    val splitPointsEval = right.genCode(ctx)
+    val splitPointsEval = second.genCode(ctx)
     val code =
       s"""
          |${sketchEval.code}
          |${splitPointsEval.code}
          |if (${sketchEval.isNull} || ${splitPointsEval.isNull}) {
-         |  ${ev.isNull} = true;
+         |  boolean ${ev.isNull} = true;
          |} else {
-         |  QuantileSearchCriteria searchCriteria = ${if (isInclusive) 
"QuantileSearchCriteria.INCLUSIVE" else "QuantileSearchCriteria.EXCLUSIVE"};
+         |  org.apache.datasketches.quantilescommon.QuantileSearchCriteria 
searchCriteria = ${if (isInclusive) 
"org.apache.datasketches.quantilescommon.QuantileSearchCriteria.INCLUSIVE" else 
"org.apache.datasketches.quantilescommon.QuantileSearchCriteria.EXCLUSIVE"};
          |  final org.apache.datasketches.kll.KllDoublesSketch $sketch = 
org.apache.spark.sql.types.KllDoublesSketchType.wrap(${sketchEval.value});
          |  final double[] splitPoints = 
((org.apache.spark.sql.catalyst.util.GenericArrayData)${splitPointsEval.value}).toDoubleArray();
          |  final double[] result = ${if (isPmf) s"$sketch.getPMF(splitPoints, 
searchCriteria)" else s"$sketch.getCDF(splitPoints, searchCriteria)"};
-         |  GenericArrayData ${ev.value} = new GenericArrayData(result);
-         |  ${ev.isNull} = false;
+         |  org.apache.spark.sql.catalyst.util.GenericArrayData ${ev.value} = 
new org.apache.spark.sql.catalyst.util.GenericArrayData(result);
+         |  boolean ${ev.isNull} = false;
          |}
        """.stripMargin
     ev.copy(code = CodeBlock(Seq(code), Seq.empty))
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    nullSafeCodeGen(ctx, ev, (arg1, arg2) => s"($arg1, $arg2)")
+    nullSafeCodeGen(ctx, ev, (arg1, arg2, arg3) => s"($arg1, $arg2, $arg3)")
   }
 }
 
diff --git a/src/main/scala/org/apache/spark/sql/functions_ds.scala 
b/src/main/scala/org/apache/spark/sql/functions_ds.scala
index 7c2a96e..7fa9e7f 100644
--- a/src/main/scala/org/apache/spark/sql/functions_ds.scala
+++ b/src/main/scala/org/apache/spark/sql/functions_ds.scala
@@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.aggregate.{KllDoublesSketchAgg, KllDoublesMergeAgg}
 import org.apache.spark.sql.expressions._
 import org.apache.spark.sql.catalyst.expressions.Literal
-import org.apache.spark.sql.types.ArrayType
-import org.apache.spark.sql.types.DoubleType
+import org.apache.spark.sql.types.{ArrayType, BooleanType, DoubleType}
 
 // this class defines and maps all the variants of each function invocation, 
analagous
 // to the functions object in org.apache.spark.sql.functions
@@ -87,11 +86,11 @@ object functions_ds {
 
   // get PMF
   def kll_get_pmf(sketch: Column, splitPoints: Column, isInclusive: Boolean): 
Column = withExpr {
-    new KllGetPmfCdf(sketch.expr, splitPoints.expr, isInclusive, true)
+    new KllGetPmfCdf(sketch.expr, splitPoints.expr, 
Literal.create(isInclusive, BooleanType), true)
   }
 
   def kll_get_pmf(sketch: Column, splitPoints: Column): Column = withExpr {
-    new KllGetPmfCdf(sketch.expr, splitPoints.expr, true, true)
+    new KllGetPmfCdf(sketch.expr, splitPoints.expr, Literal(true), true)
   }
 
   def kll_get_pmf(columnName: String, splitPoints: Column, isInclusive: 
Boolean): Column = {
@@ -121,11 +120,11 @@ object functions_ds {
 
   // get CDF
   def kll_get_cdf(sketch: Column, splitPoints: Column, isInclusive: Boolean): 
Column = withExpr {
-    new KllGetPmfCdf(sketch.expr, splitPoints.expr, isInclusive, false)
+    new KllGetPmfCdf(sketch.expr, splitPoints.expr, 
Literal.create(isInclusive, BooleanType), false)
   }
 
   def kll_get_cdf(sketch: Column, splitPoints: Column): Column = withExpr {
-    new KllGetPmfCdf(sketch.expr, splitPoints.expr, true, false)
+    new KllGetPmfCdf(sketch.expr, splitPoints.expr, Literal(true), false)
   }
 
   def kll_get_cdf(columnName: String, splitPoints: Column, isInclusive: 
Boolean): Column = {
diff --git 
a/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala
 
b/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala
index 381aa49..e5feb96 100644
--- 
a/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala
+++ 
b/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala
@@ -28,13 +28,20 @@ import scala.reflect.ClassTag
 // DataSketches imports
 import org.apache.spark.sql.aggregate.{KllDoublesSketchAgg, KllDoublesMergeAgg}
 import org.apache.spark.sql.expressions.{KllGetMin, KllGetMax}
-import org.apache.spark.sql.expressions.KllGetPmfCdf
+import org.apache.spark.sql.expressions.{KllGetPmf, KllGetCdf}
 
 // based on org.apache.spark.sql.catalyst.FunctionRegistry
 trait DatasketchesFunctionRegistry {
   // override this to define the actual functions
   val expressions: Map[String, (ExpressionInfo, FunctionBuilder)]
 
+  // registers all the functions in the expressions Map
+  def registerFunctions(spark: SparkSession): Unit = {
+    expressions.foreach { case (name, (info, builder)) =>
+      
spark.sessionState.functionRegistry.registerFunction(FunctionIdentifier(name), 
info, builder)
+    }
+  }
+
   // simplifies defining the expression (ignoring "since" as a stand-alone 
library)
   protected def expression[T <: Expression : ClassTag](name: String): (String, 
(ExpressionInfo, FunctionBuilder)) = {
     val (expressionInfo, builder) = FunctionRegistryBase.build[T](name, None)
@@ -44,6 +51,12 @@ trait DatasketchesFunctionRegistry {
   // some functions throw a query compile-time exception around the wrong
   // number of parameters when using expression(). This function allows
   // explicit argument handling by providing a lambda to use for the bulder.
+  // This seems to be related to non-Expression inputs to the classes, but 
keeping
+  // this an an example of usage for now in case it really is needed:
+  //    complexExpression[KllGetPmfCdf]("kll_get_cdf") { args: Seq[Expression] 
=>
+  //      val isInclusive = if (args.length > 2) 
args(2).eval().asInstanceOf[Boolean] else true
+  //      new KllGetPmfCdf(args(0), args(1), isInclusive = isInclusive, isPmf 
= false)
+  //    }
   protected def complexExpression[T <: Expression : ClassTag](name: String)(f: 
(Seq[Expression]) => T): (String, (ExpressionInfo, FunctionBuilder)) = {
     val expressionInfo = FunctionRegistryBase.expressionInfo[T](name, None)
     val builder: FunctionBuilder = (args: Seq[Expression]) => f(args)
@@ -58,30 +71,7 @@ object DatasketchesFunctionRegistry extends 
DatasketchesFunctionRegistry {
     expression[KllDoublesMergeAgg]("kll_merge_agg"),
     expression[KllGetMin]("kll_get_min"),
     expression[KllGetMax]("kll_get_max"),
-
-    // TODO: it seems like there's got to be a way to simplify this, but
-    // perhaps not with the optional isInclusive parameter?
-    // Spark uses ExpressionBuilder, extending that class via a builder class
-    // and overriding build() to handle the lambda.
-    // It allows for a cleaner registry here, so we can look at where to put
-    // the builder classes in the future.
-    // See 
org.apache.spark.sql.catalyst.expressions.variant.variantExpressions.scala
-    complexExpression[KllGetPmfCdf]("kll_get_pmf") { args: Seq[Expression] =>
-      val isInclusive = if (args.length > 2) 
args(2).eval().asInstanceOf[Boolean] else true
-      new KllGetPmfCdf(args(0), args(1), isInclusive = isInclusive, isPmf = 
true)
-    },
-    complexExpression[KllGetPmfCdf]("kll_get_cdf") { args: Seq[Expression] =>
-      val isInclusive = if (args.length > 2) 
args(2).eval().asInstanceOf[Boolean] else true
-      new KllGetPmfCdf(args(0), args(1), isInclusive = isInclusive, isPmf = 
false)
-    }
+    expression[KllGetPmf]("kll_get_pmf"),
+    expression[KllGetCdf]("kll_get_cdf")
   )
-
-  // registers all the functions in the expressions Map
-  def registerFunctions(spark: SparkSession): Unit = {
-    val functionRegistry = spark.sessionState.functionRegistry
-    expressions.foreach { case (name, (info, builder)) =>
-      functionRegistry.registerFunction(FunctionIdentifier(name), info, 
builder)
-    }
-  }
-
 }
diff --git a/src/test/scala/org/apache/spark/sql/KllTest.scala 
b/src/test/scala/org/apache/spark/sql/KllTest.scala
index 570d72b..f9c835d 100644
--- a/src/test/scala/org/apache/spark/sql/KllTest.scala
+++ b/src/test/scala/org/apache/spark/sql/KllTest.scala
@@ -58,7 +58,7 @@ class KllTest extends SparkSessionManager {
     ))
 
     val df = spark.createDataFrame(dataList, schema)
-    df.show()
+    assert(df.count() == numClass)
   }
 
   test("Create DataFrame from parallelize()") {
@@ -80,7 +80,7 @@ class KllTest extends SparkSessionManager {
     val df = spark.createDataFrame(spark.sparkContext.parallelize(data), 
schema)
       .select($"id", KllDoublesSketchType.wrapBytes($"kll").as("sketch"))
 
-    df.show()
+    assert(df.count() == numClass)
   }
 
   test("KLL Doubles Sketch via scala") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to