This is an automated email from the ASF dual-hosted git repository.

kimahriman pushed a commit to branch spark-4
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit 5dffc7b2d4455537500288953c0787690c19c961
Author: Adam Binford <[email protected]>
AuthorDate: Wed Dec 11 07:06:13 2024 -0500

    Add Spark 4 support
---
 pom.xml                                            |  19 ++++
 .../sql/sedona_sql/expressions/DataFrameAPI.scala  | 107 ++++++++++++++-------
 2 files changed, 91 insertions(+), 35 deletions(-)

diff --git a/pom.xml b/pom.xml
index 6ced20603e..79a3aec4bb 100644
--- a/pom.xml
+++ b/pom.xml
@@ -745,6 +745,25 @@
                 <skip.deploy.common.modules>true</skip.deploy.common.modules>
             </properties>
         </profile>
+        <profile>
+            <id>sedona-spark-4.0</id>
+            <activation>
+                <property>
+                    <name>spark</name>
+                    <value>4.0</value>
+                </property>
+            </activation>
+            <properties>
+                <spark.version>4.0.0-preview2</spark.version>
+                <spark.compat.version>3.5</spark.compat.version>
+                <log4j.version>2.22.1</log4j.version>
+                <graphframe.version>0.8.3-spark3.5</graphframe.version>
+                <scala.version>2.13.12</scala.version>
+                <scala.compat.version>2.13</scala.compat.version>
+                <!-- Skip deploying parent module. it will be deployed with 
sedona-spark-3.3 -->
+                <skip.deploy.common.modules>true</skip.deploy.common.modules>
+            </properties>
+        </profile>
         <profile>
             <id>scala2.13</id>
             <activation>
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/DataFrameAPI.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/DataFrameAPI.scala
index 0b3a041fda..8c77f7f9ab 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/DataFrameAPI.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/DataFrameAPI.scala
@@ -22,51 +22,88 @@ import scala.reflect.ClassTag
 
 import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
 import org.apache.spark.sql.Column
+import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
 import org.apache.spark.sql.execution.aggregate.ScalaUDAF
 
 trait DataFrameAPI {
+
+  // The Column class changed in Spark 4, removing any direct usage of 
Expression
+  // Spark 3 Expression based methods:
+  val exprMethods =
+    try {
+      val exprMethod = Column.getClass().getDeclaredMethod("expr")
+      val constructor = 
Column.getClass().getDeclaredConstructor(classOf[Expression])
+      Some(exprMethod, constructor)
+    } catch {
+      case _: NoSuchMethodException => None
+    }
+
+  // Spark 4 method using Column.fn
+  val fnMethod =
+    try {
+      Some(Column.getClass().getDeclaredMethod("fn", classOf[String], 
classOf[Array[Column]]))
+    } catch {
+      case _: NoSuchMethodException => None
+    }
+
   protected def wrapExpression[E <: Expression: ClassTag](args: Any*): Column 
= {
-    val exprArgs = args.map(_ match {
-      case c: Column => c.expr
-      case s: String => Column(s).expr
-      case e: Expression => e
-      case x: Any => Literal(x)
-      case null => Literal(null)
-    })
-    val expressionConstructor =
-      
implicitly[ClassTag[E]].runtimeClass.getConstructor(classOf[Seq[Expression]])
-    val expressionInstance = 
expressionConstructor.newInstance(exprArgs).asInstanceOf[E]
-    Column(expressionInstance)
+    wrapVarArgExpression[E](args)
   }
 
   protected def wrapVarArgExpression[E <: Expression: ClassTag](arg: 
Seq[Any]): Column = {
-    val exprArgs = arg.map(_ match {
-      case c: Column => c.expr
-      case s: String => Column(s).expr
-      case e: Expression => e
-      case x: Any => Literal(x)
-      case null => Literal(null)
-    })
-    val expressionConstructor =
-      
implicitly[ClassTag[E]].runtimeClass.getConstructor(classOf[Seq[Expression]])
-    val expressionInstance = 
expressionConstructor.newInstance(exprArgs).asInstanceOf[E]
-    Column(expressionInstance)
+    val runtimeClass = implicitly[ClassTag[E]].runtimeClass
+    fnMethod
+      .map { fn =>
+        val colArgs = arg.map(_ match {
+          case c: Column => c
+          case s: String => Column(s)
+          case x: Any => lit(x)
+          case null => lit(null)
+        })
+        fn.invoke(runtimeClass.getSimpleName(), colArgs: 
_*).asInstanceOf[Column]
+      }
+      .getOrElse {
+        val (expr, constructor) = exprMethods.get
+        val exprArgs = arg.map(_ match {
+          case c: Column => expr.invoke(c).asInstanceOf[Expression]
+          case s: String => expr.invoke(Column(s)).asInstanceOf[Expression]
+          case e: Expression => e
+          case x: Any => Literal(x)
+          case null => Literal(null)
+        })
+        val expressionConstructor = 
runtimeClass.getConstructor(classOf[Seq[Expression]])
+        val expressionInstance = 
expressionConstructor.newInstance(exprArgs).asInstanceOf[E]
+        constructor.newInstance(expressionInstance).asInstanceOf[Column]
+      }
   }
 
   protected def wrapAggregator[A <: UserDefinedAggregateFunction: 
ClassTag](arg: Any*): Column = {
-    val exprArgs = arg.map(_ match {
-      case c: Column => c.expr
-      case s: String => Column(s).expr
-      case e: Expression => e
-      case x: Any => Literal(x)
-      case null => Literal(null)
-    })
-    val aggregatorClass = implicitly[ClassTag[A]].runtimeClass
-    val aggregatorConstructor = aggregatorClass.getConstructor()
-    val aggregatorInstance =
-      
aggregatorConstructor.newInstance().asInstanceOf[UserDefinedAggregateFunction]
-    val scalaAggregator = ScalaUDAF(exprArgs, aggregatorInstance)
-    Column(scalaAggregator)
+    val runtimeClass = implicitly[ClassTag[A]].runtimeClass
+    fnMethod
+      .map { fn =>
+        val colArgs = arg.map(_ match {
+          case c: Column => c
+          case s: String => Column(s)
+          case x: Any => lit(x)
+          case null => lit(null)
+        })
+        fn.invoke(runtimeClass.getSimpleName(), colArgs: 
_*).asInstanceOf[Column]
+      }
+      .getOrElse {
+        val (expr, constructor) = exprMethods.get
+        val exprArgs = arg.map(_ match {
+          case c: Column => expr.invoke(c).asInstanceOf[Expression]
+          case s: String => expr.invoke(Column(s)).asInstanceOf[Expression]
+          case e: Expression => e
+          case x: Any => Literal(x)
+          case null => Literal(null)
+        })
+        val aggregatorConstructor = runtimeClass.getConstructor()
+        val aggregatorInstance =
+          
aggregatorConstructor.newInstance().asInstanceOf[UserDefinedAggregateFunction]
+        val scalaAggregator = ScalaUDAF(exprArgs, aggregatorInstance)
+        constructor.newInstance(scalaAggregator).asInstanceOf[Column]
+      }
   }
 }

Reply via email to