This is an automated email from the ASF dual-hosted git repository.

jmalkin pushed a commit to branch import_kll
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git

commit 7e001c1447cff8a9f7f8a90234fef94c6b827cb8
Author: Jon <[email protected]>
AuthorDate: Wed Dec 25 13:52:39 2024 -0800

    Add udf to import sketch images from outside spark into a dataframe
---
 .../spark/sql/types/KllDoublesSketchType.scala     | 15 ++++++++++--
 src/test/scala/org/apache/spark/sql/KllTest.scala  | 28 +++++++++++++++++++++-
 2 files changed, 40 insertions(+), 3 deletions(-)

diff --git 
a/src/main/scala/org/apache/spark/sql/types/KllDoublesSketchType.scala 
b/src/main/scala/org/apache/spark/sql/types/KllDoublesSketchType.scala
index 65a3635..afb56cc 100644
--- a/src/main/scala/org/apache/spark/sql/types/KllDoublesSketchType.scala
+++ b/src/main/scala/org/apache/spark/sql/types/KllDoublesSketchType.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.types
 
-class KllDoublesSketchType extends UserDefinedType[KllDoublesSketchWrapper] {
+import org.apache.spark.sql.functions.udf
+
+class KllDoublesSketchType extends UserDefinedType[KllDoublesSketchWrapper] 
with Serializable {
   override def sqlType: DataType = DataTypes.BinaryType
 
   override def serialize(wrapper: KllDoublesSketchWrapper): Array[Byte] = {
@@ -34,4 +36,13 @@ class KllDoublesSketchType extends 
UserDefinedType[KllDoublesSketchWrapper] {
   override def catalogString: String = "KllDoublesSketch"
 }
 
-case object KllDoublesSketchType extends KllDoublesSketchType
+case object KllDoublesSketchType extends KllDoublesSketchType {
+  // udf to allow importing serialized sketches into dataframes
+  val wrapBytes = udf((bytes: Array[Byte]) => {
+    if (bytes == null) {
+      null
+    } else {
+      deserialize(bytes)
+    }
+  })
+}
diff --git a/src/test/scala/org/apache/spark/sql/KllTest.scala 
b/src/test/scala/org/apache/spark/sql/KllTest.scala
index 49e0a18..15e83c9 100644
--- a/src/test/scala/org/apache/spark/sql/KllTest.scala
+++ b/src/test/scala/org/apache/spark/sql/KllTest.scala
@@ -19,9 +19,13 @@ package org.apache.spark.sql
 
 import scala.util.Random
 import org.apache.spark.sql.functions._
+import scala.collection.mutable.WrappedArray
+import org.apache.spark.sql.types.{StructType, StructField, IntegerType, 
BinaryType}
+
 import org.apache.spark.sql.functions_ds._
+import org.apache.datasketches.kll.KllDoublesSketch
+import org.apache.spark.sql.types.KllDoublesSketchType
 import org.apache.spark.registrar.DatasketchesFunctionRegistry
-import scala.collection.mutable.WrappedArray
 
 class KllTest extends SparkSessionManager {
   import spark.implicits._
@@ -34,6 +38,28 @@ class KllTest extends SparkSessionManager {
     (ref zip tstArr).foreach { case (v1, v2) => if (v1 != v2) throw new 
AssertionError("Values do not match: " + v1 + " != " + v2) }
   }
 
+  test("Load KllDoublesSketch images into dataframe") {
+    val numClass = 10
+    val numSamples = 10000
+
+    // produce a Seq(Array(id, sk))
+    val data = for (i <- 1 to numClass) yield {
+      val sk = KllDoublesSketch.newHeapInstance(200)
+      for (j <- 0 until numSamples) sk.update(Random.nextDouble)
+      Row(i, sk.toByteArray)
+    }
+
+    val schema = StructType(Array(
+      StructField("id", IntegerType, false),
+      StructField("kll", BinaryType, true)
+    ))
+
+    val df = spark.createDataFrame(spark.sparkContext.parallelize(data), 
schema)
+      .select($"id", KllDoublesSketchType.wrapBytes($"kll").as("sketch"))
+
+    df.show()
+  }
+
   test("KLL Doubles Sketch via scala") {
     val n = 100
     val data = (for (i <- 1 to n) yield i.toDouble).toDF("value")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to