This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 3a3b8d35ec [SEDONA 710] Rename Geostats SQL classes to generic name;
merge UdfRegistrator into AbstractCatalog (#1809)
3a3b8d35ec is described below
commit 3a3b8d35ec3a132804ef60d53567ee89d67940c6
Author: James Willis <[email protected]>
AuthorDate: Wed Feb 12 16:23:39 2025 -0800
[SEDONA 710] Rename Geostats SQL classes to generic name; merge
UdfRegistrator into AbstractCatalog (#1809)
Co-authored-by: jameswillis <[email protected]>
---
.../org/apache/sedona/spark/SedonaContext.scala | 19 ++--
.../apache/sedona/sql/UDF/AbstractCatalog.scala | 25 +++++
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 10 +-
.../org/apache/sedona/sql/UDF/UdfRegistrator.scala | 54 ---------
.../sedona/sql/utils/SedonaSQLRegistrator.scala | 4 +-
.../sedona_sql/expressions/GeoStatsFunctions.scala | 103 ++++-------------
.../sedona_sql/expressions/PhysicalFunction.scala | 109 ++++++++++++++++++
.../optimization/ExtractGeoStatsFunctions.scala | 120 --------------------
.../optimization/ExtractPhysicalFunctions.scala | 122 +++++++++++++++++++++
...tsFunction.scala => EvalPhysicalFunction.scala} | 2 +-
.../function/EvalPhysicalFunctionExec.scala} | 8 +-
.../function/EvalPhysicalFunctionStrategy.scala} | 12 +-
12 files changed, 305 insertions(+), 283 deletions(-)
diff --git
a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
index fe2926fc51..7cfb8670be 100644
--- a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
@@ -21,12 +21,12 @@ package org.apache.sedona.spark
import org.apache.sedona.common.utils.TelemetryCollector
import org.apache.sedona.core.serde.SedonaKryoRegistrator
import org.apache.sedona.sql.RasterRegistrator
-import org.apache.sedona.sql.UDF.UdfRegistrator
+import org.apache.sedona.sql.UDF.Catalog
import org.apache.sedona.sql.UDT.UdtRegistrator
import org.apache.spark.serializer.KryoSerializer
-import org.apache.spark.sql.sedona_sql.optimization.{ExtractGeoStatsFunctions,
SpatialFilterPushDownForGeoParquet, SpatialTemporalFilterPushDownForStacScan}
-import
org.apache.spark.sql.sedona_sql.strategy.geostats.EvalGeoStatsFunctionStrategy
+import org.apache.spark.sql.sedona_sql.optimization._
import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
+import
org.apache.spark.sql.sedona_sql.strategy.physical.function.EvalPhysicalFunctionStrategy
import org.apache.spark.sql.{SQLContext, SparkSession}
import scala.annotation.StaticAnnotation
@@ -73,20 +73,21 @@ object SedonaContext {
}
}
- // Support geostats functions
- if
(!sparkSession.experimental.extraOptimizations.contains(ExtractGeoStatsFunctions))
{
- sparkSession.experimental.extraOptimizations ++=
Seq(ExtractGeoStatsFunctions)
+ // Support physical functions
+ if
(!sparkSession.experimental.extraOptimizations.contains(ExtractPhysicalFunctions))
{
+ sparkSession.experimental.extraOptimizations ++=
Seq(ExtractPhysicalFunctions)
}
+
if (!sparkSession.experimental.extraStrategies.exists(
- _.isInstanceOf[EvalGeoStatsFunctionStrategy])) {
+ _.isInstanceOf[EvalPhysicalFunctionStrategy])) {
sparkSession.experimental.extraStrategies ++= Seq(
- new EvalGeoStatsFunctionStrategy(sparkSession))
+ new EvalPhysicalFunctionStrategy(sparkSession))
}
addGeoParquetToSupportNestedFilterSources(sparkSession)
RasterRegistrator.registerAll(sparkSession)
UdtRegistrator.registerAll()
- UdfRegistrator.registerAll(sparkSession)
+ Catalog.registerAll(sparkSession)
sparkSession
}
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala
index 3ad579c38c..f8bb0ac5fe 100644
---
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala
+++
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala
@@ -18,6 +18,7 @@
*/
package org.apache.sedona.sql.UDF
+import org.apache.spark.sql.{SQLContext, SparkSession, functions}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes,
Expression, ExpressionInfo, Literal}
@@ -74,4 +75,28 @@ abstract class AbstractCatalog {
(functionIdentifier, expressionInfo, functionBuilder)
}
+
+ def registerAll(sqlContext: SQLContext): Unit = {
+ registerAll(sqlContext.sparkSession)
+ }
+
+ def registerAll(sparkSession: SparkSession): Unit = {
+ Catalog.expressions.foreach { case (functionIdentifier, expressionInfo,
functionBuilder) =>
+ sparkSession.sessionState.functionRegistry.registerFunction(
+ functionIdentifier,
+ expressionInfo,
+ functionBuilder)
+ }
+ Catalog.aggregateExpressions.foreach(f =>
+ sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f)))
+ }
+
+ def dropAll(sparkSession: SparkSession): Unit = {
+ Catalog.expressions.foreach { case (functionIdentifier, _, _) =>
+
sparkSession.sessionState.functionRegistry.dropFunction(functionIdentifier)
+ }
+ Catalog.aggregateExpressions.foreach(f =>
+ sparkSession.sessionState.functionRegistry.dropFunction(
+ FunctionIdentifier(f.getClass.getSimpleName)))
+ }
}
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 16c393cdbc..af51a825f8 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -19,14 +19,12 @@
package org.apache.sedona.sql.UDF
import org.apache.spark.sql.expressions.Aggregator
-import org.apache.spark.sql.sedona_sql.expressions.{ST_InterpolatePoint, _}
import org.apache.spark.sql.sedona_sql.expressions.collect.ST_Collect
import org.apache.spark.sql.sedona_sql.expressions.raster._
+import org.apache.spark.sql.sedona_sql.expressions._
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.operation.buffer.BufferParameters
-import scala.collection.mutable.ListBuffer
-
object Catalog extends AbstractCatalog {
override val expressions: Seq[FunctionDescription] = Seq(
@@ -344,9 +342,5 @@ object Catalog extends AbstractCatalog {
function[ST_WeightedDistanceBandColumn]())
val aggregateExpressions: Seq[Aggregator[Geometry, _, _]] =
- Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr)
-
- // Aggregate functions with List as buffer
- val aggregateExpressions2: Seq[Aggregator[Geometry, ListBuffer[Geometry],
Geometry]] =
- Seq(new ST_Union_Aggr())
+ Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr, new ST_Union_Aggr())
}
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
deleted file mode 100644
index 30c3cb2e3b..0000000000
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.sedona.sql.UDF
-
-import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.{SQLContext, SparkSession, functions}
-
-object UdfRegistrator {
-
- def registerAll(sqlContext: SQLContext): Unit = {
- registerAll(sqlContext.sparkSession)
- }
-
- def registerAll(sparkSession: SparkSession): Unit = {
- Catalog.expressions.foreach { case (functionIdentifier, expressionInfo,
functionBuilder) =>
- sparkSession.sessionState.functionRegistry.registerFunction(
- functionIdentifier,
- expressionInfo,
- functionBuilder)
- }
- Catalog.aggregateExpressions.foreach(f =>
- sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f)))
// SPARK3 anchor
-
- Catalog.aggregateExpressions2.foreach(f =>
- sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f)))
// SPARK3 anchor
- }
-
- def dropAll(sparkSession: SparkSession): Unit = {
- Catalog.expressions.foreach { case (functionIdentifier, _, _) =>
-
sparkSession.sessionState.functionRegistry.dropFunction(functionIdentifier)
- }
- Catalog.aggregateExpressions.foreach(f =>
- sparkSession.sessionState.functionRegistry.dropFunction(
- FunctionIdentifier(f.getClass.getSimpleName)
- )) // SPARK3 anchor
-//Catalog.aggregateExpressions_UDAF.foreach(f =>
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(f.getClass.getSimpleName)))
// SPARK2 anchor
- }
-}
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
index 52f7ceb1cd..a679db084d 100644
---
a/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
+++
b/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
@@ -20,7 +20,7 @@ package org.apache.sedona.sql.utils
import org.apache.sedona.spark.SedonaContext
import org.apache.sedona.sql.RasterRegistrator
-import org.apache.sedona.sql.UDF.UdfRegistrator
+import org.apache.sedona.sql.UDF.Catalog
import org.apache.spark.sql.{SQLContext, SparkSession}
@deprecated("Use SedonaContext instead", "1.4.1")
@@ -44,7 +44,7 @@ object SedonaSQLRegistrator {
SedonaContext.create(sparkSession, language)
def dropAll(sparkSession: SparkSession): Unit = {
- UdfRegistrator.dropAll(sparkSession)
+ Catalog.dropAll(sparkSession)
RasterRegistrator.dropAll(sparkSession)
}
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala
index 8c6b645daf..75e86510ab 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala
@@ -23,80 +23,13 @@ import
org.apache.sedona.stats.Weighting.{addBinaryDistanceBandColumn, addWeight
import org.apache.sedona.stats.clustering.DBSCAN.dbscan
import org.apache.sedona.stats.hotspotDetection.GetisOrd.gLocal
import
org.apache.sedona.stats.outlierDetection.LocalOutlierFactor.localOutlierFactor
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, ImplicitCastInputTypes, Literal,
ScalarSubquery, Unevaluable}
-import org.apache.spark.sql.execution.{LogicalRDD, SparkPlan}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.functions.{col, struct}
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
-import scala.reflect.ClassTag
-
-// We mark ST_GeoStatsFunction as non-deterministic to avoid the filter
push-down optimization pass
-// duplicates the ST_GeoStatsFunction when pushing down aliased
ST_GeoStatsFunction through a
-// Project operator. This will make ST_GeoStatsFunction being evaluated twice.
-trait ST_GeoStatsFunction
- extends Expression
- with ImplicitCastInputTypes
- with Unevaluable
- with Serializable {
-
- final override lazy val deterministic: Boolean = false
-
- override def nullable: Boolean = true
-
- private final lazy val sparkSession = SparkSession.getActiveSession.get
-
- protected final lazy val geometryColumnName = getInputName(0, "geometry")
-
- protected def getInputName(i: Int, fieldName: String): String = children(i)
match {
- case ref: AttributeReference => ref.name
- case _ =>
- throw new IllegalArgumentException(
- f"$fieldName argument must be a named reference to an existing column")
- }
-
- protected def getInputNames(i: Int, fieldName: String): Seq[String] =
children(
- i).dataType match {
- case StructType(fields) => fields.map(_.name)
- case _ => throw new IllegalArgumentException(f"$fieldName argument must be
a struct")
- }
-
- protected def getResultName(resultAttrs: Seq[Attribute]): String =
resultAttrs match {
- case Seq(attr) => attr.name
- case _ => throw new IllegalArgumentException("resultAttrs must have
exactly one attribute")
- }
-
- protected def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]):
DataFrame
-
- protected def getScalarValue[T](i: Int, name: String)(implicit ct:
ClassTag[T]): T = {
- children(i) match {
- case Literal(l: T, _) => l
- case _: Literal =>
- throw new IllegalArgumentException(f"$name must be an instance of
${ct.runtimeClass}")
- case s: ScalarSubquery =>
- s.eval() match {
- case t: T => t
- case _ =>
- throw new IllegalArgumentException(
- f"$name must be an instance of ${ct.runtimeClass}")
- }
- case _ => throw new IllegalArgumentException(f"$name must be a scalar
value")
- }
- }
-
- def execute(plan: SparkPlan, resultAttrs: Seq[Attribute]): RDD[InternalRow]
= {
- val df = doExecute(
- Dataset.ofRows(sparkSession, LogicalRDD(plan.output,
plan.execute())(sparkSession)),
- resultAttrs)
- df.queryExecution.toRdd
- }
-
-}
-
-case class ST_DBSCAN(children: Seq[Expression]) extends ST_GeoStatsFunction {
+case class ST_DBSCAN(children: Seq[Expression]) extends
DataframePhysicalFunction {
override def dataType: DataType = StructType(
Seq(StructField("isCore", BooleanType), StructField("cluster", LongType)))
@@ -107,7 +40,9 @@ case class ST_DBSCAN(children: Seq[Expression]) extends
ST_GeoStatsFunction {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression =
copy(children = newChildren)
- override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]):
DataFrame = {
+ override def transformDataframe(
+ dataframe: DataFrame,
+ resultAttrs: Seq[Attribute]): DataFrame = {
require(
!dataframe.columns.contains("__isCore"),
"__isCore is a reserved name by the dbscan algorithm. Please rename the
columns before calling the ST_DBSCAN function.")
@@ -129,7 +64,7 @@ case class ST_DBSCAN(children: Seq[Expression]) extends
ST_GeoStatsFunction {
}
}
-case class ST_LocalOutlierFactor(children: Seq[Expression]) extends
ST_GeoStatsFunction {
+case class ST_LocalOutlierFactor(children: Seq[Expression]) extends
DataframePhysicalFunction {
override def dataType: DataType = DoubleType
@@ -139,7 +74,9 @@ case class ST_LocalOutlierFactor(children: Seq[Expression])
extends ST_GeoStatsF
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression =
copy(children = newChildren)
- override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]):
DataFrame = {
+ override def transformDataframe(
+ dataframe: DataFrame,
+ resultAttrs: Seq[Attribute]): DataFrame = {
localOutlierFactor(
dataframe,
getScalarValue[Int](1, "k"),
@@ -150,7 +87,7 @@ case class ST_LocalOutlierFactor(children: Seq[Expression])
extends ST_GeoStatsF
}
}
-case class ST_GLocal(children: Seq[Expression]) extends ST_GeoStatsFunction {
+case class ST_GLocal(children: Seq[Expression]) extends
DataframePhysicalFunction {
override def dataType: DataType = StructType(
Seq(
@@ -172,7 +109,9 @@ case class ST_GLocal(children: Seq[Expression]) extends
ST_GeoStatsFunction {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression =
copy(children = newChildren)
- override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]):
DataFrame = {
+ override def transformDataframe(
+ dataframe: DataFrame,
+ resultAttrs: Seq[Attribute]): DataFrame = {
gLocal(
dataframe,
getInputName(0, "x"),
@@ -187,7 +126,8 @@ case class ST_GLocal(children: Seq[Expression]) extends
ST_GeoStatsFunction {
}
}
-case class ST_BinaryDistanceBandColumn(children: Seq[Expression]) extends
ST_GeoStatsFunction {
+case class ST_BinaryDistanceBandColumn(children: Seq[Expression])
+ extends DataframePhysicalFunction {
override def dataType: DataType = ArrayType(
StructType(
Seq(StructField("neighbor", children(5).dataType), StructField("value",
DoubleType))))
@@ -198,7 +138,9 @@ case class ST_BinaryDistanceBandColumn(children:
Seq[Expression]) extends ST_Geo
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression =
copy(children = newChildren)
- override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]):
DataFrame = {
+ override def transformDataframe(
+ dataframe: DataFrame,
+ resultAttrs: Seq[Attribute]): DataFrame = {
val attributeNames = getInputNames(5, "attributes")
require(attributeNames.nonEmpty, "attributes must have at least one
column")
require(
@@ -217,7 +159,8 @@ case class ST_BinaryDistanceBandColumn(children:
Seq[Expression]) extends ST_Geo
}
}
-case class ST_WeightedDistanceBandColumn(children: Seq[Expression]) extends
ST_GeoStatsFunction {
+case class ST_WeightedDistanceBandColumn(children: Seq[Expression])
+ extends DataframePhysicalFunction {
override def dataType: DataType = ArrayType(
StructType(
@@ -237,7 +180,9 @@ case class ST_WeightedDistanceBandColumn(children:
Seq[Expression]) extends ST_G
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression =
copy(children = newChildren)
- override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]):
DataFrame = {
+ override def transformDataframe(
+ dataframe: DataFrame,
+ resultAttrs: Seq[Attribute]): DataFrame = {
val attributeNames = getInputNames(7, "attributes")
require(attributeNames.nonEmpty, "attributes must have at least one
column")
require(
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/PhysicalFunction.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/PhysicalFunction.scala
new file mode 100644
index 0000000000..253dfe2cfa
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/PhysicalFunction.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, ImplicitCastInputTypes, Literal,
ScalarSubquery, Unevaluable}
+import org.apache.spark.sql.execution.{LogicalRDD, SparkPlan}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+
+import scala.reflect.ClassTag
+
+/**
+ * PhysicalFunctions are Functions that will be replaced with a Physical Node
for their
+ * evaluation.
+ *
+ * execute is the method that will be called in order to evaluate the
function. PhysicalFunction
+ * is marked non-deterministic to avoid the filter push-down optimization pass
which duplicates
+ * the PhysicalFunction when pushing down aliased PhysicalFunction calls
through a Project
+ * operator. Otherwise the PhysicalFunction would be evaluated twice.
+ */
+trait PhysicalFunction
+ extends Expression
+ with ImplicitCastInputTypes
+ with Unevaluable
+ with Serializable {
+ final override lazy val deterministic: Boolean = false
+
+ override def nullable: Boolean = true
+
+ protected final lazy val sparkSession = SparkSession.getActiveSession.get
+
+ protected final lazy val geometryColumnName = getInputName(0, "geometry")
+
+ protected def getInputName(i: Int, fieldName: String): String = children(i)
match {
+ case ref: AttributeReference => ref.name
+ case _ =>
+ throw new IllegalArgumentException(
+ f"$fieldName argument must be a named reference to an existing column")
+ }
+
+ protected def getScalarValue[T](i: Int, name: String)(implicit ct:
ClassTag[T]): T = {
+ children(i) match {
+ case Literal(l: T, _) => l
+ case _: Literal =>
+ throw new IllegalArgumentException(f"$name must be an instance of
${ct.runtimeClass}")
+ case s: ScalarSubquery =>
+ s.eval() match {
+ case t: T => t
+ case _ =>
+ throw new IllegalArgumentException(
+ f"$name must be an instance of ${ct.runtimeClass}")
+ }
+ case _ => throw new IllegalArgumentException(f"$name must be a scalar
value")
+ }
+ }
+
+ protected def getInputNames(i: Int, fieldName: String): Seq[String] =
children(
+ i).dataType match {
+ case StructType(fields) => fields.map(_.name)
+ case _ => throw new IllegalArgumentException(f"$fieldName argument must be
a struct")
+ }
+
+ protected def getResultName(resultAttrs: Seq[Attribute]): String =
resultAttrs match {
+ case Seq(attr) => attr.name
+ case _ => throw new IllegalArgumentException("resultAttrs must have
exactly one attribute")
+ }
+
+ def execute(plan: SparkPlan, resultAttrs: Seq[Attribute]): RDD[InternalRow]
+}
+
+/**
+ * DataframePhysicalFunctions are Functions that will be replaced with a
Physical Node for their
+ * evaluation.
+ *
+ * The physical node will transform the input dataframe into the output
dataframe. execute handles
+ * conversion of the RDD[InternalRow] to a DataFrame and back. Each
DataframePhysicalFunction
+ * should implement transformDataframe. The output dataframe should have the
same schema as the
+ * input dataframe, except for the resultAttrs which should be added to the
output dataframe.
+ */
+trait DataframePhysicalFunction extends PhysicalFunction {
+
+ protected def transformDataframe(dataframe: DataFrame, resultAttrs:
Seq[Attribute]): DataFrame
+
+ override def execute(plan: SparkPlan, resultAttrs: Seq[Attribute]):
RDD[InternalRow] = {
+ val df = transformDataframe(
+ Dataset.ofRows(sparkSession, LogicalRDD(plan.output,
plan.execute())(sparkSession)),
+ resultAttrs)
+ df.queryExecution.toRdd
+ }
+
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractGeoStatsFunctions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractGeoStatsFunctions.scala
deleted file mode 100644
index 6b4cf9ccea..0000000000
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractGeoStatsFunctions.scala
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.spark.sql.sedona_sql.optimization
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.sedona_sql.expressions.ST_GeoStatsFunction
-import org.apache.spark.sql.sedona_sql.plans.logical.EvalGeoStatsFunction
-
-import scala.collection.mutable
-
-/**
- * Extracts GeoStats functions from operators, rewriting the query plan so
that the geo-stats
- * functions can be evaluated alone in its own physical executors.
- */
-object ExtractGeoStatsFunctions extends Rule[LogicalPlan] {
- var geoStatsResultCount = 0
-
- private def collectGeoStatsFunctionsFromExpressions(
- expressions: Seq[Expression]): Seq[ST_GeoStatsFunction] = {
- def collectGeoStatsFunctions(expr: Expression): Seq[ST_GeoStatsFunction] =
expr match {
- case expr: ST_GeoStatsFunction => Seq(expr)
- case e => e.children.flatMap(collectGeoStatsFunctions)
- }
- expressions.flatMap(collectGeoStatsFunctions)
- }
-
- def apply(plan: LogicalPlan): LogicalPlan = plan match {
- // SPARK-26293: A subquery will be rewritten into join later, and will go
through this rule
- // eventually. Here we skip subquery, as geo-stats functions only needs to
be extracted once.
- case s: Subquery if s.correlated => plan
- case _ =>
- plan.transformUp {
- case p: EvalGeoStatsFunction => p
- case plan: LogicalPlan => extract(plan)
- }
- }
-
- private def canonicalizeDeterministic(u: ST_GeoStatsFunction) = {
- if (u.deterministic) {
- u.canonicalized.asInstanceOf[ST_GeoStatsFunction]
- } else {
- u
- }
- }
-
- /**
- * Extract all the geo-stats functions from the current operator and
evaluate them before the
- * operator.
- */
- private def extract(plan: LogicalPlan): LogicalPlan = {
- val geoStatsFuncs = plan match {
- case e: EvalGeoStatsFunction =>
- collectGeoStatsFunctionsFromExpressions(e.function.children)
- case _ =>
-
ExpressionSet(collectGeoStatsFunctionsFromExpressions(plan.expressions))
- // ignore the ST_GeoStatsFunction that come from second/third
aggregate, which is not used
- .filter(func => func.references.subsetOf(plan.inputSet))
- .filter(func =>
- plan.children.exists(child =>
func.references.subsetOf(child.outputSet)))
- .toSeq
- .asInstanceOf[Seq[ST_GeoStatsFunction]]
- }
-
- if (geoStatsFuncs.isEmpty) {
- // If there aren't any, we are done.
- plan
- } else {
- // Transform the first geo-stats function we have found. We'll call
extract recursively later
- // to transform the rest.
- val geoStatsFunc = geoStatsFuncs.head
-
- val attributeMap = mutable.HashMap[ST_GeoStatsFunction, Expression]()
- // Rewrite the child that has the input required for the UDF
- val newChildren = plan.children.map { child =>
- if (geoStatsFunc.references.subsetOf(child.outputSet)) {
- geoStatsResultCount += 1
- val resultAttr =
- AttributeReference(f"geoStatsResult$geoStatsResultCount",
geoStatsFunc.dataType)()
- val evaluation = EvalGeoStatsFunction(geoStatsFunc, Seq(resultAttr),
child)
- attributeMap += (canonicalizeDeterministic(geoStatsFunc) ->
resultAttr)
- extract(evaluation) // handle nested geo-stats functions
- } else {
- child
- }
- }
-
- // Replace the geo stats function call with the newly created
geoStatsResult attribute
- val rewritten = plan.withNewChildren(newChildren).transformExpressions {
- case p: ST_GeoStatsFunction =>
attributeMap.getOrElse(canonicalizeDeterministic(p), p)
- }
-
- // extract remaining geo-stats functions recursively
- val newPlan = extract(rewritten)
- if (newPlan.output != plan.output) {
- // Trim away the new UDF value if it was only used for filtering or
something.
- Project(plan.output, newPlan)
- } else {
- newPlan
- }
- }
- }
-}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractPhysicalFunctions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractPhysicalFunctions.scala
new file mode 100644
index 0000000000..9aac6db2b8
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractPhysicalFunctions.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.optimization
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.sedona_sql.expressions.PhysicalFunction
+import org.apache.spark.sql.sedona_sql.plans.logical.EvalPhysicalFunction
+
+import scala.collection.mutable
+
+/**
+ * Extracts Physical functions from operators, rewriting the query plan so
that the functions can
+ * be evaluated alone in its own physical executors.
+ */
+object ExtractPhysicalFunctions extends Rule[LogicalPlan] {
+ private var physicalFunctionResultCount = 0
+
+ private def collectPhysicalFunctionsFromExpressions(
+ expressions: Seq[Expression]): Seq[PhysicalFunction] = {
+ def collectPhysicalFunctions(expr: Expression): Seq[PhysicalFunction] =
expr match {
+ case expr: PhysicalFunction => Seq(expr)
+ case e => e.children.flatMap(collectPhysicalFunctions)
+ }
+ expressions.flatMap(collectPhysicalFunctions)
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan match {
+ // SPARK-26293: A subquery will be rewritten into join later, and will go
through this rule
+ // eventually. Here we skip subquery, as physical functions only needs to
be extracted once.
+ case s: Subquery if s.correlated => plan
+ case _ =>
+ plan.transformUp {
+ case p: EvalPhysicalFunction => p
+ case plan: LogicalPlan => extract(plan)
+ }
+ }
+
+ private def canonicalizeDeterministic(u: PhysicalFunction) = {
+ if (u.deterministic) {
+ u.canonicalized.asInstanceOf[PhysicalFunction]
+ } else {
+ u
+ }
+ }
+
+ /**
+ * Extract all the physical functions from the current operator and evaluate
them before the
+ * operator.
+ */
+ private def extract(plan: LogicalPlan): LogicalPlan = {
+ val physicalFunctions = plan match {
+ case e: EvalPhysicalFunction =>
+ collectPhysicalFunctionsFromExpressions(e.function.children)
+ case _ =>
+
ExpressionSet(collectPhysicalFunctionsFromExpressions(plan.expressions))
+ // ignore the PhysicalFunction that come from second/third
aggregate, which is not used
+ .filter(func => func.references.subsetOf(plan.inputSet))
+ .filter(func =>
+ plan.children.exists(child =>
func.references.subsetOf(child.outputSet)))
+ .toSeq
+ .asInstanceOf[Seq[PhysicalFunction]]
+ }
+
+ if (physicalFunctions.isEmpty) {
+ // If there aren't any, we are done.
+ plan
+ } else {
+ // Transform the first physical function we have found. We'll call
extract recursively later
+ // to transform the rest.
+ val physicalFunction = physicalFunctions.head
+
+ val attributeMap = mutable.HashMap[PhysicalFunction, Expression]()
+ // Rewrite the child that has the input required for the UDF
+ val newChildren = plan.children.map { child =>
+ if (physicalFunction.references.subsetOf(child.outputSet)) {
+ physicalFunctionResultCount += 1
+ val resultAttr =
+ AttributeReference(
+ f"physicalFunctionResult$physicalFunctionResultCount",
+ physicalFunction.dataType)()
+ val evaluation = EvalPhysicalFunction(physicalFunction,
Seq(resultAttr), child)
+ attributeMap += (canonicalizeDeterministic(physicalFunction) ->
resultAttr)
+ extract(evaluation) // handle nested functions
+ } else {
+ child
+ }
+ }
+
+ // Replace the physical function call with the newly created attribute
+ val rewritten = plan.withNewChildren(newChildren).transformExpressions {
+ case p: PhysicalFunction =>
attributeMap.getOrElse(canonicalizeDeterministic(p), p)
+ }
+
+ // extract remaining physical functions recursively
+ val newPlan = extract(rewritten)
+ if (newPlan.output != plan.output) {
+ // Trim away the new UDF value if it was only used for filtering or
something.
+ Project(plan.output, newPlan)
+ } else {
+ newPlan
+ }
+ }
+ }
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalGeoStatsFunction.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalPhysicalFunction.scala
similarity index 97%
rename from
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalGeoStatsFunction.scala
rename to
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalPhysicalFunction.scala
index 8daeb0c304..9371d0c12d 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalGeoStatsFunction.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalPhysicalFunction.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.UnaryNode
-case class EvalGeoStatsFunction(
+case class EvalPhysicalFunction(
function: Expression,
resultAttrs: Seq[Attribute],
child: LogicalPlan)
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionExec.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/physical/function/EvalPhysicalFunctionExec.scala
similarity index 87%
rename from
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionExec.scala
rename to
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/physical/function/EvalPhysicalFunctionExec.scala
index fbecb69ec4..a99630dc0c 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionExec.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/physical/function/EvalPhysicalFunctionExec.scala
@@ -16,16 +16,16 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.spark.sql.sedona_sql.strategy.geostats
+package org.apache.spark.sql.sedona_sql.strategy.physical.function
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
-import org.apache.spark.sql.sedona_sql.expressions.ST_GeoStatsFunction
+import org.apache.spark.sql.sedona_sql.expressions.PhysicalFunction
-case class EvalGeoStatsFunctionExec(
- function: ST_GeoStatsFunction,
+case class EvalPhysicalFunctionExec(
+ function: PhysicalFunction,
child: SparkPlan,
resultAttrs: Seq[Attribute])
extends UnaryExecNode {
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionStrategy.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/physical/function/EvalPhysicalFunctionStrategy.scala
similarity index 73%
rename from
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionStrategy.scala
rename to
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/physical/function/EvalPhysicalFunctionStrategy.scala
index 4c10b747a6..a159badd38 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionStrategy.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/physical/function/EvalPhysicalFunctionStrategy.scala
@@ -16,21 +16,21 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.spark.sql.sedona_sql.strategy.geostats
+package org.apache.spark.sql.sedona_sql.strategy.physical.function
import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.sedona_sql.plans.logical.EvalGeoStatsFunction
+import org.apache.spark.sql.sedona_sql.plans.logical.EvalPhysicalFunction
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.sedona_sql.expressions.ST_GeoStatsFunction
+import org.apache.spark.sql.sedona_sql.expressions.PhysicalFunction
-class EvalGeoStatsFunctionStrategy(spark: SparkSession) extends Strategy {
+class EvalPhysicalFunctionStrategy(spark: SparkSession) extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
plan match {
- case EvalGeoStatsFunction(function: ST_GeoStatsFunction, resultAttrs,
child) =>
- EvalGeoStatsFunctionExec(function, planLater(child), resultAttrs) ::
Nil
+ case EvalPhysicalFunction(function: PhysicalFunction, resultAttrs,
child) =>
+ EvalPhysicalFunctionExec(function, planLater(child), resultAttrs) ::
Nil
case _ => Nil
}
}