This is an automated email from the ASF dual-hosted git repository.
alsay pushed a commit to branch theta_params
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git
The following commit(s) were added to refs/heads/theta_params by this push:
new ffe0066 theta functions with parameters
ffe0066 is described below
commit ffe00665819ea1a92cd4f0ddd11cacc9672c9fce
Author: AlexanderSaydakov <[email protected]>
AuthorDate: Mon Feb 24 20:30:28 2025 -0800
theta functions with parameters
---
.../datasketches/theta/ThetaFunctionRegistry.scala | 5 +-
...onRegistry.scala => ThetaSketchConstants.scala} | 18 +----
.../theta/aggregate/ThetaSketchAggBuild.scala | 39 +++++-----
.../theta/aggregate/ThetaSketchAggUnion.scala | 88 ++++++++++++++--------
.../theta/expressions/ThetaExpressions.scala | 73 ++++++++++++++++--
.../spark/sql/datasketches/theta/functions.scala | 28 +++++--
.../spark/sql/datasketches/theta/ThetaTest.scala | 65 ++++++++++++++--
7 files changed, 230 insertions(+), 86 deletions(-)
diff --git
a/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaFunctionRegistry.scala
b/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaFunctionRegistry.scala
index ddbe615..3ceeb3c 100644
---
a/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaFunctionRegistry.scala
+++
b/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaFunctionRegistry.scala
@@ -22,13 +22,14 @@ import
org.apache.spark.sql.catalyst.expressions.{ExpressionInfo}
import org.apache.spark.sql.datasketches.common.DatasketchesFunctionRegistry
import org.apache.spark.sql.datasketches.theta.aggregate.{ThetaSketchAggBuild,
ThetaSketchAggUnion}
-import
org.apache.spark.sql.datasketches.theta.expressions.ThetaSketchGetEstimate
+import
org.apache.spark.sql.datasketches.theta.expressions.{ThetaSketchGetEstimate,
ThetaSketchToString}
import org.apache.spark.sql.datasketches.common.DatasketchesFunctionRegistry
object ThetaFunctionRegistry extends DatasketchesFunctionRegistry {
override val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] =
Map(
expression[ThetaSketchAggBuild]("theta_sketch_agg_build"),
expression[ThetaSketchAggUnion]("theta_sketch_agg_union"),
- expression[ThetaSketchGetEstimate]("theta_sketch_get_estimate")
+ expression[ThetaSketchGetEstimate]("theta_sketch_get_estimate"),
+ expression[ThetaSketchToString]("theta_sketch_to_string")
)
}
diff --git
a/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaFunctionRegistry.scala
b/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaSketchConstants.scala
similarity index 50%
copy from
src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaFunctionRegistry.scala
copy to
src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaSketchConstants.scala
index ddbe615..1e8befc 100644
---
a/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaFunctionRegistry.scala
+++
b/src/main/scala/org/apache/spark/sql/datasketches/theta/ThetaSketchConstants.scala
@@ -17,18 +17,6 @@
package org.apache.spark.sql.datasketches.theta
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.expressions.{ExpressionInfo}
-
-import org.apache.spark.sql.datasketches.common.DatasketchesFunctionRegistry
-import org.apache.spark.sql.datasketches.theta.aggregate.{ThetaSketchAggBuild,
ThetaSketchAggUnion}
-import
org.apache.spark.sql.datasketches.theta.expressions.ThetaSketchGetEstimate
-import org.apache.spark.sql.datasketches.common.DatasketchesFunctionRegistry
-
-object ThetaFunctionRegistry extends DatasketchesFunctionRegistry {
- override val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] =
Map(
- expression[ThetaSketchAggBuild]("theta_sketch_agg_build"),
- expression[ThetaSketchAggUnion]("theta_sketch_agg_union"),
- expression[ThetaSketchGetEstimate]("theta_sketch_get_estimate")
- )
-}
+final object ThetaSketchConstants {
+ final val DEFAULT_LG_K: Int = 12
+}
\ No newline at end of file
diff --git
a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
index 70d13e2..f6eaba5 100644
---
a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
+++
b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggBuild.scala
@@ -17,21 +17,19 @@
package org.apache.spark.sql.datasketches.theta.aggregate
-import org.apache.datasketches.theta.{UpdateSketch, SetOperation}
-import org.apache.spark.sql.datasketches.theta.types.{ThetaSketchType,
ThetaSketchWrapper}
-
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes,
Expression, ExpressionDescription, Literal}
import
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.trees.QuaternaryLike
-import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType,
LongType, NumericType, FloatType, DoubleType}
+import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType,
LongType, NumericType, FloatType, DoubleType, StringType, TypeCollection}
+import org.apache.spark.unsafe.types.UTF8String
-import org.apache.datasketches.thetacommon.ThetaUtil.DEFAULT_UPDATE_SEED
-import org.apache.datasketches.common.ResizeFactor
+import
org.apache.spark.sql.datasketches.theta.ThetaSketchConstants.DEFAULT_LG_K
+import org.apache.spark.sql.datasketches.theta.types.{ThetaSketchType,
ThetaSketchWrapper}
-object ThetaSketchConstants {
- final val DEFAULT_LG_K: Int = 12
-}
+import org.apache.datasketches.common.ResizeFactor
+import org.apache.datasketches.theta.{UpdateSketch, SetOperation}
+import org.apache.datasketches.thetacommon.ThetaUtil.DEFAULT_UPDATE_SEED
/**
* The ThetaSketchBuild function creates a Theta sketch from a column of values
@@ -39,17 +37,21 @@ object ThetaSketchConstants {
*
* See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for
more information.
*
- * @param child child expression, from which to build a sketch
- * @param lgk the size-accraucy trade-off parameter for the sketch
+ * @param input expression, from which to build a sketch
+ * @param lgK size-accraucy trade-off parameter for the sketch
+ * @param seed update seed for the sketch
+ * @param p initial sampling probability for the sketch
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
- _FUNC_(expr, lgk) - Creates a Theta Sketch and returns the binary
representation.
- `lgk` (optional, default: 12) the size-accuracy trade-off parameter.""",
+ _FUNC_(expr, lgK, seed, p) - Creates a Theta Sketch and returns the binary
representation.
+ `lgK` (optional, default: 12) size-accuracy trade-off parameter.
+ `seed` (optional, default: 9001) update seed for the sketch.
+ `p` (optional, default: 1) initial sampling probability for the
sketch.""",
examples = """
Example:
- > SELECT theta_sketch_get_estimate(_FUNC_(col, 12)) FROM VALUES (1),
(2), (3), (4), (5) tab(col);
+ > SELECT theta_sketch_get_estimate(_FUNC_(col)) FROM VALUES (1), (2),
(3), (4), (5) tab(col);
5.0
""",
)
@@ -67,7 +69,7 @@ case class ThetaSketchAggBuild(
lazy val lgK: Int = {
lgKExpr.eval() match {
- case null => ThetaSketchConstants.DEFAULT_LG_K
+ case null => DEFAULT_LG_K
case lgK: Int => lgK
case _ => throw new IllegalArgumentException(
s"Unsupported input type ${lgKExpr.dataType.catalogString}")
@@ -99,7 +101,7 @@ case class ThetaSketchAggBuild(
def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression,
pExpr: Expression) = this(inputExpr, lgKExpr, seedExpr, pExpr, 0, 0)
def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression) =
this(inputExpr, lgKExpr, seedExpr, Literal(1f))
def this(inputExpr: Expression, lgKExpr: Expression) = this(inputExpr,
lgKExpr, Literal(DEFAULT_UPDATE_SEED))
- def this(inputExpr: Expression) = this(inputExpr,
Literal(ThetaSketchConstants.DEFAULT_LG_K))
+ def this(inputExpr: Expression) = this(inputExpr, Literal(DEFAULT_LG_K))
def this(inputExpr: Expression, lgK: Int) = this(inputExpr, Literal(lgK))
def this(inputExpr: Expression, lgK: Int, seed: Long) = this(inputExpr,
Literal(lgK), Literal(seed))
@@ -115,13 +117,13 @@ case class ThetaSketchAggBuild(
copy(inputExpr = newFirst, lgKExpr = newSecond, seedExpr = newThird, pExpr
= newFourth)
}
- override def prettyName: String = "theta_sketch_build"
+ override def prettyName: String = "theta_sketch_agg_build"
override def dataType: DataType = ThetaSketchType
override def nullable: Boolean = false
- override def inputTypes: Seq[AbstractDataType] = Seq(NumericType,
IntegerType, LongType, FloatType, DoubleType)
+ override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(NumericType, StringType), IntegerType, LongType, FloatType)
override def createAggregationBuffer(): ThetaSketchWrapper = new
ThetaSketchWrapper(updateSketch
=
Some(UpdateSketch.builder().setLogNominalEntries(lgK).setSeed(seed).setP(p).build()))
@@ -134,6 +136,7 @@ case class ThetaSketchAggBuild(
case FloatType =>
wrapper.updateSketch.get.update(value.asInstanceOf[Float])
case IntegerType =>
wrapper.updateSketch.get.update(value.asInstanceOf[Int])
case LongType =>
wrapper.updateSketch.get.update(value.asInstanceOf[Long])
+ case StringType =>
wrapper.updateSketch.get.update(value.asInstanceOf[UTF8String].toString)
case _ => throw new IllegalArgumentException(
s"Unsupported input type ${inputExpr.dataType.catalogString}")
}
diff --git
a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggUnion.scala
b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggUnion.scala
index 6e067c9..fd8ce17 100644
---
a/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggUnion.scala
+++
b/src/main/scala/org/apache/spark/sql/datasketches/theta/aggregate/ThetaSketchAggUnion.scala
@@ -17,28 +17,34 @@
package org.apache.spark.sql.datasketches.theta.aggregate
-import org.apache.datasketches.memory.Memory
-import org.apache.datasketches.theta.{Sketch, SetOperation}
-import org.apache.spark.sql.datasketches.theta.types.{ThetaSketchType,
ThetaSketchWrapper}
-
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes,
Expression, ExpressionDescription, Literal}
import
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
-import org.apache.spark.sql.catalyst.trees.BinaryLike
-import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType}
+import org.apache.spark.sql.catalyst.trees.TernaryLike
+import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType,
LongType}
+
+import
org.apache.spark.sql.datasketches.theta.ThetaSketchConstants.DEFAULT_LG_K
+import org.apache.spark.sql.datasketches.theta.types.{ThetaSketchType,
ThetaSketchWrapper}
+
+import org.apache.datasketches.memory.Memory
+import org.apache.datasketches.theta.{Sketch, SetOperation}
+import org.apache.datasketches.thetacommon.ThetaUtil.DEFAULT_UPDATE_SEED
/**
* Theta Union operation.
*
* See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for
more information.
*
- * @param child child expression, on which to perform the union operation
- * @param lgk the size-accraucy trade-off parameter for the sketch
+ * @param input expression, on which to perform the union operation
+ * @param lgK size-accraucy trade-off parameter for the sketch
+ * @param seed update seed for the sketch
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
- _FUNC_(expr, lgk) - Performs Theta Union operation and returns the result
as Theta Sketch in binary form
+ _FUNC_(expr, lgK, seed) - Performs Theta Union operation and returns the
result as Theta Sketch in binary form
+ `lgK` (optional, default: 12) size-accuracy trade-off parameter.
+ `seed` (optional, default: 9001) update seed for the sketch.
""",
examples = """
Example:
@@ -48,27 +54,43 @@ import org.apache.spark.sql.types.{AbstractDataType,
DataType, IntegerType}
)
// scalastyle:on line.size.limit
case class ThetaSketchAggUnion(
- left: Expression,
- right: Expression,
+ inputExpr: Expression,
+ lgKExpr: Expression,
+ seedExpr: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[ThetaSketchWrapper]
- with BinaryLike[Expression]
+ with TernaryLike[Expression]
with ExpectsInputTypes {
- lazy val lgk: Int = {
- right.eval() match {
- case null => 12
- case lgk: Int => lgk
+ lazy val lgK: Int = {
+ lgKExpr.eval() match {
+ case null => DEFAULT_LG_K
+ case lgK: Int => lgK
case _ => throw new IllegalArgumentException(
- s"Unsupported input type ${right.dataType.catalogString}")
+ s"Unsupported input type ${lgKExpr.dataType.catalogString}")
}
}
- def this(left: Expression, right: Expression) = this(left, right, 0, 0)
- def this(left: Expression) = this(left, Literal(12), 0, 0)
-// def this(child: Expression, lgk: Expression) = this(child, lgk, 0, 0)
-// def this(child: Expression, lgk: Int) = this(child, Literal(lgk), 0, 0)
+ lazy val seed: Long = {
+ seedExpr.eval() match {
+ case null => DEFAULT_UPDATE_SEED
+ case seed: Long => seed
+ case _ => throw new IllegalArgumentException(
+ s"Unsupported input type ${seedExpr.dataType.catalogString}")
+ }
+ }
+
+ override def first: Expression = inputExpr
+ override def second: Expression = lgKExpr
+ override def third: Expression = seedExpr
+
+ def this(inputExpr: Expression, lgKExpr: Expression, seedExpr: Expression) =
this(inputExpr, lgKExpr, seedExpr, 0, 0)
+ def this(inputExpr: Expression, lgKExpr: Expression) = this(inputExpr,
lgKExpr, Literal(DEFAULT_UPDATE_SEED))
+ def this(inputExpr: Expression) = this(inputExpr, Literal(DEFAULT_LG_K))
+
+ def this(inputExpr: Expression, lgK: Int) = this(inputExpr, Literal(lgK))
+ def this(inputExpr: Expression, lgK: Int, seed: Long) = this(inputExpr,
Literal(lgK), Literal(seed))
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int):
ThetaSketchAggUnion =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -76,37 +98,37 @@ case class ThetaSketchAggUnion(
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int):
ThetaSketchAggUnion =
copy(inputAggBufferOffset = newInputAggBufferOffset)
- override protected def withNewChildrenInternal(newLeft: Expression,
newRight: Expression): ThetaSketchAggUnion = {
- copy(left = newLeft, right = newRight)
+ override protected def withNewChildrenInternal(newInputExpr: Expression,
newLgKExpr: Expression, newSeedExpr: Expression): ThetaSketchAggUnion = {
+ copy(inputExpr = newInputExpr, lgKExpr = newLgKExpr, seedExpr =
newSeedExpr)
}
// overrides for TypedImperativeAggregate
- override def prettyName: String = "theta_union"
+ override def prettyName: String = "theta_sketch_agg_union"
override def dataType: DataType = ThetaSketchType
override def nullable: Boolean = false
- // TODO: refine this?
- override def inputTypes: Seq[AbstractDataType] = Seq(ThetaSketchType,
IntegerType)
+ override def inputTypes: Seq[AbstractDataType] = Seq(ThetaSketchType,
IntegerType, LongType)
- override def createAggregationBuffer(): ThetaSketchWrapper = new
ThetaSketchWrapper(union =
Some(SetOperation.builder().setLogNominalEntries(lgk).buildUnion))
+ override def createAggregationBuffer(): ThetaSketchWrapper = new
ThetaSketchWrapper(union
+ =
Some(SetOperation.builder().setLogNominalEntries(lgK).setSeed(seed).buildUnion()))
override def update(wrapper: ThetaSketchWrapper, input: InternalRow):
ThetaSketchWrapper = {
- val bytes = left.eval(input)
+ val bytes = inputExpr.eval(input)
if (bytes != null) {
- left.dataType match {
+ inputExpr.dataType match {
case ThetaSketchType =>
wrapper.union.get.union(Sketch.wrap(Memory.wrap(bytes.asInstanceOf[Array[Byte]])))
case _ => throw new IllegalArgumentException(
- s"Unsupported input type ${left.dataType.catalogString}")
+ s"Unsupported input type ${inputExpr.dataType.catalogString}")
}
}
wrapper
}
override def merge(wrapper: ThetaSketchWrapper, other: ThetaSketchWrapper):
ThetaSketchWrapper = {
- if (other != null && !other.compactSketch.get.isEmpty) {
+ if (other != null && !other.compactSketch.get.isEmpty()) {
if (wrapper.union.isEmpty) {
- wrapper.union =
Some(SetOperation.builder().setLogNominalEntries(lgk).buildUnion)
+ wrapper.union =
Some(SetOperation.builder().setLogNominalEntries(lgK).setSeed(seed).buildUnion())
if (wrapper.compactSketch.isDefined) {
wrapper.union.get.union(wrapper.compactSketch.get)
wrapper.compactSketch = None
@@ -118,7 +140,7 @@ case class ThetaSketchAggUnion(
}
override def eval(wrapper: ThetaSketchWrapper): Any = {
- wrapper.union.get.getResult.toByteArrayCompressed
+ wrapper.union.get.getResult.toByteArrayCompressed()
}
override def serialize(wrapper: ThetaSketchWrapper): Array[Byte] = {
diff --git
a/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala
b/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala
index c426ee4..c77fcd4 100644
---
a/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala
+++
b/src/main/scala/org/apache/spark/sql/datasketches/theta/expressions/ThetaExpressions.scala
@@ -17,15 +17,17 @@
package org.apache.spark.sql.datasketches.theta.expressions
-import org.apache.datasketches.memory.Memory
-import org.apache.datasketches.theta.Sketch
-import org.apache.spark.sql.datasketches.theta.types.ThetaSketchType
-
import org.apache.spark.sql.catalyst.expressions.{Expression,
ExpectsInputTypes, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.NullIntolerant
import org.apache.spark.sql.catalyst.expressions.ExpressionDescription
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeBlock,
CodegenContext, ExprCode}
-import org.apache.spark.sql.types.{AbstractDataType, DataType, DoubleType}
+import org.apache.spark.sql.types.{AbstractDataType, DataType, DoubleType,
StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+import org.apache.spark.sql.datasketches.theta.types.ThetaSketchType
+
+import org.apache.datasketches.memory.Memory
+import org.apache.datasketches.theta.Sketch
@ExpressionDescription(
usage = """
@@ -33,7 +35,7 @@ import org.apache.spark.sql.types.{AbstractDataType,
DataType, DoubleType}
""",
examples = """
Example:
- > SELECT _FUNC_(theta_sketch_build(col)) FROM VALUES (1), (2), (3)
tab(col);
+ > SELECT _FUNC_(theta_sketch_agg_build(col)) FROM VALUES (1), (2), (3)
tab(col);
3.0
"""
)
@@ -53,7 +55,7 @@ case class ThetaSketchGetEstimate(child: Expression)
override def dataType: DataType = DoubleType
override def nullSafeEval(input: Any): Any = {
- Sketch.wrap(Memory.wrap(input.asInstanceOf[Array[Byte]])).getEstimate
+ Sketch.wrap(Memory.wrap(input.asInstanceOf[Array[Byte]])).getEstimate()
}
override protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f:
String => String): ExprCode = {
@@ -71,3 +73,60 @@ case class ThetaSketchGetEstimate(child: Expression)
nullSafeCodeGen(ctx, ev, c => s"($c)")
}
}
+
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr) - Returns a summary string that represents the state of the
given sketch
+ """,
+ examples = """
+ Example:
+ > SELECT _FUNC_(theta_sketch_agg_build(col)) FROM VALUES (1), (2), (3)
tab(col);
+ ### HeapCompactSketch SUMMARY:
+ Estimate : 3.0
+ Upper Bound, 95% conf : 3.0
+ Lower Bound, 95% conf : 3.0
+ Theta (double) : 1.0
+ Theta (long) : 9223372036854775807
+ Theta (long) hex : 7fffffffffffffff
+ EstMode? : false
+ Empty? : false
+ Ordered? : true
+ Retained Entries : 3
+ Seed Hash : 93cc | 37836
+ ### END SKETCH SUMMARY
+ """
+)
+case class ThetaSketchToString(child: Expression)
+ extends UnaryExpression
+ with ExpectsInputTypes
+ with NullIntolerant {
+
+ override protected def withNewChildInternal(newChild: Expression):
ThetaSketchToString = {
+ copy(child = newChild)
+ }
+
+ override def prettyName: String = "theta_sketch_to_string"
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(ThetaSketchType)
+
+ override def dataType: DataType = StringType
+
+ override def nullSafeEval(input: Any): Any = {
+
UTF8String.fromString(Sketch.wrap(Memory.wrap(input.asInstanceOf[Array[Byte]])).toString());
+ }
+
+ override protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f:
String => String): ExprCode = {
+ val childEval = child.genCode(ctx)
+ val sketch = ctx.freshName("sketch")
+ val code = s"""
+ ${childEval.code}
+ final org.apache.datasketches.theta.Sketch $sketch =
org.apache.spark.sql.types.ThetaSketchWrapper.wrapAsReadOnlySketch(${childEval.value});
+ final double ${ev.value} = $sketch.toString());
+ """
+ ev.copy(code = CodeBlock(Seq(code), Seq.empty), isNull = childEval.isNull)
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, c => s"($c)")
+ }
+}
diff --git
a/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala
b/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala
index 240aeea..7445d50 100644
--- a/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala
+++ b/src/main/scala/org/apache/spark/sql/datasketches/theta/functions.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.datasketches.common.DatasketchesScalaFunctionBase
import org.apache.spark.sql.datasketches.theta.aggregate.{ThetaSketchAggBuild,
ThetaSketchAggUnion}
-import
org.apache.spark.sql.datasketches.theta.expressions.ThetaSketchGetEstimate
+import
org.apache.spark.sql.datasketches.theta.expressions.{ThetaSketchGetEstimate,
ThetaSketchToString}
import org.apache.spark.sql.datasketches.common.DatasketchesScalaFunctionBase
object functions extends DatasketchesScalaFunctionBase {
@@ -58,20 +58,28 @@ object functions extends DatasketchesScalaFunctionBase {
theta_sketch_agg_build(Column(columnName))
}
+ def theta_sketch_agg_union(column: Column, lgk: Int, seed: Long): Column =
withAggregateFunction {
+ new ThetaSketchAggUnion(column.expr, lgk, seed)
+ }
+
+ def theta_sketch_agg_union(columnName: String, lgk: Int, seed: Long): Column
= {
+ theta_sketch_agg_union(Column(columnName), lgk, seed)
+ }
+
def theta_sketch_agg_union(column: Column, lgk: Int): Column =
withAggregateFunction {
- new ThetaSketchAggUnion(column.expr, lit(lgk).expr)
+ new ThetaSketchAggUnion(column.expr, lgk)
}
- def theta_sketch_agg_union(columnName: String, lgk: Int): Column =
withAggregateFunction {
- new ThetaSketchAggUnion(Column(columnName).expr, lit(lgk).expr)
+ def theta_sketch_agg_union(columnName: String, lgk: Int): Column = {
+ theta_sketch_agg_union(Column(columnName), lgk)
}
def theta_sketch_agg_union(column: Column): Column = withAggregateFunction {
new ThetaSketchAggUnion(column.expr)
}
- def theta_sketch_agg_union(columnName: String): Column =
withAggregateFunction {
- new ThetaSketchAggUnion(Column(columnName).expr)
+ def theta_sketch_agg_union(columnName: String): Column = {
+ theta_sketch_agg_union(Column(columnName))
}
def theta_sketch_get_estimate(column: Column): Column = withExpr {
@@ -81,4 +89,12 @@ object functions extends DatasketchesScalaFunctionBase {
def theta_sketch_get_estimate(columnName: String): Column = {
theta_sketch_get_estimate(Column(columnName))
}
+
+ def theta_sketch_to_string(column: Column): Column = withExpr {
+ new ThetaSketchToString(column.expr)
+ }
+
+ def theta_sketch_to_string(columnName: String): Column = {
+ theta_sketch_to_string(Column(columnName))
+ }
}
diff --git
a/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala
b/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala
index 62a547f..09add01 100644
--- a/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala
+++ b/src/test/scala/org/apache/spark/sql/datasketches/theta/ThetaTest.scala
@@ -99,16 +99,16 @@ class ThetaTest extends SparkSessionManager {
assert(df.head().getAs[Double]("estimate") == 100.0)
}
- test("Theta Sketch build via SQL with lgk and seed") {
+ test("Theta Sketch build from strings via SQL with lgk and seed") {
ThetaFunctionRegistry.registerFunctions(spark)
val n = 100
- val data = (for (i <- 1 to n) yield i).toDF("value")
+ val data = (for (i <- 1 to n) yield i.toString()).toDF("str")
data.createOrReplaceTempView("theta_input_table")
val df = spark.sql(s"""
SELECT
- theta_sketch_get_estimate(theta_sketch_agg_build(value, 14, 111L)) AS
estimate
+ theta_sketch_get_estimate(theta_sketch_agg_build(str, 14, 111L)) AS
estimate
FROM
theta_input_table
""")
@@ -131,7 +131,7 @@ class ThetaTest extends SparkSessionManager {
df.head().getAs[Double]("estimate") shouldBe (100.0 +- 2.0)
}
- test("Theta Union via Scala") {
+ test("Theta Union via Scala with defauls") {
val numGroups = 10
val numDistinct = 2000
val data = (for (i <- 1 to numDistinct) yield (i % numGroups,
i)).toDF("group", "value")
@@ -142,8 +142,34 @@ class ThetaTest extends SparkSessionManager {
assert(result.getAs[Double]("estimate") == numDistinct)
}
+ test("Theta Union via Scala with lgk") {
+ val numGroups = 10
+ val numDistinct = 2000
+ val data = (for (i <- 1 to numDistinct) yield (i % numGroups,
i)).toDF("group", "value")
+
+ val groupedDf = data.groupBy("group").agg(theta_sketch_agg_build("value",
14).as("sketch"))
+ val mergedDf = groupedDf.agg(theta_sketch_agg_union("sketch",
14).as("merged"))
+ val result: Row =
mergedDf.select(theta_sketch_get_estimate("merged").as("estimate")).head()
+ assert(result.getAs[Double]("estimate") == numDistinct)
+ }
+
+ test("Theta Union via Scala with lgk and seed") {
+ val numGroups = 10
+ val numDistinct = 2000
+ val data = (for (i <- 1 to numDistinct) yield (i % numGroups,
i)).toDF("group", "value")
+
+ val groupedDf = data.groupBy("group").agg(theta_sketch_agg_build("value",
14, 111).as("sketch"))
+ val mergedDf = groupedDf.agg(theta_sketch_agg_union("sketch", 14,
111).as("merged"))
+ val result: Row =
mergedDf.select(theta_sketch_get_estimate("merged").as("estimate")).head()
+ assert(result.getAs[Double]("estimate") == numDistinct)
+
+ val toStr: Row =
mergedDf.select(theta_sketch_to_string("merged").as("summary")).head()
+ toStr.getAs[String]("summary") should startWith ("\n### HeapCompactSketch")
+ }
+
+ test("Theta Union via SQL with defaults") {
+ ThetaFunctionRegistry.registerFunctions(spark)
- test("Theta Union via SQL default lgk") {
val numGroups = 10
val numDistinct = 2000
val data = (for (i <- 1 to numDistinct) yield (i % numGroups,
i)).toDF("group", "value")
@@ -170,6 +196,8 @@ class ThetaTest extends SparkSessionManager {
}
test("Theta Union via SQL with lgk") {
+ ThetaFunctionRegistry.registerFunctions(spark)
+
val numGroups = 10
val numDistinct = 2000
val data = (for (i <- 1 to numDistinct) yield (i % numGroups,
i)).toDF("group", "value")
@@ -195,4 +223,31 @@ class ThetaTest extends SparkSessionManager {
assert(mergedDf.head().getAs[Double]("estimate") == numDistinct)
}
+ test("Theta Union via SQL with lgk and seed") {
+ ThetaFunctionRegistry.registerFunctions(spark)
+
+ val numGroups = 10
+ val numDistinct = 2000
+ val data = (for (i <- 1 to numDistinct) yield (i % numGroups,
i)).toDF("group", "value")
+ data.createOrReplaceTempView("theta_input_table")
+
+ val groupedDf = spark.sql(s"""
+ SELECT
+ group,
+ theta_sketch_agg_build(value, 14, 111L) AS sketch
+ FROM
+ theta_input_table
+ GROUP BY
+ group
+ """)
+ groupedDf.createOrReplaceTempView("theta_sketch_table")
+
+ val mergedDf = spark.sql(s"""
+ SELECT
+ theta_sketch_get_estimate(theta_sketch_agg_union(sketch, 14, 111L)) AS
estimate
+ FROM
+ theta_sketch_table
+ """)
+ assert(mergedDf.head().getAs[Double]("estimate") == numDistinct)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]