This is an automated email from the ASF dual-hosted git repository.

jmalkin pushed a commit to branch pyspark_java17
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git

commit 2d0096da117a3b735efd839bc08db6b818c2a3be
Author: Jon Malkin <[email protected]>
AuthorDate: Wed Feb 19 20:39:46 2025 -0800

    Fix get_pmf/cdf codegen not using proper naming, and add support for java17 
in pytest (we hope)
---
 python/tests/conftest.py                           | 61 +++++++++++++++++++---
 .../expressions/KllDoublesSketchExpressions.scala  | 14 +++--
 .../sql/datasketches/SparkSessionManager.scala     |  2 +
 3 files changed, 65 insertions(+), 12 deletions(-)

diff --git a/python/tests/conftest.py b/python/tests/conftest.py
index 5a3fe2d..0de51e9 100644
--- a/python/tests/conftest.py
+++ b/python/tests/conftest.py
@@ -15,20 +15,67 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import os
+import subprocess
 import pytest
+from typing import Optional
 from pyspark.sql import SparkSession
 from datasketches_spark import get_dependency_classpath
 
[email protected](scope="session")
+# Attempt to determin the Java version -- this may still fail
+# based on command line arguments or other pyspark config, but
+# it should often work for local testing.
+# Favor looking for a java executable in $SPARK_HOME, else
+# $JAVA_HOME, else $PATH from just running `java`
+# May return an incorrect number for Java 8 (aka 1.8) but
+# we only care about detecting 17 and higher
+def get_java_version() -> Optional[int]:
+    java_cmd = 'java' # fallback option
+    # check $JAVA_HOME first
+    if 'JAVA_HOME' in os.environ:
+        java_cmd = os.path.join(os.environ['JAVA_HOME'], 'bin', 'java')
+
+    # check $SPARK_HOME next -- top choice, if available
+    if 'SPARK_HOME' in os.environ:
+        spark_java = os.path.join(os.environ['SPARK_HOME'], 'bin', 'java')
+        if os.path.exists(spark_java):
+            java_cmd = spark_java
+
+    # now attempt to run java --version
+    try:
+        # java -version prints to stderr
+        # Version is the 3nd argument of the 1st line and may be in double 
quotes
+        # We care only about the major version
+        output = subprocess.run([java_cmd, '-version'],
+                                capture_output=True,
+                                text=True)
+        version = output.stderr.splitlines()[0]
+        return int(version.split()[2].strip('"').split('.')[0])
+
+    except Exception as e:
+        print(f"Could not determine java version: {e}")
+        return None
+
[email protected](scope='session')
 def spark():
+    java_opts = ''
+    java_version = get_java_version()
+
+    if java_version is not None and java_version >= 17 and java_version < 21:
+        java_opts = '--add-modules=jdk.incubator.foreign 
--add-exports=java.base/sun.nio.ch=ALL-UNNAMED'
+        os.environ['PYSPARK_SUBMIT_ARGS'] = f"--driver-java-options 
'{java_opts}' pyspark-shell"
+
     spark = (
         SparkSession.builder
-        .appName("test")
-        .master("local[*]")
-        .config("spark.driver.userClassPathFirst", "true")
-        .config("spark.executor.userClassPathFirst", "true")
-        .config("spark.driver.extraClassPath", get_dependency_classpath())
-        .config("spark.executor.extraClassPath", get_dependency_classpath())
+        .appName('test')
+        .master('local[1]')
+        .config('spark.driver.userClassPathFirst', 'true')
+        .config('spark.executor.userClassPathFirst', 'true')
+        .config('spark.driver.extraClassPath', get_dependency_classpath())
+        .config('spark.executor.extraClassPath', get_dependency_classpath())
+        .config('spark.executor.extraJavaOptions', java_opts)
+        .config('spark.driver.bindAddress', 'localhost')
+        .config('spark.driver.host', 'localhost')
         .getOrCreate()
     )
     yield spark
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala
 
b/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala
index 8cd9d65..087a9cb 100644
--- 
a/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala
+++ 
b/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala
@@ -285,18 +285,22 @@ case class KllDoublesSketchGetPmfCdf(sketchExpr: 
Expression,
 
   override protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f: 
(String, String, String) => String): ExprCode = {
     val sketchEval = sketchExpr.genCode(ctx)
-    val sketch = ctx.freshName("sketch")
     val splitPointsEval = splitPointsExpr.genCode(ctx)
+    val sketch = ctx.freshName("sketch")
+    val searchCriterion = ctx.freshName("searchCriterion")
+    val splitPoints = ctx.freshName("splitPoints")
+    val result = ctx.freshName("result")
+
     val code =
       s"""
          |${sketchEval.code}
          |${splitPointsEval.code}
-         |org.apache.datasketches.quantilescommon.QuantileSearchCriteria 
searchCriteria = ${if (isInclusive) 
"org.apache.datasketches.quantilescommon.QuantileSearchCriteria.INCLUSIVE" else 
"org.apache.datasketches.quantilescommon.QuantileSearchCriteria.EXCLUSIVE"};
+         |org.apache.datasketches.quantilescommon.QuantileSearchCriteria 
$searchCriterion = ${if (isInclusive) 
"org.apache.datasketches.quantilescommon.QuantileSearchCriteria.INCLUSIVE" else 
"org.apache.datasketches.quantilescommon.QuantileSearchCriteria.EXCLUSIVE"};
          |final org.apache.datasketches.kll.KllDoublesSketch $sketch = 
org.apache.spark.sql.datasketches.kll.types.KllDoublesSketchType.wrap(${sketchEval.value});
-         |final double[] splitPoints = 
((org.apache.spark.sql.catalyst.util.GenericArrayData)${splitPointsEval.value}).toDoubleArray();
-         |final double[] result = ${if (isPmf) s"$sketch.getPMF(splitPoints, 
searchCriteria)" else s"$sketch.getCDF(splitPoints, searchCriteria)"};
+         |final double[] $splitPoints = 
((org.apache.spark.sql.catalyst.util.GenericArrayData)${splitPointsEval.value}).toDoubleArray();
+         |final double[] $result = ${if (isPmf) s"$sketch.getPMF($splitPoints, 
$searchCriterion)" else s"$sketch.getCDF($splitPoints, $searchCriterion)"};
          |final boolean ${ev.isNull} = false;
-         |org.apache.spark.sql.catalyst.util.GenericArrayData ${ev.value} = 
new org.apache.spark.sql.catalyst.util.GenericArrayData(result);
+         |org.apache.spark.sql.catalyst.util.GenericArrayData ${ev.value} = 
new org.apache.spark.sql.catalyst.util.GenericArrayData($result);
        """.stripMargin
     ev.copy(code = CodeBlock(Seq(code), Seq.empty))
   }
diff --git 
a/src/test/scala/org/apache/spark/sql/datasketches/SparkSessionManager.scala 
b/src/test/scala/org/apache/spark/sql/datasketches/SparkSessionManager.scala
index 8aa20bf..1430de2 100644
--- a/src/test/scala/org/apache/spark/sql/datasketches/SparkSessionManager.scala
+++ b/src/test/scala/org/apache/spark/sql/datasketches/SparkSessionManager.scala
@@ -34,6 +34,8 @@ trait SparkSessionManager extends AnyFunSuite with 
BeforeAndAfterAll {
       .builder()
       .appName("datasketches-spark-tests")
       .master("local[3]")
+      .config("spark.driver.bindAddress", "localhost")
+      .config("spark.driver.host", "localhost")
       //.config("spark.sql.debug.codegen", "true")
       .getOrCreate()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to