This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 76093b5a2d [SEDONA-708] Separate Catalog implementation into
AbstractCatalog and Catalog (#1798)
76093b5a2d is described below
commit 76093b5a2d3058484048a2c467338f51fce51498
Author: James Willis <[email protected]>
AuthorDate: Mon Feb 10 18:29:22 2025 -0800
[SEDONA-708] Separate Catalog implementation into AbstractCatalog and
Catalog (#1798)
---
.../apache/sedona/sql/UDF/AbstractCatalog.scala | 77 ++++++++++++++++++++++
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 54 +--------------
2 files changed, 80 insertions(+), 51 deletions(-)
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala
new file mode 100644
index 0000000000..3ad579c38c
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.UDF
+
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes,
Expression, ExpressionInfo, Literal}
+import org.apache.spark.sql.expressions.Aggregator
+import org.locationtech.jts.geom.Geometry
+
+import scala.reflect.ClassTag
+
+abstract class AbstractCatalog {
+
+ type FunctionDescription = (FunctionIdentifier, ExpressionInfo,
FunctionBuilder)
+
+ val expressions: Seq[FunctionDescription]
+
+ val aggregateExpressions: Seq[Aggregator[Geometry, _, _]]
+
+ protected def function[T <: Expression: ClassTag](defaultArgs: Any*):
FunctionDescription = {
+ val classTag = implicitly[ClassTag[T]]
+ val constructor =
classTag.runtimeClass.getConstructor(classOf[Seq[Expression]])
+ val functionName = classTag.runtimeClass.getSimpleName
+ val functionIdentifier = FunctionIdentifier(functionName)
+ val expressionInfo = new ExpressionInfo(
+ classTag.runtimeClass.getCanonicalName,
+ functionIdentifier.database.orNull,
+ functionName)
+
+ def functionBuilder(expressions: Seq[Expression]): T = {
+ val expr = constructor.newInstance(expressions).asInstanceOf[T]
+ expr match {
+ case e: ExpectsInputTypes =>
+ val numParameters = e.inputTypes.size
+ val numArguments = expressions.size
+ if (numParameters == numArguments || numParameters ==
expr.children.size) expr
+ else {
+ val numUnspecifiedArgs = numParameters - numArguments
+ if (numUnspecifiedArgs > 0) {
+ if (numUnspecifiedArgs <= defaultArgs.size) {
+ val args =
+ expressions ++
defaultArgs.takeRight(numUnspecifiedArgs).map(Literal(_))
+ constructor.newInstance(args).asInstanceOf[T]
+ } else {
+ throw new IllegalArgumentException(s"function $functionName
takes at least " +
+ s"${numParameters - defaultArgs.size} argument(s),
$numArguments argument(s) specified")
+ }
+ } else {
+ throw new IllegalArgumentException(
+ s"function $functionName takes at most " +
+ s"$numParameters argument(s), $numArguments argument(s)
specified")
+ }
+ }
+ case _ => expr
+ }
+ }
+
+ (functionIdentifier, expressionInfo, functionBuilder)
+ }
+}
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 0bffa54baf..16c393cdbc 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -18,9 +18,6 @@
*/
package org.apache.sedona.sql.UDF
-import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes,
Expression, ExpressionInfo, Literal}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.sedona_sql.expressions.{ST_InterpolatePoint, _}
import org.apache.spark.sql.sedona_sql.expressions.collect.ST_Collect
@@ -29,13 +26,10 @@ import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.operation.buffer.BufferParameters
import scala.collection.mutable.ListBuffer
-import scala.reflect.ClassTag
-object Catalog {
+object Catalog extends AbstractCatalog {
- type FunctionDescription = (FunctionIdentifier, ExpressionInfo,
FunctionBuilder)
-
- val expressions: Seq[FunctionDescription] = Seq(
+ override val expressions: Seq[FunctionDescription] = Seq(
// Expression for vectors
function[GeometryType](),
function[ST_LabelPoint](),
@@ -349,52 +343,10 @@ object Catalog {
function[ST_BinaryDistanceBandColumn](),
function[ST_WeightedDistanceBandColumn]())
- // Aggregate functions with Geometry as buffer
- val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] =
+ val aggregateExpressions: Seq[Aggregator[Geometry, _, _]] =
Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr)
// Aggregate functions with List as buffer
val aggregateExpressions2: Seq[Aggregator[Geometry, ListBuffer[Geometry],
Geometry]] =
Seq(new ST_Union_Aggr())
-
- private def function[T <: Expression: ClassTag](defaultArgs: Any*):
FunctionDescription = {
- val classTag = implicitly[ClassTag[T]]
- val constructor =
classTag.runtimeClass.getConstructor(classOf[Seq[Expression]])
- val functionName = classTag.runtimeClass.getSimpleName
- val functionIdentifier = FunctionIdentifier(functionName)
- val expressionInfo = new ExpressionInfo(
- classTag.runtimeClass.getCanonicalName,
- functionIdentifier.database.orNull,
- functionName)
-
- def functionBuilder(expressions: Seq[Expression]): T = {
- val expr = constructor.newInstance(expressions).asInstanceOf[T]
- expr match {
- case e: ExpectsInputTypes =>
- val numParameters = e.inputTypes.size
- val numArguments = expressions.size
- if (numParameters == numArguments) expr
- else {
- val numUnspecifiedArgs = numParameters - numArguments
- if (numUnspecifiedArgs > 0) {
- if (numUnspecifiedArgs <= defaultArgs.size) {
- val args =
- expressions ++
defaultArgs.takeRight(numUnspecifiedArgs).map(Literal(_))
- constructor.newInstance(args).asInstanceOf[T]
- } else {
- throw new IllegalArgumentException(s"function $functionName
takes at least " +
- s"${numParameters - defaultArgs.size} argument(s),
$numArguments argument(s) specified")
- }
- } else {
- throw new IllegalArgumentException(
- s"function $functionName takes at most " +
- s"$numParameters argument(s), $numArguments argument(s)
specified")
- }
- }
- case _ => expr
- }
- }
-
- (functionIdentifier, expressionInfo, functionBuilder)
- }
}