This is an automated email from the ASF dual-hosted git repository. alsay pushed a commit to branch theta in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git
commit ed9ef30461aada12b44179f7ffce05cd02b92432 Author: AlexanderSaydakov <[email protected]> AuthorDate: Mon Jan 6 16:22:39 2025 -0800 theta prototype --- .../spark/sql/aggregate/ThetaSketchBuild.scala | 137 +++++++++++++++++++++ .../apache/spark/sql/aggregate/ThetaUnion.scala | 133 ++++++++++++++++++++ .../spark/sql/expressions/ThetaExpressions.scala | 74 +++++++++++ .../scala/org/apache/spark/sql/functions_ds.scala | 52 +++++++- .../registrar/DatasketchesFunctionRegistry.scala | 12 +- .../apache/spark/sql/types/ThetaSketchType.scala | 37 ++++++ .../spark/sql/types/ThetaSketchWrapper.scala | 51 ++++++++ .../scala/org/apache/spark/sql/ThetaTest.scala | 132 ++++++++++++++++++++ 8 files changed, 619 insertions(+), 9 deletions(-) diff --git a/src/main/scala/org/apache/spark/sql/aggregate/ThetaSketchBuild.scala b/src/main/scala/org/apache/spark/sql/aggregate/ThetaSketchBuild.scala new file mode 100644 index 0000000..d1bb88d --- /dev/null +++ b/src/main/scala/org/apache/spark/sql/aggregate/ThetaSketchBuild.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.aggregate + +import org.apache.datasketches.theta.{UpdateSketch, SetOperation} +import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, LongType, NumericType, FloatType, DoubleType, ThetaSketchWrapper, ThetaSketchType} + +/** + * The ThetaSketchBuild function creates a Theta sketch from a column of values + * which can be used to estimate distinct count. + * + * See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for more information. + * + * @param child child expression, from which to build a sketch + * @param lgk the size-accraucy trade-off parameter for the sketch + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr, lgk) - Creates a Theta Sketch and returns the binary representation. + `lgk` (optional, default: 12) the size-accuracy trade-off parameter.""", + examples = """ + Example: + > SELECT theta_sketch_get_estimate(_FUNC_(col, 12)) FROM VALUES (1), (2), (3), (4), (5) tab(col); + 5.0 + """, +) +// scalastyle:on line.size.limit +case class ThetaSketchBuild( + left: Expression, + right: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[ThetaSketchWrapper] + with BinaryLike[Expression] + with ExpectsInputTypes { + + lazy val lgk: Int = { + right.eval() match { + case null => 12 + case lgk: Int => lgk + case _ => throw new SparkUnsupportedOperationException( + s"Unsupported input type ${right.dataType.catalogString}", + Map("dataType" -> dataType.toString)) + } + } + + def this(child: Expression) = this(child, Literal(12), 0, 0) + def this(child: Expression, lgk: Expression) = this(child, lgk, 0, 0) + def this(child: Expression, lgk: Int) = this(child, Literal(lgk), 0, 0) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ThetaSketchBuild = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ThetaSketchBuild = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): ThetaSketchBuild = { + copy(left = newLeft, right = newRight) + } + + override def prettyName: String = "theta_sketch_build" + + override def dataType: DataType = ThetaSketchType + + override def nullable: Boolean = false + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType, LongType, FloatType, DoubleType) + + override def createAggregationBuffer(): ThetaSketchWrapper = new ThetaSketchWrapper(updateSketch = Some(UpdateSketch.builder().setLogNominalEntries(lgk).build())) + + override def update(wrapper: ThetaSketchWrapper, input: InternalRow): ThetaSketchWrapper = { + val value = left.eval(input) + if (value != null) { + left.dataType match { + case DoubleType => wrapper.updateSketch.get.update(value.asInstanceOf[Double]) + case FloatType => wrapper.updateSketch.get.update(value.asInstanceOf[Float]) + case IntegerType => wrapper.updateSketch.get.update(value.asInstanceOf[Int]) + case LongType => wrapper.updateSketch.get.update(value.asInstanceOf[Long]) + case _ => throw new SparkUnsupportedOperationException( + s"Unsupported input type ${left.dataType.catalogString}", + Map("dataType" -> dataType.toString)) + } + } + wrapper + } + + override def merge(wrapper: ThetaSketchWrapper, other: ThetaSketchWrapper): ThetaSketchWrapper = { + if (other != null && !other.compactSketch.get.isEmpty) { + if (wrapper.union.isEmpty) { + wrapper.union = Some(SetOperation.builder().setLogNominalEntries(lgk).buildUnion) + if (wrapper.compactSketch.isDefined) { + wrapper.union.get.union(wrapper.compactSketch.get) + wrapper.compactSketch = None + } + } + wrapper.union.get.union(other.compactSketch.get) + } + wrapper + } + + override def eval(wrapper: ThetaSketchWrapper): Any = { + if (wrapper == null || wrapper.union.isEmpty) { + null + } else { + wrapper.union.get.getResult.toByteArrayCompressed + } + } + + override def serialize(wrapper: ThetaSketchWrapper): Array[Byte] = { + ThetaSketchType.serialize(wrapper) + } + + override def deserialize(bytes: Array[Byte]): ThetaSketchWrapper = { + ThetaSketchType.deserialize(bytes) + } +} diff --git a/src/main/scala/org/apache/spark/sql/aggregate/ThetaUnion.scala b/src/main/scala/org/apache/spark/sql/aggregate/ThetaUnion.scala new file mode 100644 index 0000000..29dd408 --- /dev/null +++ b/src/main/scala/org/apache/spark/sql/aggregate/ThetaUnion.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.aggregate + +import org.apache.datasketches.memory.Memory +import org.apache.datasketches.theta.{Sketch, SetOperation} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.types.{AbstractDataType, DataType, ThetaSketchWrapper, ThetaSketchType} +import org.apache.spark.SparkUnsupportedOperationException + +/** + * Theta Union operation. + * + * See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for more information. + * + * @param child child expression, on which to perform the union operation + * @param lgk the size-accraucy trade-off parameter for the sketch + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr, lgk) - Performs Theta Union operation and returns the result as Theta Sketch in binary form + """, + examples = """ + Example: + > SELECT theta_sketch_get_estimate(_FUNC_(sketch)) FROM (SELECT theta_sketch_build(col) as sketch FROM VALUES (1), (2), (3) tab(col) UNION ALL SELECT theta_sketch_build(col) as sketch FROM VALUES (3), (4), (5) tab(col)); + 5.0 + """ +) +// scalastyle:on line.size.limit +case class ThetaUnion( + left: Expression, + right: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[ThetaSketchWrapper] + with BinaryLike[Expression] + with ExpectsInputTypes { + + lazy val lgk: Int = { + right.eval() match { + case null => 12 + case lgk: Int => lgk + case _ => throw new SparkUnsupportedOperationException( + s"Unsupported input type ${right.dataType.catalogString}", + Map("dataType" -> dataType.toString)) + } + } + + def this(left: Expression, right: Expression) = this(left, right, 0, 0) + def this(left: Expression) = this(left, Literal(12), 0, 0) +// def this(child: Expression, lgk: Expression) = this(child, lgk, 0, 0) +// def this(child: Expression, lgk: Int) = this(child, Literal(lgk), 0, 0) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ThetaUnion = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ThetaUnion = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): ThetaUnion = { + copy(left = newLeft, right = newRight) + } + + // overrides for TypedImperativeAggregate + override def prettyName: String = "theta_union" + override def dataType: DataType = ThetaSketchType + override def nullable: Boolean = false + + // TODO: refine this? + override def inputTypes: Seq[AbstractDataType] = Seq(ThetaSketchType) + + override def createAggregationBuffer(): ThetaSketchWrapper = new ThetaSketchWrapper(union = Some(SetOperation.builder().setLogNominalEntries(lgk).buildUnion)) + + override def update(wrapper: ThetaSketchWrapper, input: InternalRow): ThetaSketchWrapper = { + val bytes = left.eval(input) + if (bytes != null) { + left.dataType match { + case ThetaSketchType => + wrapper.union.get.union(Sketch.wrap(Memory.wrap(bytes.asInstanceOf[Array[Byte]]))) + case _ => throw new SparkUnsupportedOperationException( + s"Unsupported input type ${left.dataType.catalogString}", + Map("dataType" -> dataType.toString)) + } + } + wrapper + } + + override def merge(wrapper: ThetaSketchWrapper, other: ThetaSketchWrapper): ThetaSketchWrapper = { + if (other != null && !other.compactSketch.get.isEmpty) { + if (wrapper.union.isEmpty) { + wrapper.union = Some(SetOperation.builder().setLogNominalEntries(lgk).buildUnion) + if (wrapper.compactSketch.isDefined) { + wrapper.union.get.union(wrapper.compactSketch.get) + wrapper.compactSketch = None + } + } + wrapper.union.get.union(other.compactSketch.get) + } + wrapper + } + + override def eval(wrapper: ThetaSketchWrapper): Any = { + wrapper.union.get.getResult.toByteArrayCompressed + } + + override def serialize(wrapper: ThetaSketchWrapper): Array[Byte] = { + ThetaSketchType.serialize(wrapper) + } + + override def deserialize(bytes: Array[Byte]): ThetaSketchWrapper = { + ThetaSketchType.deserialize(bytes) + } +} diff --git a/src/main/scala/org/apache/spark/sql/expressions/ThetaExpressions.scala b/src/main/scala/org/apache/spark/sql/expressions/ThetaExpressions.scala new file mode 100644 index 0000000..37709a7 --- /dev/null +++ b/src/main/scala/org/apache/spark/sql/expressions/ThetaExpressions.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.datasketches.memory.Memory +import org.apache.datasketches.theta.Sketch + +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpectsInputTypes, UnaryExpression, BinaryExpression} +import org.apache.spark.sql.catalyst.expressions.NullIntolerant +import org.apache.spark.sql.catalyst.expressions.ExpressionDescription +import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeBlock, CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types.{AbstractDataType, DataType, ArrayType, DoubleType, ThetaSketchType} + +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns distinct count estimate from a given sketch + """, + examples = """ + Example: + > SELECT _FUNC_(theta_sketch_build(col)) FROM VALUES (1), (2), (3) tab(col); + 3.0 + """ +) +case class ThetaSketchGetEstimate(child: Expression) + extends UnaryExpression + with ExpectsInputTypes + with NullIntolerant { + + override protected def withNewChildInternal(newChild: Expression): ThetaSketchGetEstimate = { + copy(child = newChild) + } + + override def prettyName: String = "theta_sketch_get_estimate" + + override def inputTypes: Seq[AbstractDataType] = Seq(ThetaSketchType) + + override def dataType: DataType = DoubleType + + override def nullSafeEval(input: Any): Any = { + Sketch.wrap(Memory.wrap(input.asInstanceOf[Array[Byte]])).getEstimate + } + + override protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f: String => String): ExprCode = { + val childEval = child.genCode(ctx) + val sketch = ctx.freshName("sketch") + val code = s""" + ${childEval.code} + final org.apache.datasketches.theta.Sketch $sketch = org.apache.spark.sql.types.ThetaSketchWrapper.wrapAsReadOnlySketch(${childEval.value}); + final double ${ev.value} = $sketch.getEstimate(); + """ + ev.copy(code = CodeBlock(Seq(code), Seq.empty), isNull = childEval.isNull) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => s"($c)") + } +} diff --git a/src/main/scala/org/apache/spark/sql/functions_ds.scala b/src/main/scala/org/apache/spark/sql/functions_ds.scala index 7c2a96e..f30c452 100644 --- a/src/main/scala/org/apache/spark/sql/functions_ds.scala +++ b/src/main/scala/org/apache/spark/sql/functions_ds.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction -import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.aggregate.{KllDoublesSketchAgg, KllDoublesMergeAgg} +import org.apache.spark.sql.aggregate.{ThetaSketchBuild, ThetaUnion} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.expressions._ -import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.types.ArrayType -import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types.{ArrayType, DoubleType} // this class defines and maps all the variants of each function invocation, analagous // to the functions object in org.apache.spark.sql.functions @@ -152,4 +151,45 @@ object functions_ds { kll_get_cdf(Column(columnName), splitPoints) } + // Theta + + def theta_sketch_build(column: Column, lgk: Int): Column = withAggregateFunction { + new ThetaSketchBuild(column.expr, lgk) + } + + def theta_sketch_build(columnName: String, lgk: Int): Column = { + theta_sketch_build(Column(columnName), lgk) + } + + def theta_sketch_build(column: Column): Column = withAggregateFunction { + new ThetaSketchBuild(column.expr) + } + + def theta_sketch_build(columnName: String): Column = { + theta_sketch_build(Column(columnName)) + } + + def theta_union(column: Column, lgk: Int): Column = withAggregateFunction { + new ThetaUnion(column.expr, lit(lgk).expr) + } + + def theta_union(columnName: String, lgk: Int): Column = withAggregateFunction { + new ThetaUnion(Column(columnName).expr, lit(lgk).expr) + } + + def theta_union(column: Column): Column = withAggregateFunction { + new ThetaUnion(column.expr) + } + + def theta_union(columnName: String): Column = withAggregateFunction { + new ThetaUnion(Column(columnName).expr) + } + + def theta_sketch_get_estimate(column: Column): Column = withExpr { + new ThetaSketchGetEstimate(column.expr) + } + + def theta_sketch_get_estimate(columnName: String): Column = { + theta_sketch_get_estimate(Column(columnName)) + } } diff --git a/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala b/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala index 5ab1738..58eacee 100644 --- a/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala +++ b/src/main/scala/org/apache/spark/sql/registrar/DatasketchesFunctionRegistry.scala @@ -27,8 +27,10 @@ import scala.reflect.ClassTag // DataSketches imports import org.apache.spark.sql.aggregate.{KllDoublesSketchAgg, KllDoublesMergeAgg} -import org.apache.spark.sql.expressions.{KllGetMin, KllGetMax} -import org.apache.spark.sql.expressions.KllGetPmfCdf +import org.apache.spark.sql.expressions.{KllGetMin, KllGetMax, KllGetPmfCdf} + +import org.apache.spark.sql.aggregate.{ThetaSketchBuild, ThetaUnion} +import org.apache.spark.sql.expressions.ThetaSketchGetEstimate // based on org.apache.spark.sql.catalyst.FunctionRegistry trait DatasketchesFunctionRegistry { @@ -80,6 +82,10 @@ object DatasketchesFunctionRegistry extends DatasketchesFunctionRegistry { complexExpression[KllGetPmfCdf]("kll_get_cdf") { args: Seq[Expression] => val isInclusive = if (args.length > 2) args(2).eval().asInstanceOf[Boolean] else true new KllGetPmfCdf(args(0), args(1), isInclusive = isInclusive, isPmf = false) - } + }, + + expression[ThetaSketchBuild]("theta_sketch_build"), + expression[ThetaUnion]("theta_union"), + expression[ThetaSketchGetEstimate]("theta_sketch_get_estimate"), ) } diff --git a/src/main/scala/org/apache/spark/sql/types/ThetaSketchType.scala b/src/main/scala/org/apache/spark/sql/types/ThetaSketchType.scala new file mode 100644 index 0000000..e5a5e2c --- /dev/null +++ b/src/main/scala/org/apache/spark/sql/types/ThetaSketchType.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +class ThetaSketchType extends UserDefinedType[ThetaSketchWrapper] { + override def sqlType: DataType = DataTypes.BinaryType + + override def serialize(wrapper: ThetaSketchWrapper): Array[Byte] = { + wrapper.serialize + } + + override def deserialize(data: Any): ThetaSketchWrapper = { + val bytes = data.asInstanceOf[Array[Byte]] + ThetaSketchWrapper.deserialize(bytes) + } + + override def userClass: Class[ThetaSketchWrapper] = classOf[ThetaSketchWrapper] + + override def catalogString: String = "ThetaSketch" +} + +case object ThetaSketchType extends ThetaSketchType diff --git a/src/main/scala/org/apache/spark/sql/types/ThetaSketchWrapper.scala b/src/main/scala/org/apache/spark/sql/types/ThetaSketchWrapper.scala new file mode 100644 index 0000000..86e3b89 --- /dev/null +++ b/src/main/scala/org/apache/spark/sql/types/ThetaSketchWrapper.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.datasketches.theta.{UpdateSketch, CompactSketch, Union} +import org.apache.datasketches.memory.Memory + +@SQLUserDefinedType(udt = classOf[ThetaSketchType]) +class ThetaSketchWrapper(var updateSketch: Option[UpdateSketch] = None, var compactSketch: Option[CompactSketch] = None, var union: Option[Union] = None) { + + def serialize: Array[Byte] = { + if (updateSketch.isDefined) return updateSketch.get.compact().toByteArrayCompressed + else if (compactSketch.isDefined) return compactSketch.get.toByteArrayCompressed + else if (union.isDefined) return union.get.getResult.toByteArrayCompressed + null + } + + override def toString(): String = { + if (updateSketch.isDefined) return updateSketch.get.toString + else if (compactSketch.isDefined) return compactSketch.get.toString + else if (union.isDefined) return union.get.toString + "" + } +} + +object ThetaSketchWrapper { + def deserialize(bytes: Array[Byte]): ThetaSketchWrapper = { + new ThetaSketchWrapper(compactSketch = Some(CompactSketch.heapify(Memory.wrap(bytes)))) + } + + // this can go away in favor of directly calling the Sketch.wrap + // from codegen once janino can generate java 8+ code + def wrapAsReadOnlySketch(bytes: Array[Byte]): CompactSketch = { + CompactSketch.wrap(Memory.wrap(bytes)) + } +} diff --git a/src/test/scala/org/apache/spark/sql/ThetaTest.scala b/src/test/scala/org/apache/spark/sql/ThetaTest.scala new file mode 100644 index 0000000..18fc021 --- /dev/null +++ b/src/test/scala/org/apache/spark/sql/ThetaTest.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions_ds._ +import org.apache.spark.registrar.DatasketchesFunctionRegistry + +class ThetaTest extends SparkSessionManager { + import spark.implicits._ + + test("Theta Sketch build via Scala") { + val n = 100 + val data = (for (i <- 1 to n) yield i).toDF("value") + + val sketchDf = data.agg(theta_sketch_build("value").as("sketch")) + val result: Row = sketchDf.select(theta_sketch_get_estimate("sketch").as("estimate")).head + + assert(result.getAs[Double]("estimate") == 100.0) + } + + test("Theta Sketch build via SQL default lgk") { + DatasketchesFunctionRegistry.registerFunctions(spark) + + val n = 100 + val data = (for (i <- 1 to n) yield i).toDF("value") + data.createOrReplaceTempView("theta_input_table") + + val df = spark.sql(s""" + SELECT + theta_sketch_get_estimate(theta_sketch_build(value)) AS estimate + FROM + theta_input_table + """) + assert(df.head.getAs[Double]("estimate") == 100.0) + } + + test("Theta Sketch build via SQL with lgk") { + DatasketchesFunctionRegistry.registerFunctions(spark) + + val n = 100 + val data = (for (i <- 1 to n) yield i).toDF("value") + data.createOrReplaceTempView("theta_input_table") + + val df = spark.sql(s""" + SELECT + theta_sketch_get_estimate(theta_sketch_build(value, 14)) AS estimate + FROM + theta_input_table + """) + assert(df.head.getAs[Double]("estimate") == 100.0) + } + + test("Theta Union via Scala") { + val numGroups = 10 + val numDistinct = 2000 + val data = (for (i <- 1 to numDistinct) yield (i % numGroups, i)).toDF("group", "value") + + val groupedDf = data.groupBy("group").agg(theta_sketch_build("value").as("sketch")) + val mergedDf = groupedDf.agg(theta_union("sketch").as("merged")) + val result: Row = mergedDf.select(theta_sketch_get_estimate("merged").as("estimate")).head + assert(result.getAs[Double]("estimate") == numDistinct) + } + +/* + test("Theta Union via SQL default lgk") { + val numGroups = 10 + val numDistinct = 2000 + val data = (for (i <- 1 to numDistinct) yield (i % numGroups, i)).toDF("group", "value") + data.createOrReplaceTempView("theta_input_table") + + val groupedDf = spark.sql(s""" + SELECT + group, + theta_sketch_build(value) AS sketch + FROM + theta_input_table + GROUP BY + group + """) + groupedDf.createOrReplaceTempView("theta_sketch_table") + + val mergedDf = spark.sql(s""" + SELECT + theta_sketch_get_estimate(theta_union(sketch)) AS estimate + FROM + theta_sketch_table + """) + assert(mergedDf.head.getAs[Double]("estimate") == numDistinct) + } + + test("Theta Union via SQL with lgk") { + val numGroups = 10 + val numDistinct = 2000 + val data = (for (i <- 1 to numDistinct) yield (i % numGroups, i)).toDF("group", "value") + data.createOrReplaceTempView("theta_input_table") + + val groupedDf = spark.sql(s""" + SELECT + group, + theta_sketch_build(value, 14) AS sketch + FROM + theta_input_table + GROUP BY + group + """) + groupedDf.createOrReplaceTempView("theta_sketch_table") + + val mergedDf = spark.sql(s""" + SELECT + theta_sketch_get_estimate(theta_union(sketch, 14)) AS estimate + FROM + theta_sketch_table + """) + assert(mergedDf.head.getAs[Double]("estimate") == numDistinct) + } +*/ +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
