This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 4964c1b902 [GH-2072] Fix Structured Adapter Issue with Spatial Queries
(#2073)
4964c1b902 is described below
commit 4964c1b902e60b1664fb739ece4e6e94875bfd87
Author: Feng Zhang <[email protected]>
AuthorDate: Tue Jul 8 18:30:24 2025 -0700
[GH-2072] Fix Structured Adapter Issue with Spatial Queries (#2073)
* [GH-2072] Fix Structured Adapter Issue with Spatial Queries
* add more tests
---
python/tests/sql/test_structured_adapter.py | 186 +++++++++++++++++++++
.../sedona/python/wrapper/utils/implicits.scala | 10 +-
2 files changed, 195 insertions(+), 1 deletion(-)
diff --git a/python/tests/sql/test_structured_adapter.py
b/python/tests/sql/test_structured_adapter.py
index 960dd45415..640540ca34 100644
--- a/python/tests/sql/test_structured_adapter.py
+++ b/python/tests/sql/test_structured_adapter.py
@@ -18,6 +18,11 @@ import glob
import tempfile
from pyspark.sql import DataFrame
+from pyspark.sql.functions import expr
+from shapely.geometry.point import Point
+
+from sedona.spark.core.enums import IndexType
+from sedona.spark.core.spatialOperator import RangeQuery
from sedona.spark.core.SpatialRDD import CircleRDD
from sedona.spark.core.enums import GridType
@@ -80,3 +85,184 @@ class TestStructuredAdapter(TestBase):
out = td + "/out"
partitioned_df.write.format("geoparquet").save(out)
assert len(glob.glob(out + "/*.parquet")) == n_spatial_partitions
+
+ def test_build_index_and_range_query_with_polygons(self):
+ # Create a spatial DataFrame with polygons
+ polygons_data = [
+ (1, "POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))"),
+ (2, "POLYGON((1 1, 2 1, 2 2, 1 2, 1 1))"),
+ (3, "POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))"),
+ (4, "POLYGON((3 3, 4 3, 4 4, 3 4, 3 3))"),
+ (5, "POLYGON((4 4, 5 4, 5 5, 4 5, 4 4))"),
+ ]
+
+ df = self.spark.createDataFrame(polygons_data, ["id", "wkt"])
+ spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
+
+ # Convert to SpatialRDD
+ spatial_rdd = StructuredAdapter.toSpatialRdd(spatial_df, "geometry")
+
+ # Build index on the spatial RDD
+ spatial_rdd.buildIndex(IndexType.RTREE, False)
+
+ query_point = Point(2.2, 2.2)
+
+ # Perform range query
+ query_result = RangeQuery.SpatialRangeQuery(
+ spatial_rdd, query_point, True, True
+ )
+
+ # Assertions
+ result_count = query_result.count()
+
+ assert result_count >= 0, f"Expected at least one result, got
{result_count}"
+
+ def test_build_index_and_range_query_with_points(self):
+ # Create a spatial DataFrame with points
+ points_data = [
+ (1, "POINT(0 0)"),
+ (2, "POINT(1 1)"),
+ (3, "POINT(2 2)"),
+ (4, "POINT(3 3)"),
+ (5, "POINT(4 4)"),
+ ]
+
+ df = self.spark.createDataFrame(points_data, ["id", "wkt"])
+ spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
+
+ # Convert to SpatialRDD
+ spatial_rdd = StructuredAdapter.toSpatialRdd(spatial_df, "geometry")
+
+ # Build index on the spatial RDD
+ spatial_rdd.buildIndex(IndexType.RTREE, False)
+
+ query_window = Point(2.0, 2.0).buffer(1.0)
+
+ # Perform range query
+ query_result = RangeQuery.SpatialRangeQuery(
+ spatial_rdd, query_window, True, True
+ )
+
+ # Assertions
+ result_count = query_result.count()
+ assert result_count > 0, f"Expected at least one result, got
{result_count}"
+
+ def test_build_index_and_range_query_with_linestrings(self):
+ # Create a spatial DataFrame with linestrings
+ linestrings_data = [
+ (1, "LINESTRING(0 0, 1 1)"),
+ (2, "LINESTRING(1 1, 2 2)"),
+ (3, "LINESTRING(2 2, 3 3)"),
+ (4, "LINESTRING(3 3, 4 4)"),
+ (5, "LINESTRING(4 4, 5 5)"),
+ ]
+
+ df = self.spark.createDataFrame(linestrings_data, ["id", "wkt"])
+ spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
+
+ # Convert to SpatialRDD
+ spatial_rdd = StructuredAdapter.toSpatialRdd(spatial_df, "geometry")
+
+ # Build index on the spatial RDD
+ spatial_rdd.buildIndex(IndexType.RTREE, False)
+
+ query_window = Point(2.0, 2.0).buffer(0.5)
+
+ # Perform range query
+ query_result = RangeQuery.SpatialRangeQuery(
+ spatial_rdd, query_window, True, True
+ )
+
+ # Assertions
+ result_count = query_result.count()
+ assert result_count > 0, f"Expected at least one result, got
{result_count}"
+
+ def test_build_index_and_range_query_with_mixed_geometries(self):
+ # Create a spatial DataFrame with mixed geometry types
+ mixed_data = [
+ (1, "POINT(0 0)"),
+ (2, "LINESTRING(1 1, 2 2)"),
+ (3, "POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))"),
+ (4, "MULTIPOINT((3 3), (3.1 3.1))"),
+ (
+ 5,
+ "MULTIPOLYGON(((4 4, 5 4, 5 5, 4 5, 4 4)), ((4.1 4.1, 4.2 4.1,
4.2 4.2, 4.1 4.2, 4.1 4.1)))",
+ ),
+ ]
+
+ df = self.spark.createDataFrame(mixed_data, ["id", "wkt"])
+ spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
+
+ # Convert to SpatialRDD
+ spatial_rdd = StructuredAdapter.toSpatialRdd(spatial_df, "geometry")
+
+ # Build index on the spatial RDD
+ spatial_rdd.buildIndex(IndexType.RTREE, False)
+
+ query_window = Point(3.0, 3.0).buffer(1.0)
+
+ # Perform range query
+ query_result = RangeQuery.SpatialRangeQuery(
+ spatial_rdd, query_window, True, True
+ )
+
+ # Assertions
+ result_count = query_result.count()
+ assert result_count > 0, f"Expected at least one result, got
{result_count}"
+
+ def test_toDf_preserves_columns_with_proper_types(self):
+ # Create a spatial DataFrame with various columns and types
+ data = [
+ (1, "POINT(0 0)", "alpha", 10.5, True),
+ (2, "POINT(1 1)", "beta", 20.7, False),
+ (3, "POINT(2 2)", "gamma", 30.9, True),
+ ]
+
+ schema = ["id", "wkt", "name", "value", "flag"]
+ df = self.spark.createDataFrame(data, schema)
+ spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
+
+ # Store original column names and types
+ original_cols = spatial_df.columns
+ original_dtypes = {f.name: f.dataType for f in
spatial_df.schema.fields}
+
+ # Convert to SpatialRDD and back to DataFrame
+ spatial_rdd = StructuredAdapter.toSpatialRdd(spatial_df, "geometry")
+ result_df = StructuredAdapter.toDf(spatial_rdd, self.spark)
+
+ # Verify all columns are preserved
+ assert len(result_df.columns) == len(original_cols)
+ for col in original_cols:
+ assert col in result_df.columns
+
+ # Verify data types are preserved
+ result_dtypes = {f.name: f.dataType for f in result_df.schema.fields}
+ for col, dtype in original_dtypes.items():
+ assert col in result_dtypes
+ assert str(result_dtypes[col]) == str(
+ dtype
+ ), f"Type mismatch for {col}: expected {dtype}, got
{result_dtypes[col]}"
+
+ # Verify values are preserved
+ for i in range(1, 4):
+ original_row = spatial_df.filter(spatial_df.id == i).collect()[0]
+ result_row = result_df.filter(result_df.id == i).collect()[0]
+
+ # Compare values for each column
+ assert result_row["id"] == original_row["id"]
+ assert result_row["name"] == original_row["name"]
+ assert abs(result_row["value"] - original_row["value"]) < 0.001
+ assert result_row["flag"] == original_row["flag"]
+
+ # Verify geometry data is preserved (using WKT representation)
+ orig_wkt = (
+ spatial_df.filter(spatial_df.id == i)
+ .select(expr("ST_AsText(geometry)"))
+ .collect()[0][0]
+ )
+ result_wkt = (
+ result_df.filter(result_df.id == i)
+ .select(expr("ST_AsText(geometry)"))
+ .collect()[0][0]
+ )
+ assert orig_wkt == result_wkt
diff --git
a/spark/common/src/main/scala/org/apache/sedona/python/wrapper/utils/implicits.scala
b/spark/common/src/main/scala/org/apache/sedona/python/wrapper/utils/implicits.scala
index 204c57b963..60fb9d46e2 100644
---
a/spark/common/src/main/scala/org/apache/sedona/python/wrapper/utils/implicits.scala
+++
b/spark/common/src/main/scala/org/apache/sedona/python/wrapper/utils/implicits.scala
@@ -18,6 +18,7 @@
*/
package org.apache.sedona.python.wrapper.utils
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.locationtech.jts.geom.Geometry
import java.nio.charset.StandardCharsets
@@ -50,10 +51,17 @@ object implicits {
def userDataToUtf8ByteArray: Array[Byte] = {
geometry.getUserData match {
+ // Case when user data is null: return an empty UTF-8 byte array
case null => EMPTY_STRING.getBytes(StandardCharsets.UTF_8)
+ // Case when user data is a String: convert the string to a UTF-8 byte
array
case data: String => data.getBytes(StandardCharsets.UTF_8)
+ // Case when user data is already an Array[Byte]: return as is
+ case data: Array[Byte] => data
+ // Case when user data is an UnsafeRow: use its getBytes method
+ case data: UnsafeRow => data.getBytes
+ // Case for any other type: convert to string, then to a UTF-8 byte
array
+ case data => data.toString.getBytes(StandardCharsets.UTF_8)
}
}
}
-
}