This is an automated email from the ASF dual-hosted git repository. jmalkin pushed a commit to branch pyspark_java17 in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git
commit 2d0096da117a3b735efd839bc08db6b818c2a3be Author: Jon Malkin <[email protected]> AuthorDate: Wed Feb 19 20:39:46 2025 -0800 Fix get_pmf/cdf codegen not using proper naming, and add support for java17 in pytest (we hope) --- python/tests/conftest.py | 61 +++++++++++++++++++--- .../expressions/KllDoublesSketchExpressions.scala | 14 +++-- .../sql/datasketches/SparkSessionManager.scala | 2 + 3 files changed, 65 insertions(+), 12 deletions(-) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 5a3fe2d..0de51e9 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -15,20 +15,67 @@ # specific language governing permissions and limitations # under the License. +import os +import subprocess import pytest +from typing import Optional from pyspark.sql import SparkSession from datasketches_spark import get_dependency_classpath [email protected](scope="session") +# Attempt to determin the Java version -- this may still fail +# based on command line arguments or other pyspark config, but +# it should often work for local testing. +# Favor looking for a java executable in $SPARK_HOME, else +# $JAVA_HOME, else $PATH from just running `java` +# May return an incorrect number for Java 8 (aka 1.8) but +# we only care about detecting 17 and higher +def get_java_version() -> Optional[int]: + java_cmd = 'java' # fallback option + # check $JAVA_HOME first + if 'JAVA_HOME' in os.environ: + java_cmd = os.path.join(os.environ['JAVA_HOME'], 'bin', 'java') + + # check $SPARK_HOME next -- top choice, if available + if 'SPARK_HOME' in os.environ: + spark_java = os.path.join(os.environ['SPARK_HOME'], 'bin', 'java') + if os.path.exists(spark_java): + java_cmd = spark_java + + # now attempt to run java --version + try: + # java -version prints to stderr + # Version is the 3nd argument of the 1st line and may be in double quotes + # We care only about the major version + output = subprocess.run([java_cmd, '-version'], + capture_output=True, + text=True) + version = output.stderr.splitlines()[0] + return int(version.split()[2].strip('"').split('.')[0]) + + except Exception as e: + print(f"Could not determine java version: {e}") + return None + [email protected](scope='session') def spark(): + java_opts = '' + java_version = get_java_version() + + if java_version is not None and java_version >= 17 and java_version < 21: + java_opts = '--add-modules=jdk.incubator.foreign --add-exports=java.base/sun.nio.ch=ALL-UNNAMED' + os.environ['PYSPARK_SUBMIT_ARGS'] = f"--driver-java-options '{java_opts}' pyspark-shell" + spark = ( SparkSession.builder - .appName("test") - .master("local[*]") - .config("spark.driver.userClassPathFirst", "true") - .config("spark.executor.userClassPathFirst", "true") - .config("spark.driver.extraClassPath", get_dependency_classpath()) - .config("spark.executor.extraClassPath", get_dependency_classpath()) + .appName('test') + .master('local[1]') + .config('spark.driver.userClassPathFirst', 'true') + .config('spark.executor.userClassPathFirst', 'true') + .config('spark.driver.extraClassPath', get_dependency_classpath()) + .config('spark.executor.extraClassPath', get_dependency_classpath()) + .config('spark.executor.extraJavaOptions', java_opts) + .config('spark.driver.bindAddress', 'localhost') + .config('spark.driver.host', 'localhost') .getOrCreate() ) yield spark diff --git a/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala b/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala index 8cd9d65..087a9cb 100644 --- a/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala +++ b/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala @@ -285,18 +285,22 @@ case class KllDoublesSketchGetPmfCdf(sketchExpr: Expression, override protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f: (String, String, String) => String): ExprCode = { val sketchEval = sketchExpr.genCode(ctx) - val sketch = ctx.freshName("sketch") val splitPointsEval = splitPointsExpr.genCode(ctx) + val sketch = ctx.freshName("sketch") + val searchCriterion = ctx.freshName("searchCriterion") + val splitPoints = ctx.freshName("splitPoints") + val result = ctx.freshName("result") + val code = s""" |${sketchEval.code} |${splitPointsEval.code} - |org.apache.datasketches.quantilescommon.QuantileSearchCriteria searchCriteria = ${if (isInclusive) "org.apache.datasketches.quantilescommon.QuantileSearchCriteria.INCLUSIVE" else "org.apache.datasketches.quantilescommon.QuantileSearchCriteria.EXCLUSIVE"}; + |org.apache.datasketches.quantilescommon.QuantileSearchCriteria $searchCriterion = ${if (isInclusive) "org.apache.datasketches.quantilescommon.QuantileSearchCriteria.INCLUSIVE" else "org.apache.datasketches.quantilescommon.QuantileSearchCriteria.EXCLUSIVE"}; |final org.apache.datasketches.kll.KllDoublesSketch $sketch = org.apache.spark.sql.datasketches.kll.types.KllDoublesSketchType.wrap(${sketchEval.value}); - |final double[] splitPoints = ((org.apache.spark.sql.catalyst.util.GenericArrayData)${splitPointsEval.value}).toDoubleArray(); - |final double[] result = ${if (isPmf) s"$sketch.getPMF(splitPoints, searchCriteria)" else s"$sketch.getCDF(splitPoints, searchCriteria)"}; + |final double[] $splitPoints = ((org.apache.spark.sql.catalyst.util.GenericArrayData)${splitPointsEval.value}).toDoubleArray(); + |final double[] $result = ${if (isPmf) s"$sketch.getPMF($splitPoints, $searchCriterion)" else s"$sketch.getCDF($splitPoints, $searchCriterion)"}; |final boolean ${ev.isNull} = false; - |org.apache.spark.sql.catalyst.util.GenericArrayData ${ev.value} = new org.apache.spark.sql.catalyst.util.GenericArrayData(result); + |org.apache.spark.sql.catalyst.util.GenericArrayData ${ev.value} = new org.apache.spark.sql.catalyst.util.GenericArrayData($result); """.stripMargin ev.copy(code = CodeBlock(Seq(code), Seq.empty)) } diff --git a/src/test/scala/org/apache/spark/sql/datasketches/SparkSessionManager.scala b/src/test/scala/org/apache/spark/sql/datasketches/SparkSessionManager.scala index 8aa20bf..1430de2 100644 --- a/src/test/scala/org/apache/spark/sql/datasketches/SparkSessionManager.scala +++ b/src/test/scala/org/apache/spark/sql/datasketches/SparkSessionManager.scala @@ -34,6 +34,8 @@ trait SparkSessionManager extends AnyFunSuite with BeforeAndAfterAll { .builder() .appName("datasketches-spark-tests") .master("local[3]") + .config("spark.driver.bindAddress", "localhost") + .config("spark.driver.host", "localhost") //.config("spark.sql.debug.codegen", "true") .getOrCreate() --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
