This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new aede0eb4dc [SEDONA-711] Add Geography user-defined type (#1828)
aede0eb4dc is described below
commit aede0eb4dcc1e03b27287c0354c79eb56a6ba3c9
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Tue Feb 25 06:41:41 2025 +0800
[SEDONA-711] Add Geography user-defined type (#1828)
* start basic object
* compiling udt class
* maybe actually register
* add one input and one output function
* maybe one input and one output function
* maybe builds
* some possible Python requirements
* format
* maybe a few more references to geography
* remove word
* Making ST_GeogFromWKT, ST_AsEWKT, ST_AsEWKB work properly
* Add geography serde tests and dataframe api tests for python binding
---------
Co-authored-by: Dewey Dunnington <[email protected]>
---
.../org/apache/sedona/common/Constructors.java | 5 +
.../java/org/apache/sedona/common/Functions.java | 9 ++
.../sedona/common/geometryObjects/Geography.java | 22 ++--
.../sedona/common/geometrySerde/GeometrySerde.java | 13 +-
python/sedona/core/geom/geography.py | 25 ++++
python/sedona/register/java_libs.py | 1 +
python/sedona/spark/__init__.py | 2 +-
python/sedona/sql/st_constructors.py | 16 +++
python/sedona/sql/types.py | 26 ++++
python/sedona/utils/prep.py | 9 ++
python/tests/sql/test_dataframe_api.py | 7 ++
python/tests/sql/test_geography.py | 46 +++++++
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 1 +
.../spark/sql/sedona_sql/UDT/GeographyUDT.scala | 61 ++++++++++
.../sql/sedona_sql/UDT/UdtRegistratorWrapper.scala | 2 +
.../sql/sedona_sql/expressions/Constructors.scala | 14 +++
.../sedona_sql/expressions/FunctionResolver.scala | 135 +++++++++++++++++++++
.../sql/sedona_sql/expressions/Functions.scala | 9 +-
.../expressions/InferredExpression.scala | 42 +++++--
.../sql/sedona_sql/expressions/implicits.scala | 18 +++
.../sedona_sql/expressions/st_constructors.scala | 6 +
.../apache/sedona/sql/FunctionResolverSuite.scala | 99 +++++++++++++++
.../apache/sedona/sql/constructorTestScala.scala | 9 ++
.../apache/sedona/sql/dataFrameAPITestScala.scala | 18 ++-
.../org/apache/sedona/sql/functionTestScala.scala | 18 ++-
25 files changed, 589 insertions(+), 24 deletions(-)
diff --git a/common/src/main/java/org/apache/sedona/common/Constructors.java
b/common/src/main/java/org/apache/sedona/common/Constructors.java
index 3cd4729243..b53e8162dd 100644
--- a/common/src/main/java/org/apache/sedona/common/Constructors.java
+++ b/common/src/main/java/org/apache/sedona/common/Constructors.java
@@ -22,6 +22,7 @@ import java.io.IOException;
import javax.xml.parsers.ParserConfigurationException;
import org.apache.sedona.common.enums.FileDataSplitter;
import org.apache.sedona.common.enums.GeometryType;
+import org.apache.sedona.common.geometryObjects.Geography;
import org.apache.sedona.common.utils.FormatUtils;
import org.apache.sedona.common.utils.GeoHashDecoder;
import org.locationtech.jts.geom.*;
@@ -44,6 +45,10 @@ public class Constructors {
return new WKTReader(geometryFactory).read(wkt);
}
+ public static Geography geogFromWKT(String wkt, int srid) throws
ParseException {
+ return new Geography(geomFromWKT(wkt, srid));
+ }
+
public static Geometry geomFromEWKT(String ewkt) throws ParseException {
if (ewkt == null) {
return null;
diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java
b/common/src/main/java/org/apache/sedona/common/Functions.java
index b5a181aa29..c4bcb67086 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -29,6 +29,7 @@ import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sedona.common.geometryObjects.Circle;
+import org.apache.sedona.common.geometryObjects.Geography;
import org.apache.sedona.common.sphere.Spheroid;
import org.apache.sedona.common.subDivide.GeometrySubDivider;
import org.apache.sedona.common.utils.*;
@@ -776,6 +777,10 @@ public class Functions {
return GeomUtils.getEWKT(geometry);
}
+ public static String asEWKT(Geography geography) {
+ return asEWKT(geography.getGeometry());
+ }
+
public static String asWKT(Geometry geometry) {
return GeomUtils.getWKT(geometry);
}
@@ -784,6 +789,10 @@ public class Functions {
return GeomUtils.getEWKB(geometry);
}
+ public static byte[] asEWKB(Geography geography) {
+ return asEWKB(geography.getGeometry());
+ }
+
public static String asHexEWKB(Geometry geom, String endian) {
if (endian.equalsIgnoreCase("NDR")) {
return GeomUtils.getHexEWKB(geom, ByteOrderValues.LITTLE_ENDIAN);
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
b/common/src/main/java/org/apache/sedona/common/geometryObjects/Geography.java
similarity index 66%
copy from
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
copy to
common/src/main/java/org/apache/sedona/common/geometryObjects/Geography.java
index a96d15c008..f1eba5deca 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
+++
b/common/src/main/java/org/apache/sedona/common/geometryObjects/Geography.java
@@ -16,16 +16,22 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.spark.sql.sedona_sql.UDT
+package org.apache.sedona.common.geometryObjects;
-import org.apache.spark.sql.types.UDTRegistration
-import org.locationtech.jts.geom.Geometry
-import org.locationtech.jts.index.SpatialIndex
+import org.locationtech.jts.geom.Geometry;
-object UdtRegistratorWrapper {
+public class Geography {
+ private final Geometry geometry;
- def registerAll(): Unit = {
- UDTRegistration.register(classOf[Geometry].getName,
classOf[GeometryUDT].getName)
- UDTRegistration.register(classOf[SpatialIndex].getName,
classOf[IndexUDT].getName)
+ public Geography(Geometry geometry) {
+ this.geometry = geometry;
+ }
+
+ public Geometry getGeometry() {
+ return this.geometry;
+ }
+
+ public String toString() {
+ return this.geometry.toText();
}
}
diff --git
a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerde.java
b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerde.java
index 7b7399cdb8..5475d37c7b 100644
---
a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerde.java
+++
b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerde.java
@@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import java.io.Serializable;
import org.apache.sedona.common.geometryObjects.Circle;
+import org.apache.sedona.common.geometryObjects.Geography;
import org.locationtech.jts.geom.Envelope;
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.geom.GeometryCollection;
@@ -36,7 +37,7 @@ import org.locationtech.jts.geom.Polygon;
* Provides methods to efficiently serialize and deserialize geometry types.
*
* <p>Supports Point, LineString, Polygon, MultiPoint, MultiLineString,
MultiPolygon,
- * GeometryCollection, Circle and Envelope types.
+ * GeometryCollection, Circle, Envelope, and Geography types.
*
* <p>First byte contains {@link Type#id}. Then go type-specific bytes,
followed by user-data
* attached to the geometry.
@@ -63,6 +64,9 @@ public class GeometrySerde extends Serializer implements
Serializable {
out.writeDouble(envelope.getMaxX());
out.writeDouble(envelope.getMinY());
out.writeDouble(envelope.getMaxY());
+ } else if (object instanceof Geography) {
+ writeType(out, Type.GEOGRAPHY);
+ writeGeometry(kryo, out, ((Geography) object).getGeometry());
} else {
throw new UnsupportedOperationException(
"Cannot serialize object of type " + object.getClass().getName());
@@ -118,6 +122,10 @@ public class GeometrySerde extends Serializer implements
Serializable {
return new Envelope();
}
}
+ case GEOGRAPHY:
+ {
+ return new Geography(readGeometry(kryo, input));
+ }
default:
throw new UnsupportedOperationException(
"Cannot deserialize object of type " + geometryType);
@@ -145,7 +153,8 @@ public class GeometrySerde extends Serializer implements
Serializable {
private enum Type {
SHAPE(0),
CIRCLE(1),
- ENVELOPE(2);
+ ENVELOPE(2),
+ GEOGRAPHY(3);
private final int id;
diff --git a/python/sedona/core/geom/geography.py
b/python/sedona/core/geom/geography.py
new file mode 100644
index 0000000000..5764e935bf
--- /dev/null
+++ b/python/sedona/core/geom/geography.py
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from shapely.geometry.base import BaseGeometry
+
+
+class Geography:
+ geometry: BaseGeometry
+
+ def __init__(self, geometry: BaseGeometry):
+ self.geometry = geometry
diff --git a/python/sedona/register/java_libs.py
b/python/sedona/register/java_libs.py
index 8d1681d15e..466827cbf4 100644
--- a/python/sedona/register/java_libs.py
+++ b/python/sedona/register/java_libs.py
@@ -26,6 +26,7 @@ class SedonaJvmLib(Enum):
KNNQuery = "org.apache.sedona.core.spatialOperator.KNNQuery"
RangeQuery = "org.apache.sedona.core.spatialOperator.RangeQuery"
Envelope = "org.locationtech.jts.geom.Envelope"
+ Geography = "org.apache.sedona.common.geometryObjects.Geography"
GeoSerializerData = (
"org.apache.sedona.python.wrapper.adapters.GeoSparkPythonConverter"
)
diff --git a/python/sedona/spark/__init__.py b/python/sedona/spark/__init__.py
index 50d1d1131e..8f4c4f24ed 100644
--- a/python/sedona/spark/__init__.py
+++ b/python/sedona/spark/__init__.py
@@ -42,7 +42,7 @@ from sedona.sql.st_aggregates import *
from sedona.sql.st_constructors import *
from sedona.sql.st_functions import *
from sedona.sql.st_predicates import *
-from sedona.sql.types import GeometryType, RasterType
+from sedona.sql.types import GeometryType, GeographyType, RasterType
from sedona.utils import KryoSerializer, SedonaKryoRegistrator
from sedona.utils.adapter import Adapter
from sedona.utils.geoarrow import dataframe_to_arrow
diff --git a/python/sedona/sql/st_constructors.py
b/python/sedona/sql/st_constructors.py
index f636edc0c2..c37f302af1 100644
--- a/python/sedona/sql/st_constructors.py
+++ b/python/sedona/sql/st_constructors.py
@@ -176,6 +176,22 @@ def ST_GeomFromWKT(
return _call_constructor_function("ST_GeomFromWKT", args)
+@validate_argument_types
+def ST_GeogFromWKT(
+ wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None
+) -> Column:
+ """Generate a geography column from a Well-Known Text (WKT) string column.
+
+ :param wkt: WKT string column to generate from.
+ :type wkt: ColumnOrName
+ :return: Geography column representing the WKT string.
+ :rtype: Column
+ """
+ args = (wkt) if srid is None else (wkt, srid)
+
+ return _call_constructor_function("ST_GeogFromWKT", args)
+
+
@validate_argument_types
def ST_GeomFromEWKT(ewkt: ColumnOrName) -> Column:
"""Generate a geometry column from a OGC Extended Well-Known Text (WKT)
string column.
diff --git a/python/sedona/sql/types.py b/python/sedona/sql/types.py
index c966d451ca..1b8a9ee03b 100644
--- a/python/sedona/sql/types.py
+++ b/python/sedona/sql/types.py
@@ -33,6 +33,7 @@ else:
SedonaRaster = None
from ..utils import geometry_serde
+from ..core.geom.geography import Geography
class GeometryType(UserDefinedType):
@@ -60,6 +61,31 @@ class GeometryType(UserDefinedType):
return "org.apache.spark.sql.sedona_sql.UDT.GeometryUDT"
+class GeographyType(UserDefinedType):
+
+ @classmethod
+ def sqlType(cls):
+ return BinaryType()
+
+ def serialize(self, obj):
+ return geometry_serde.serialize(obj.geometry)
+
+ def deserialize(self, datum):
+ geom, offset = geometry_serde.deserialize(datum)
+ return Geography(geom)
+
+ @classmethod
+ def module(cls):
+ return "sedona.sql.types"
+
+ def needConversion(self):
+ return True
+
+ @classmethod
+ def scalaUDT(cls):
+ return "org.apache.spark.sql.sedona_sql.UDT.GeographyUDT"
+
+
class RasterType(UserDefinedType):
@classmethod
diff --git a/python/sedona/utils/prep.py b/python/sedona/utils/prep.py
index c9528300cf..8cabfa05c2 100644
--- a/python/sedona/utils/prep.py
+++ b/python/sedona/utils/prep.py
@@ -28,6 +28,8 @@ from shapely.geometry import (
)
from shapely.geometry.base import BaseGeometry
+from ..core.geom.geography import Geography
+
def assign_all() -> bool:
geoms = [
@@ -41,6 +43,7 @@ def assign_all() -> bool:
]
assign_udt_shapely_objects(geoms=geoms)
assign_user_data_to_shapely_objects(geoms=geoms)
+ assign_udt_geography()
return True
@@ -55,3 +58,9 @@ def assign_udt_shapely_objects(geoms:
List[type(BaseGeometry)]) -> bool:
def assign_user_data_to_shapely_objects(geoms: List[type(BaseGeometry)]) ->
bool:
for geom in geoms:
geom.getUserData = lambda geom_instance: geom_instance.userData
+
+
+def assign_udt_geography():
+ from sedona.sql.types import GeographyType
+
+ Geography.__UDT__ = GeographyType()
diff --git a/python/tests/sql/test_dataframe_api.py
b/python/tests/sql/test_dataframe_api.py
index de65f6f0f4..99d1f17754 100644
--- a/python/tests/sql/test_dataframe_api.py
+++ b/python/tests/sql/test_dataframe_api.py
@@ -26,6 +26,7 @@ from pyspark.sql import functions as f
from shapely.geometry.base import BaseGeometry
from tests.test_base import TestBase
+from sedona.core.geom.geography import Geography
from sedona.sql import st_aggregates as sta
from sedona.sql import st_constructors as stc
from sedona.sql import st_functions as stf
@@ -85,6 +86,8 @@ test_configurations = [
(stc.ST_GeomFromWKT, ("wkt",), "linestring_wkt", "", "LINESTRING (1 2, 3
4)"),
(stc.ST_GeomFromWKT, ("wkt", 4326), "linestring_wkt", "", "LINESTRING (1
2, 3 4)"),
(stc.ST_GeomFromEWKT, ("ewkt",), "linestring_ewkt", "", "LINESTRING (1 2,
3 4)"),
+ (stc.ST_GeogFromWKT, ("wkt",), "linestring_wkt", "", "LINESTRING (1 2, 3
4)"),
+ (stc.ST_GeogFromWKT, ("wkt", 4326), "linestring_wkt", "", "LINESTRING (1
2, 3 4)"),
(stc.ST_LineFromText, ("wkt",), "linestring_wkt", "", "LINESTRING (1 2, 3
4)"),
(
stc.ST_LineFromWKB,
@@ -1230,6 +1233,7 @@ wrong_type_configurations = [
(stc.ST_LinestringFromWKB, (None,)),
(stc.ST_GeomFromEWKB, (None,)),
(stc.ST_GeomFromWKT, (None,)),
+ (stc.ST_GeogFromWKT, (None,)),
(stc.ST_GeometryFromText, (None,)),
(stc.ST_LineFromText, (None,)),
(stc.ST_LineStringFromText, (None, "")),
@@ -1711,6 +1715,9 @@ class TestDataFrameAPI(TestBase):
if isinstance(actual_result, BaseGeometry):
self.assert_geometry_almost_equal(expected_result, actual_result)
return
+ elif isinstance(actual_result, Geography):
+ self.assert_geometry_almost_equal(expected_result,
actual_result.geometry)
+ return
elif isinstance(actual_result, bytearray):
actual_result = actual_result.hex()
elif isinstance(actual_result, Row):
diff --git a/python/tests/sql/test_geography.py
b/python/tests/sql/test_geography.py
new file mode 100644
index 0000000000..3105b6a8b4
--- /dev/null
+++ b/python/tests/sql/test_geography.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+from pyspark.sql.functions import expr
+from pyspark.sql.types import StructType
+from shapely.wkt import loads as wkt_loads
+from sedona.core.geom.geography import Geography
+from sedona.sql.types import GeographyType
+from tests.test_base import TestBase
+
+
+class TestGeography(TestBase):
+
+ def test_deserialize_geography(self):
+ """Test serialization and deserialization of Geography objects"""
+ geog_df = self.spark.range(0, 10).withColumn(
+ "geog", expr("ST_GeogFromWKT(CONCAT('POINT (', id, ' ', id + 1,
')'))")
+ )
+ rows = geog_df.collect()
+ assert len(rows) == 10
+ for row in rows:
+ id = row["id"]
+ geog = row["geog"]
+ assert geog.geometry.wkt == f"POINT ({id} {id + 1})"
+
+ def test_serialize_geography(self):
+ wkt = "MULTIPOLYGON (((10 10, 20 20, 20 10, 10 10)), ((-10 -10, -20
-20, -20 -10, -10 -10)))"
+ geog = Geography(wkt_loads(wkt))
+ schema = StructType().add("geog", GeographyType())
+ returned_geog = self.spark.createDataFrame([(geog,)],
schema).take(1)[0][0]
+ assert geog.geometry.equals(returned_geog.geometry)
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index af51a825f8..57fe3893e0 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -40,6 +40,7 @@ object Catalog extends AbstractCatalog {
function[ST_GeomFromText](0),
function[ST_GeometryFromText](0),
function[ST_LineFromText](),
+ function[ST_GeogFromWKT](0),
function[ST_GeomFromWKT](0),
function[ST_GeomFromEWKT](),
function[ST_GeomFromWKB](),
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala
new file mode 100644
index 0000000000..18ad4dfebd
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeographyUDT.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.UDT
+
+import org.apache.sedona.common.geometrySerde.GeometrySerializer;
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
+import org.apache.spark.sql.types._
+import org.json4s.JsonDSL._
+import org.json4s.JsonAST.JValue
+import org.apache.sedona.common.geometryObjects.Geography;
+
+class GeographyUDT extends UserDefinedType[Geography] {
+ override def sqlType: DataType = BinaryType
+
+ override def pyUDT: String = "sedona.sql.types.GeographyType"
+
+ override def userClass: Class[Geography] = classOf[Geography]
+
+ override def serialize(obj: Geography): Array[Byte] =
+ GeometrySerializer.serialize(obj.getGeometry)
+
+ override def deserialize(datum: Any): Geography = {
+ datum match {
+ case value: Array[Byte] => new
Geography(GeometrySerializer.deserialize(value))
+ }
+ }
+
+ override private[sql] def jsonValue: JValue = {
+ super.jsonValue mapField {
+ case ("class", _) => "class" -> this.getClass.getName.stripSuffix("$")
+ case other: Any => other
+ }
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case _: UserDefinedType[_] => other.isInstanceOf[GeographyUDT]
+ case _ => false
+ }
+
+ override def hashCode(): Int = userClass.hashCode()
+}
+
+case object GeographyUDT
+ extends org.apache.spark.sql.sedona_sql.UDT.GeographyUDT
+ with scala.Serializable
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
index a96d15c008..e581625556 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
@@ -20,12 +20,14 @@ package org.apache.spark.sql.sedona_sql.UDT
import org.apache.spark.sql.types.UDTRegistration
import org.locationtech.jts.geom.Geometry
+import org.apache.sedona.common.geometryObjects.Geography;
import org.locationtech.jts.index.SpatialIndex
object UdtRegistratorWrapper {
def registerAll(): Unit = {
UDTRegistration.register(classOf[Geometry].getName,
classOf[GeometryUDT].getName)
+ UDTRegistration.register(classOf[Geography].getName,
classOf[GeographyUDT].getName)
UDTRegistration.register(classOf[SpatialIndex].getName,
classOf[IndexUDT].getName)
}
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
index d787ff152b..cf2a05d3d3 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
@@ -98,6 +98,20 @@ case class ST_GeomFromWKT(inputExpressions: Seq[Expression])
}
}
+/**
+ * Return a Geography from a WKT string
+ *
+ * @param inputExpressions
+ * This function takes a geometry string and a srid. The string format must
be WKT.
+ */
+case class ST_GeogFromWKT(inputExpressions: Seq[Expression])
+ extends InferredExpression(Constructors.geogFromWKT _) {
+
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
{
+ copy(inputExpressions = newChildren)
+ }
+}
+
/**
* Return a Geometry from a OGC Extended WKT string
*
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/FunctionResolver.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/FunctionResolver.scala
new file mode 100644
index 0000000000..eb4c9f4f64
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/FunctionResolver.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion
+import org.apache.spark.sql.catalyst.expressions.Expression
+
+/**
+ * A utility object for resolving functions based on input argument types. See
+ * [[FunctionResolver.resolveFunction]] for details.
+ */
+object FunctionResolver {
+
+ /**
+ * A utility function for selecting the function to be called from multiple
function overloads.
+ * The overall rule is:
+ *
+ * 1. If any of the input argument cannot be coerced to the parameter
type, don't select that
+ * function.
+ * 1. If there is a perfect match, return it.
+ * 1. If there is no perfect match, return the function with the fewest
coerced inputs.
+ */
+ def resolveFunction(
+ expressions: Seq[Expression],
+ functionOverloads: Seq[InferrableFunction]): InferrableFunction = {
+
+ // If there's only one overload matches the arity of the expression, we'll
simply use it.
+ // The SQL analyzer will handle the ImplicitCastInputTypes trait and do
type checking for us.
+ val functionsWithMatchingArity =
+ functionOverloads.filter(_.sparkInputTypes.length == expressions.length)
+ if (functionsWithMatchingArity.length == 1) {
+ return functionsWithMatchingArity.head
+ } else if (functionsWithMatchingArity.isEmpty) {
+ throw new IllegalArgumentException(
+ s"No overloaded function accepts ${expressions.length} arguments")
+ }
+
+ val functionWithMatchResults = functionsWithMatchingArity.map { function =>
+ val matchResult = matchFunctionToInputTypes(expressions, function)
+ (function, matchResult)
+ }
+ // If there's a perfect match, return it; otherwise, return the function
with the fewest
+ // coerced inputs.
+ val bestMatch = functionWithMatchResults.minBy { case (_, matchResult) =>
+ matchResult match {
+ case NotMatch => Int.MaxValue
+ case PerfectMatch => 0
+ case CoercedMatch(coercedInputs) => coercedInputs
+ }
+ }
+ bestMatch match {
+ case (_, NotMatch) =>
+ val candidateTypesMsg = functionsWithMatchingArity
+ .map { function =>
+ " (" + function.sparkInputTypes.mkString(", ") + ")"
+ }
+ .mkString("\n")
+ throw new IllegalArgumentException(
+ "Types of arguments does not match with function parameters. " +
+ "Candidates are: \n" + candidateTypesMsg)
+ case (function, _) =>
+ // Make sure that there's no ambiguity in the best match.
+ val ambiguousMatches = functionWithMatchResults.filter { case (_,
matchResult) =>
+ matchResult == bestMatch._2
+ }
+ if (ambiguousMatches.length == 1) {
+ function
+ } else {
+ // Detected ambiguous matches, throw exception
+ val candidateTypesMsg = ambiguousMatches
+ .map { case (function, _) =>
+ " (" + function.sparkInputTypes.mkString(", ") + ")"
+ }
+ .mkString("\n")
+ throw new IllegalArgumentException(
+ "Ambiguous function call. Candidates are: \n" + candidateTypesMsg)
+ }
+ }
+ }
+
+ private sealed trait MatchResult
+ private case object NotMatch extends MatchResult
+ private case object PerfectMatch extends MatchResult
+ private case class CoercedMatch(coercedInputs: Int) extends MatchResult
+
+ private def matchFunctionToInputTypes(
+ expressions: Seq[Expression],
+ function: InferrableFunction): MatchResult = {
+ if (expressions.length != function.sparkInputTypes.length) {
+ NotMatch
+ } else {
+ val inputMatchResults =
+ expressions.zip(function.sparkInputTypes).map { case (expr,
parameterType) =>
+ val argumentType = expr.dataType
+ if (parameterType.acceptsType(argumentType)) {
+ PerfectMatch
+ } else {
+ TypeCoercion.implicitCast(expr, parameterType) match {
+ case Some(_) => CoercedMatch(1)
+ case None => NotMatch
+ }
+ }
+ }
+ if (inputMatchResults.contains(NotMatch)) {
+ NotMatch
+ } else {
+ val numCoercedInputs = inputMatchResults.map {
+ case CoercedMatch(coercedInputs) => coercedInputs
+ case _ => 0
+ }.sum
+ if (numCoercedInputs == 0) {
+ PerfectMatch
+ } else {
+ CoercedMatch(numCoercedInputs)
+ }
+ }
+ }
+ }
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 705e37758b..8fb84b3b61 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -18,6 +18,7 @@
*/
package org.apache.spark.sql.sedona_sql.expressions
+import org.apache.sedona.common.geometryObjects.Geography
import org.apache.sedona.common.{Functions, FunctionsGeoTools}
import org.apache.sedona.common.sphere.{Haversine, Spheroid}
import org.apache.sedona.common.utils.{InscribedCircle, ValidDetail}
@@ -506,7 +507,9 @@ case class ST_AsBinary(inputExpressions: Seq[Expression])
}
case class ST_AsEWKB(inputExpressions: Seq[Expression])
- extends InferredExpression(Functions.asEWKB _) {
+ extends InferredExpression(
+ (geom: Geometry) => Functions.asEWKB(geom),
+ (geog: Geography) => Functions.asEWKB(geog)) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
{
copy(inputExpressions = newChildren)
@@ -1244,7 +1247,9 @@ case class ST_Force_2D(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_AsEWKT(inputExpressions: Seq[Expression])
- extends InferredExpression(Functions.asEWKT _) {
+ extends InferredExpression(
+ (geom: Geometry) => Functions.asEWKT(geom),
+ (geog: Geography) => Functions.asEWKT(geog)) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
{
copy(inputExpressions = newChildren)
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
index 935c2d5e3d..68290b6983 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression,
ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, GeographyUDT}
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType,
DataType, DataTypes, DoubleType, IntegerType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.Geometry
@@ -34,6 +34,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag
import scala.reflect.runtime.universe.Type
import scala.reflect.runtime.universe.typeOf
+import org.apache.sedona.common.geometryObjects.Geography
/**
* Custom exception to include the input row and the original exception
message.
@@ -61,15 +62,16 @@ abstract class InferredExpression(fSeq: InferrableFunction*)
lazy val f: InferrableFunction = fSeq match {
// If there is only one function, simply use it and let
org.apache.sedona.sql.UDF.Catalog handle default arguments.
case Seq(f) => f
- // If there are multiple overloaded functions, find the one with the same
number of arguments as the input
- // expressions. Please note that the Catalog won't be able to handle
default arguments in this case. We'll
- // move default argument handling from Catalog to this class in the future.
+ // If there are multiple overloaded functions, resolve the called function
by comparing the types of input
+ // arguments with the types of function parameters.
case _ =>
- fSeq.find(f => f.sparkInputTypes.size == inputExpressions.size) match {
- case Some(f) => f
- case None =>
+ try {
+ FunctionResolver.resolveFunction(inputExpressions, fSeq)
+ } catch {
+ case e: IllegalArgumentException =>
throw new IllegalArgumentException(
- s"No overloaded function ${getClass.getName} has
${inputExpressions.size} arguments")
+ s"Cannot resolve function ${getClass.getName} with input
arguments: " +
+ inputExpressions.map(_.dataType).mkString(", ") + s":
${e.getMessage}")
}
}
@@ -160,6 +162,10 @@ object InferrableType {
new InferrableType[Geometry] {}
implicit val geometryArrayInstance: InferrableType[Array[Geometry]] =
new InferrableType[Array[Geometry]] {}
+ implicit val geographyInstance: InferrableType[Geography] =
+ new InferrableType[Geography] {}
+ implicit val geographyArrayInstance: InferrableType[Array[Geography]] =
+ new InferrableType[Array[Geography]] {}
implicit val javaDoubleInstance: InferrableType[java.lang.Double] =
new InferrableType[java.lang.Double] {}
implicit val javaIntegerInstance: InferrableType[java.lang.Integer] =
@@ -200,6 +206,8 @@ object InferredTypes {
def buildArgumentExtractor(t: Type): Expression => InternalRow => Any = {
if (t =:= typeOf[Geometry]) { expr => input =>
expr.toGeometry(input)
+ } else if (t =:= typeOf[Geography]) { expr => input =>
+ expr.toGeography(input)
} else if (t =:= typeOf[Array[Geometry]]) { expr => input =>
expr.toGeometryArray(input)
} else if (InferredRasterExpression.isRasterType(t)) {
@@ -230,6 +238,12 @@ object InferredTypes {
} else {
null
}
+ } else if (t =:= typeOf[Geography]) { output =>
+ if (output != null) {
+ output.asInstanceOf[Geography].toGenericArrayData
+ } else {
+ null
+ }
} else if (InferredRasterExpression.isRasterType(t)) {
InferredRasterExpression.rasterSerializer
} else if (t =:= typeOf[String]) { output =>
@@ -259,6 +273,14 @@ object InferredTypes {
} else {
null
}
+ } else if (t =:= typeOf[Array[Geography]] || t =:=
typeOf[java.util.List[Geography]]) {
+ output =>
+ if (output != null) {
+
ArrayData.toArrayData(output.asInstanceOf[Array[Geography]].map(_.toGenericArrayData))
+ } else {
+ null
+ }
+
} else if (InferredRasterExpression.isRasterArrayType(t)) {
InferredRasterExpression.rasterArraySerializer
} else if (t =:= typeOf[Option[Boolean]]) { output =>
@@ -277,6 +299,10 @@ object InferredTypes {
GeometryUDT
} else if (t =:= typeOf[Array[Geometry]] || t =:=
typeOf[java.util.List[Geometry]]) {
DataTypes.createArrayType(GeometryUDT)
+ } else if (t =:= typeOf[Geography]) {
+ GeographyUDT
+ } else if (t =:= typeOf[Array[Geography]] || t =:=
typeOf[java.util.List[Geography]]) {
+ DataTypes.createArrayType(GeographyUDT)
} else if (InferredRasterExpression.isRasterType(t)) {
InferredRasterExpression.rasterUDT
} else if (InferredRasterExpression.isRasterArrayType(t)) {
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
index fd55cc02d8..4e97a4fc45 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.{Geometry, GeometryFactory, Point}
+import org.apache.sedona.common.geometryObjects.Geography
object implicits {
@@ -60,6 +61,18 @@ object implicits {
}
}
+ def toGeography(input: InternalRow): Geography = {
+ inputExpression match {
+ case serdeAware: SerdeAware =>
+ serdeAware.evalWithoutSerialization(input).asInstanceOf[Geography]
+ case _ =>
+ inputExpression.eval(input).asInstanceOf[Array[Byte]] match {
+ case binary: Array[Byte] => new
Geography(GeometrySerializer.deserialize(binary))
+ case _ => null
+ }
+ }
+ }
+
def toDoubleList(input: InternalRow): java.util.List[java.lang.Double] = {
inputExpression match {
case aware: SerdeAware =>
@@ -141,4 +154,9 @@ object implicits {
def isNonEmpty: Boolean = geom != null && !geom.isEmpty
}
+
+ implicit class GeographyEnhancer(geog: Geography) {
+
+ def toGenericArrayData: Array[Byte] =
GeometrySerializer.serialize(geog.getGeometry)
+ }
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala
index 437ad91630..349e38b8b7 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala
@@ -89,6 +89,12 @@ object st_constructors extends DataFrameAPI {
def ST_GeomFromEWKT(wkt: Column): Column =
wrapExpression[ST_GeomFromEWKT](wkt)
def ST_GeomFromEWKT(wkt: String): Column =
wrapExpression[ST_GeomFromEWKT](wkt)
+ def ST_GeogFromWKT(wkt: Column): Column =
wrapExpression[ST_GeogFromWKT](wkt, 0)
+ def ST_GeogFromWKT(wkt: String): Column =
wrapExpression[ST_GeogFromWKT](wkt, 0)
+ def ST_GeogFromWKT(wkt: Column, srid: Column): Column =
+ wrapExpression[ST_GeogFromWKT](wkt, srid)
+ def ST_GeogFromWKT(wkt: String, srid: Int): Column =
wrapExpression[ST_GeogFromWKT](wkt, srid)
+
def ST_LineFromText(wkt: Column): Column =
wrapExpression[ST_LineFromText](wkt)
def ST_LineFromText(wkt: String): Column =
wrapExpression[ST_LineFromText](wkt)
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/FunctionResolverSuite.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/FunctionResolverSuite.scala
new file mode 100644
index 0000000000..63b30a5107
--- /dev/null
+++
b/spark/common/src/test/scala/org/apache/sedona/sql/FunctionResolverSuite.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.sedona_sql.expressions.{FunctionResolver,
InferrableFunction}
+import org.apache.spark.sql.types._
+import org.scalatest.funspec.AnyFunSpec
+
+class FunctionResolverSuite extends AnyFunSpec {
+ describe("FunctionResolver test") {
+ // Helper method to create test functions
+ def createTestFunction(inputTypes: Seq[DataType]): InferrableFunction = {
+ InferrableFunction(
+ sparkInputTypes = inputTypes,
+ sparkReturnType = StringType,
+ serializer = identity,
+ argExtractorBuilders = Seq.empty,
+ evaluatorBuilder = _ => _ => "")
+ }
+
+ // Helper method to create test expressions
+ def createTestExpression(dataType: DataType): Expression = {
+ Literal.create(null, dataType)
+ }
+
+ it("No function matching input arity") {
+ val functions = Seq(
+ createTestFunction(Seq(IntegerType)),
+ createTestFunction(Seq(IntegerType, StringType)))
+ val expressions = Seq(
+ createTestExpression(IntegerType),
+ createTestExpression(StringType),
+ createTestExpression(DoubleType))
+
+ assertThrows[IllegalArgumentException] {
+ FunctionResolver.resolveFunction(expressions, functions)
+ }
+ }
+
+ it("Only one function matches input arity") {
+ val functions = Seq(
+ createTestFunction(Seq(IntegerType)),
+ createTestFunction(Seq(IntegerType, StringType, DoubleType)))
+ val expressions = Seq(createTestExpression(IntegerType))
+
+ val result = FunctionResolver.resolveFunction(expressions, functions)
+ assert(result.sparkInputTypes == Seq(IntegerType))
+ }
+
+ it("Multiple functions match input arity, perfect match") {
+ val functions =
+ Seq(createTestFunction(Seq(IntegerType)),
createTestFunction(Seq(StringType)))
+ val expressions = Seq(createTestExpression(IntegerType))
+
+ val result = FunctionResolver.resolveFunction(expressions, functions)
+ assert(result.sparkInputTypes == Seq(IntegerType))
+ }
+
+ it("Multiple functions match input arity, no perfect match, no ambiguity")
{
+ val functions = Seq(
+ createTestFunction(Seq(LongType, StringType)),
+ createTestFunction(Seq(DoubleType, StringType)))
+ val expressions = Seq(createTestExpression(LongType),
createTestExpression(LongType))
+
+ val result = FunctionResolver.resolveFunction(expressions, functions)
+ // Integer can be coerced to Long with less loss of precision than Double
+ assert(result.sparkInputTypes == Seq(LongType, StringType))
+ }
+
+ it("Multiple functions match input arity, ambiguity") {
+ val functions = Seq(
+ createTestFunction(Seq(LongType, StringType)),
+ createTestFunction(Seq(DoubleType, StringType)))
+ val expressions = Seq(createTestExpression(IntegerType),
createTestExpression(StringType))
+
+ assertThrows[IllegalArgumentException] {
+ FunctionResolver.resolveFunction(expressions, functions)
+ }
+ }
+ }
+}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
index 2dd9cdfedd..402785e27f 100644
---
a/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
+++
b/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
@@ -18,6 +18,7 @@
*/
package org.apache.sedona.sql
+import org.apache.sedona.common.geometryObjects.Geography
import org.apache.sedona.core.formatMapper.GeoJsonReader
import org.apache.sedona.core.formatMapper.shapefileParser.ShapefileReader
import org.apache.sedona.sql.utils.Adapter
@@ -263,6 +264,14 @@ class constructorTestScala extends TestBaseScala {
assert(thrown.getMessage.contains("Unknown geometry type"))
}
+ it("Passed ST_GeogFromWKT") {
+ val wkt = "LINESTRING (1 2, 3 4, 5 6)"
+ val row = sparkSession.sql(s"SELECT ST_GeogFromWKT('$wkt') AS
geog").first()
+ val geog = row.get(0)
+ assert(geog.isInstanceOf[Geography])
+ assert(geog.asInstanceOf[Geography].getGeometry.toText == wkt)
+ }
+
it("Passed ST_LineFromText") {
val geometryDf = Seq("Linestring(1 2, 3 4)").map(wkt =>
Tuple1(wkt)).toDF("geom")
geometryDf.createOrReplaceTempView("linetable")
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
index a89af2355d..9874f0519d 100644
---
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
+++
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
@@ -19,8 +19,9 @@
package org.apache.sedona.sql
import org.apache.commons.codec.binary.Hex
+import org.apache.sedona.common.geometryObjects.Geography
import org.apache.spark.sql.Row
-import org.apache.spark.sql.functions.{radians, col, element_at, expr, lit}
+import org.apache.spark.sql.functions.{col, element_at, expr, lit, radians}
import org.apache.spark.sql.sedona_sql.expressions.InferredExpressionException
import org.apache.spark.sql.sedona_sql.expressions.st_aggregates._
import org.apache.spark.sql.sedona_sql.expressions.st_constructors._
@@ -223,6 +224,21 @@ class dataFrameAPITestScala extends TestBaseScala {
assert(actualResult.getSRID == 4326)
}
+ it("passed st_geogfromwkt") {
+ val df = sparkSession.sql("SELECT 'POINT(0.0 1.0)' AS
wkt").select(ST_GeogFromWKT("wkt"))
+ val actualResult = df.take(1)(0).get(0).asInstanceOf[Geography].toString
+ val expectedResult = "POINT (0 1)"
+ assert(actualResult == expectedResult)
+ }
+
+ it("passed st_geogfromwkt with srid") {
+ val df =
+ sparkSession.sql("SELECT 'POINT(0.0 1.0)' AS
wkt").select(ST_GeogFromWKT("wkt", 4326))
+ val actualResult = df.take(1)(0).get(0).asInstanceOf[Geography]
+ assert(actualResult.toString == "POINT (0 1)")
+ assert(actualResult.getGeometry.getSRID == 4326)
+ }
+
it("passed st_geomfromewkt") {
val df = sparkSession
.sql("SELECT 'SRID=4269;POINT(0.0 1.0)' AS wkt")
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index 84770157c7..665304662e 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -930,6 +930,9 @@ class functionTestScala
df = sparkSession.sql("SELECT ST_AsEWKB(point) from table")
val s = "0101000020cd0b0000000000000000f03f000000000000f03f"
assert(Hex.encodeHexString(df.first().get(0).asInstanceOf[Array[Byte]])
== s)
+ df = sparkSession.sql("SELECT ST_AsEWKB(ST_GeogFromWKT('POINT (1 1)'))")
+ val wkb = df.first().get(0).asInstanceOf[Array[Byte]]
+ assert(Hex.encodeHexString(wkb) ==
"0101000000000000000000f03f000000000000f03f")
}
it("Passed ST_AsHEXEWKB") {
@@ -950,6 +953,13 @@ class functionTestScala
assert(Hex.encodeHexString(df.first().get(0).asInstanceOf[Array[Byte]])
== s)
}
+ it("Passed ST_AsEWKT") {
+ val wkt = "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))"
+ val df = sparkSession.sql(s"SELECT ST_AsEWKT(ST_GeogFromWKT('$wkt'))")
+ val row = df.first()
+ assert(row.getString(0) == wkt)
+ }
+
it("Passed ST_Simplify") {
val baseDf = sparkSession.sql("SELECT ST_Buffer(ST_GeomFromWKT('POINT (0
2)'), 10) AS geom")
val actualPoints = baseDf.selectExpr("ST_NPoints(ST_Simplify(geom,
1))").first().get(0)
@@ -2682,7 +2692,9 @@ class functionTestScala
assert(functionDf.first().get(0) == null)
functionDf = sparkSession.sql("select ST_AsBinary(null)")
assert(functionDf.first().get(0) == null)
- functionDf = sparkSession.sql("select ST_AsEWKB(null)")
+ functionDf = sparkSession.sql("select ST_AsEWKB(ST_GeomFromWKT(null))")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_AsEWKB(ST_GeogFromWKT(null))")
assert(functionDf.first().get(0) == null)
functionDf = sparkSession.sql("select ST_SRID(null)")
assert(functionDf.first().get(0) == null)
@@ -2754,7 +2766,9 @@ class functionTestScala
assert(functionDf.first().get(0) == null)
functionDf = sparkSession.sql("select ST_Reverse(null)")
assert(functionDf.first().get(0) == null)
- functionDf = sparkSession.sql("select ST_AsEWKT(null)")
+ functionDf = sparkSession.sql("select ST_AsEWKT(ST_GeomFromWKT(null))")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_AsEWKT(ST_GeogFromWKT(null))")
assert(functionDf.first().get(0) == null)
functionDf = sparkSession.sql("select ST_Force_2D(null)")
assert(functionDf.first().get(0) == null)