This is an automated email from the ASF dual-hosted git repository. jiayu pushed a commit to branch aggregate-function-rename in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 7896d455b6a7712f71c42a3845399629effd159b Author: Jia Yu <[email protected]> AuthorDate: Wed Dec 17 00:15:20 2025 -0800 WIP --- docs/api/flink/Aggregator.md | 27 ++++++--- .../api/snowflake/vector-data/AggregateFunction.md | 27 ++++++--- docs/api/sql/AggregateFunction.md | 29 +++++---- .../main/java/org/apache/sedona/flink/Catalog.java | 4 ++ .../sedona/flink/expressions/Aggregators.java | 10 ++++ .../org/apache/sedona/flink/AggregatorTest.java | 34 +++++++++++ python/sedona/spark/sql/st_aggregates.py | 43 ++++++++++++++ python/tests/sql/test_aggregate_functions.py | 69 ++++++++++++++++++++++ python/tests/sql/test_dataframe_api.py | 22 +++++++ .../snowflake/snowsql/udtfs/ST_Envelope_Agg.java | 26 ++++++++ .../snowsql/udtfs/ST_Intersection_Agg.java | 26 ++++++++ .../snowflake/snowsql/udtfs/ST_Union_Agg.java | 26 ++++++++ .../apache/sedona/sql/UDF/AbstractCatalog.scala | 30 +++++++--- .../sql/sedona_sql/expressions/st_aggregates.scala | 13 ++++ .../sedona/sql/aggregateFunctionTestScala.scala | 41 +++++++++++++ .../apache/sedona/sql/dataFrameAPITestScala.scala | 28 +++++++++ 16 files changed, 420 insertions(+), 35 deletions(-) diff --git a/docs/api/flink/Aggregator.md b/docs/api/flink/Aggregator.md index 87b2a0f8be..94f252f617 100644 --- a/docs/api/flink/Aggregator.md +++ b/docs/api/flink/Aggregator.md @@ -17,18 +17,21 @@ under the License. --> -## ST_Envelope_Aggr +## ST_Envelope_Agg Introduction: Return the entire envelope boundary of all geometries in A -Format: `ST_Envelope_Aggr (A: geometryColumn)` +Format: `ST_Envelope_Agg (A: geometryColumn)` Since: `v1.3.0` +!!!note + This function was previously named `ST_Envelope_Aggr`, which is deprecated since `v1.8.1`. + Example: ```sql -SELECT ST_Envelope_Aggr(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) +SELECT ST_Envelope_Agg(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) ``` Output: @@ -37,18 +40,21 @@ Output: POLYGON ((1.1 101.1, 1.1 120.1, 20.1 120.1, 20.1 101.1, 1.1 101.1)) ``` -## ST_Intersection_Aggr +## ST_Intersection_Agg Introduction: Return the polygon intersection of all polygons in A -Format: `ST_Intersection_Aggr (A: geometryColumn)` +Format: `ST_Intersection_Agg (A: geometryColumn)` Since: `v1.5.0` +!!!note + This function was previously named `ST_Intersection_Aggr`, which is deprecated since `v1.8.1`. + Example: ```sql -SELECT ST_Intersection_Aggr(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) +SELECT ST_Intersection_Agg(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) ``` Output: @@ -57,18 +63,21 @@ Output: MULTIPOINT ((1.1 101.1), (2.1 102.1), (3.1 103.1), (4.1 104.1), (5.1 105.1), (6.1 106.1), (7.1 107.1), (8.1 108.1), (9.1 109.1), (10.1 110.1)) ``` -## ST_Union_Aggr +## ST_Union_Agg Introduction: Return the polygon union of all polygons in A. All inputs must be polygons. -Format: `ST_Union_Aggr (A: geometryColumn)` +Format: `ST_Union_Agg (A: geometryColumn)` Since: `v1.3.0` +!!!note + This function was previously named `ST_Union_Aggr`, which is deprecated since `v1.8.1`. + Example: ```sql -SELECT ST_Union_Aggr(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) +SELECT ST_Union_Agg(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) ``` Output: diff --git a/docs/api/snowflake/vector-data/AggregateFunction.md b/docs/api/snowflake/vector-data/AggregateFunction.md index 7669819d99..b8ba8ee4e2 100644 --- a/docs/api/snowflake/vector-data/AggregateFunction.md +++ b/docs/api/snowflake/vector-data/AggregateFunction.md @@ -20,11 +20,14 @@ !!!note Please always keep the schema name `SEDONA` (e.g., `SEDONA.ST_GeomFromWKT`) when you use Sedona functions to avoid conflicting with Snowflake's built-in functions. -## ST_Envelope_Aggr +## ST_Envelope_Agg Introduction: Return the entire envelope boundary of all geometries in A -Format: `ST_Envelope_Aggr (A:geometryColumn)` +Format: `ST_Envelope_Agg (A:geometryColumn)` + +!!!note + This function was previously named `ST_Envelope_Aggr`, which is deprecated since `v1.8.1`. SQL example: @@ -36,7 +39,7 @@ WITH src_tbl AS ( ) SELECT sedona.ST_AsText(envelope) FROM src_tbl, - TABLE(sedona.ST_Envelope_Aggr(src_tbl.geom) OVER (PARTITION BY 1)); + TABLE(sedona.ST_Envelope_Agg(src_tbl.geom) OVER (PARTITION BY 1)); ``` Output: @@ -45,11 +48,14 @@ Output: POLYGON ((0 0, 0 1.5, 1.5 1.5, 1.5 0, 0 0)) ``` -## ST_Intersection_Aggr +## ST_Intersection_Agg Introduction: Return the polygon intersection of all polygons in A -Format: `ST_Intersection_Aggr (A:geometryColumn)` +Format: `ST_Intersection_Agg (A:geometryColumn)` + +!!!note + This function was previously named `ST_Intersection_Aggr`, which is deprecated since `v1.8.1`. SQL example: @@ -61,7 +67,7 @@ WITH src_tbl AS ( ) SELECT sedona.ST_AsText(intersected) FROM src_tbl, - TABLE(sedona.ST_Intersection_Aggr(src_tbl.geom) OVER (PARTITION BY 1)); + TABLE(sedona.ST_Intersection_Agg(src_tbl.geom) OVER (PARTITION BY 1)); ``` Output: @@ -70,11 +76,14 @@ Output: POLYGON ((0.5 1, 1 1, 1 0.5, 0.5 0.5, 0.5 1)) ``` -## ST_Union_Aggr +## ST_Union_Agg Introduction: Return the polygon union of all polygons in A -Format: `ST_Union_Aggr (A:geometryColumn)` +Format: `ST_Union_Agg (A:geometryColumn)` + +!!!note + This function was previously named `ST_Union_Aggr`, which is deprecated since `v1.8.1`. SQL example: @@ -86,7 +95,7 @@ WITH src_tbl AS ( ) SELECT sedona.ST_AsText(unioned) FROM src_tbl, - TABLE(sedona.ST_Union_Aggr(src_tbl.geom) OVER (PARTITION BY 1)); + TABLE(sedona.ST_Union_Agg(src_tbl.geom) OVER (PARTITION BY 1)); ``` Output: diff --git a/docs/api/sql/AggregateFunction.md b/docs/api/sql/AggregateFunction.md index 8918c6d428..4d0165f2b3 100644 --- a/docs/api/sql/AggregateFunction.md +++ b/docs/api/sql/AggregateFunction.md @@ -19,7 +19,7 @@ ## ST_Collect_Agg -Introduction: Collects all geometries in a geometry column into a single multi-geometry (MultiPoint, MultiLineString, MultiPolygon, or GeometryCollection). Unlike `ST_Union_Aggr`, this function does not dissolve boundaries between geometries - it simply collects them into a multi-geometry. +Introduction: Collects all geometries in a geometry column into a single multi-geometry (MultiPoint, MultiLineString, MultiPolygon, or GeometryCollection). Unlike `ST_Union_Agg`, this function does not dissolve boundaries between geometries - it simply collects them into a multi-geometry. Format: `ST_Collect_Agg (A: geometryColumn)` @@ -49,18 +49,21 @@ SQL Example with GROUP BY SELECT category, ST_Collect_Agg(geom) FROM geometries GROUP BY category ``` -## ST_Envelope_Aggr +## ST_Envelope_Agg Introduction: Return the entire envelope boundary of all geometries in A -Format: `ST_Envelope_Aggr (A: geometryColumn)` +Format: `ST_Envelope_Agg (A: geometryColumn)` Since: `v1.0.0` +!!!note + This function was previously named `ST_Envelope_Aggr`, which is deprecated since `v1.8.1`. + SQL Example ```sql -SELECT ST_Envelope_Aggr(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) +SELECT ST_Envelope_Agg(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) ``` Output: @@ -69,18 +72,21 @@ Output: POLYGON ((1.1 101.1, 1.1 120.1, 20.1 120.1, 20.1 101.1, 1.1 101.1)) ``` -## ST_Intersection_Aggr +## ST_Intersection_Agg Introduction: Return the polygon intersection of all polygons in A -Format: `ST_Intersection_Aggr (A: geometryColumn)` +Format: `ST_Intersection_Agg (A: geometryColumn)` Since: `v1.0.0` +!!!note + This function was previously named `ST_Intersection_Aggr`, which is deprecated since `v1.8.1`. + SQL Example ```sql -SELECT ST_Intersection_Aggr(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) +SELECT ST_Intersection_Agg(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) ``` Output: @@ -89,18 +95,21 @@ Output: MULTIPOINT ((1.1 101.1), (2.1 102.1), (3.1 103.1), (4.1 104.1), (5.1 105.1), (6.1 106.1), (7.1 107.1), (8.1 108.1), (9.1 109.1), (10.1 110.1)) ``` -## ST_Union_Aggr +## ST_Union_Agg Introduction: Return the polygon union of all polygons in A -Format: `ST_Union_Aggr (A: geometryColumn)` +Format: `ST_Union_Agg (A: geometryColumn)` Since: `v1.0.0` +!!!note + This function was previously named `ST_Union_Aggr`, which is deprecated since `v1.8.1`. + SQL Example ```sql -SELECT ST_Union_Aggr(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) +SELECT ST_Union_Agg(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) ``` Output: diff --git a/flink/src/main/java/org/apache/sedona/flink/Catalog.java b/flink/src/main/java/org/apache/sedona/flink/Catalog.java index dcb593e9c3..0ee24de799 100644 --- a/flink/src/main/java/org/apache/sedona/flink/Catalog.java +++ b/flink/src/main/java/org/apache/sedona/flink/Catalog.java @@ -27,6 +27,10 @@ public class Catalog { new Aggregators.ST_Envelope_Aggr(), new Aggregators.ST_Intersection_Aggr(), new Aggregators.ST_Union_Aggr(), + // Aliases for *_Aggr functions with *_Agg suffix + new Aggregators.ST_Envelope_Agg(), + new Aggregators.ST_Intersection_Agg(), + new Aggregators.ST_Union_Agg(), new Constructors.ST_Point(), new Constructors.ST_PointZ(), new Constructors.ST_PointM(), diff --git a/flink/src/main/java/org/apache/sedona/flink/expressions/Aggregators.java b/flink/src/main/java/org/apache/sedona/flink/expressions/Aggregators.java index de3d2f525d..73af1729c1 100644 --- a/flink/src/main/java/org/apache/sedona/flink/expressions/Aggregators.java +++ b/flink/src/main/java/org/apache/sedona/flink/expressions/Aggregators.java @@ -199,4 +199,14 @@ public class Aggregators { acc.geom = null; } } + + // Aliases for *_Aggr functions with *_Agg suffix + @DataTypeHint(value = "RAW", bridgedTo = Geometry.class) + public static class ST_Envelope_Agg extends ST_Envelope_Aggr {} + + @DataTypeHint(value = "RAW", bridgedTo = Geometry.class) + public static class ST_Intersection_Agg extends ST_Intersection_Aggr {} + + @DataTypeHint(value = "RAW", bridgedTo = Geometry.class) + public static class ST_Union_Agg extends ST_Union_Aggr {} } diff --git a/flink/src/test/java/org/apache/sedona/flink/AggregatorTest.java b/flink/src/test/java/org/apache/sedona/flink/AggregatorTest.java index 258927622a..0220ff434d 100644 --- a/flink/src/test/java/org/apache/sedona/flink/AggregatorTest.java +++ b/flink/src/test/java/org/apache/sedona/flink/AggregatorTest.java @@ -93,4 +93,38 @@ public class AggregatorTest extends TestBase { Row last = last(result); assertEquals(1001, ((Polygon) last.getField(0)).getArea(), 0); } + + // Test aliases for *_Aggr functions with *_Agg suffix + @Test + public void testEnvelop_Agg_Alias() { + Table pointTable = createPointTable(testDataSize); + Table result = pointTable.select(call("ST_Envelope_Agg", $(pointColNames[0]))); + Row last = last(result); + assertEquals( + String.format( + "POLYGON ((0 0, 0 %s, %s %s, %s 0, 0 0))", + testDataSize - 1, testDataSize - 1, testDataSize - 1, testDataSize - 1), + last.getField(0).toString()); + } + + @Test + public void testIntersection_Agg_Alias() { + Table polygonTable = createPolygonOverlappingTable(testDataSize); + Table result = polygonTable.select(call("ST_Intersection_Agg", $(polygonColNames[0]))); + Row last = last(result); + assertEquals("LINESTRING EMPTY", last.getField(0).toString()); + + polygonTable = createPolygonOverlappingTable(3); + result = polygonTable.select(call("ST_Intersection_Agg", $(polygonColNames[0]))); + last = last(result); + assertEquals("LINESTRING (1 1, 1 0)", last.getField(0).toString()); + } + + @Test + public void testUnion_Agg_Alias() { + Table polygonTable = createPolygonOverlappingTable(testDataSize); + Table result = polygonTable.select(call("ST_Union_Agg", $(polygonColNames[0]))); + Row last = last(result); + assertEquals(1001, ((Polygon) last.getField(0)).getArea(), 0); + } } diff --git a/python/sedona/spark/sql/st_aggregates.py b/python/sedona/spark/sql/st_aggregates.py index ec20e64307..c85a3983ba 100644 --- a/python/sedona/spark/sql/st_aggregates.py +++ b/python/sedona/spark/sql/st_aggregates.py @@ -81,6 +81,49 @@ def ST_Collect_Agg(geometry: ColumnOrName) -> Column: return _call_aggregate_function("ST_Collect_Agg", geometry) +# Aliases for *_Aggr functions with *_Agg suffix +@validate_argument_types +def ST_Envelope_Agg(geometry: ColumnOrName) -> Column: + """Aggregate Function: Get the aggregate envelope of a geometry column. + + This is an alias for ST_Envelope_Aggr. + + :param geometry: Geometry column to aggregate. + :type geometry: ColumnOrName + :return: Geometry representing the aggregate envelope of the geometry column. + :rtype: Column + """ + return ST_Envelope_Aggr(geometry) + + +@validate_argument_types +def ST_Intersection_Agg(geometry: ColumnOrName) -> Column: + """Aggregate Function: Get the aggregate intersection of a geometry column. + + This is an alias for ST_Intersection_Aggr. + + :param geometry: Geometry column to aggregate. + :type geometry: ColumnOrName + :return: Geometry representing the aggregate intersection of the geometry column. + :rtype: Column + """ + return ST_Intersection_Aggr(geometry) + + +@validate_argument_types +def ST_Union_Agg(geometry: ColumnOrName) -> Column: + """Aggregate Function: Get the aggregate union of a geometry column. + + This is an alias for ST_Union_Aggr. + + :param geometry: Geometry column to aggregate. + :type geometry: ColumnOrName + :return: Geometry representing the aggregate union of the geometry column. + :rtype: Column + """ + return ST_Union_Aggr(geometry) + + # Automatically populate __all__ __all__ = [ name diff --git a/python/tests/sql/test_aggregate_functions.py b/python/tests/sql/test_aggregate_functions.py index 4df0e3e230..93c7b1671c 100644 --- a/python/tests/sql/test_aggregate_functions.py +++ b/python/tests/sql/test_aggregate_functions.py @@ -145,3 +145,72 @@ class TestConstructors(TestBase): assert len(result.geoms) == 2 # Area should be 2 because it doesn't merge overlapping areas assert result.area == 2.0 + + # Test aliases for *_Aggr functions with *_Agg suffix + def test_st_envelope_agg_alias(self): + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) + + point_csv_df.createOrReplaceTempView("pointtable_alias") + point_df = self.spark.sql( + "select ST_Point(cast(pointtable_alias._c0 as Decimal(24,20)), cast(pointtable_alias._c1 as Decimal(24,20))) as arealandmark from pointtable_alias" + ) + point_df.createOrReplaceTempView("pointdf_alias") + boundary = self.spark.sql( + "select ST_Envelope_Agg(pointdf_alias.arealandmark) from pointdf_alias" + ) + + coordinates = [ + (1.1, 101.1), + (1.1, 1100.1), + (1000.1, 1100.1), + (1000.1, 101.1), + (1.1, 101.1), + ] + + polygon = Polygon(coordinates) + assert boundary.take(1)[0][0].equals(polygon) + + def test_st_intersection_agg_alias(self): + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(union_polygon_input_location) + ) + + polygon_wkt_df.createOrReplaceTempView("polygontable_alias") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable_alias._c0) as countyshape from polygontable_alias" + ) + polygon_df.createOrReplaceTempView("polygondf_alias") + intersection = self.spark.sql( + "select ST_Intersection_Agg(polygondf_alias.countyshape) from polygondf_alias" + ) + + result = intersection.take(1)[0][0] + assert result.area > 0 + + def test_st_union_agg_alias(self): + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(union_polygon_input_location) + ) + + polygon_wkt_df.createOrReplaceTempView("polygontable_union_alias") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable_union_alias._c0) as countyshape from polygontable_union_alias" + ) + polygon_df.createOrReplaceTempView("polygondf_union_alias") + union = self.spark.sql( + "select ST_Union_Agg(polygondf_union_alias.countyshape) from polygondf_union_alias" + ) + + result = union.take(1)[0][0] + assert result.area > 0 diff --git a/python/tests/sql/test_dataframe_api.py b/python/tests/sql/test_dataframe_api.py index 3ea7ab2076..9629a7ca55 100644 --- a/python/tests/sql/test_dataframe_api.py +++ b/python/tests/sql/test_dataframe_api.py @@ -1237,6 +1237,28 @@ test_configurations = [ "", "MULTIPOINT ((0 0), (1 1))", ), + # Test aliases for *_Aggr functions with *_Agg suffix + ( + sta.ST_Envelope_Agg, + ("geom",), + "exploded_points", + "", + "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))", + ), + ( + sta.ST_Intersection_Agg, + ("geom",), + "exploded_polys", + "", + "LINESTRING (1 0, 1 1)", + ), + ( + sta.ST_Union_Agg, + ("geom",), + "exploded_polys", + "", + "POLYGON ((0 0, 0 1, 1 1, 2 1, 2 0, 1 0, 0 0))", + ), ] wrong_type_configurations = [ diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Envelope_Agg.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Envelope_Agg.java new file mode 100644 index 0000000000..86af166d9b --- /dev/null +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Envelope_Agg.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.snowflake.snowsql.udtfs; + +import org.apache.sedona.snowflake.snowsql.annotations.UDTFAnnotations; + [email protected]( + name = "ST_Envelope_Agg", + argNames = {"geom"}) +public class ST_Envelope_Agg extends ST_Envelope_Aggr {} diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Intersection_Agg.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Intersection_Agg.java new file mode 100644 index 0000000000..458464a6bb --- /dev/null +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Intersection_Agg.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.snowflake.snowsql.udtfs; + +import org.apache.sedona.snowflake.snowsql.annotations.UDTFAnnotations; + [email protected]( + name = "ST_Intersection_Agg", + argNames = {"geom"}) +public class ST_Intersection_Agg extends ST_Intersection_Aggr {} diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Union_Agg.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Union_Agg.java new file mode 100644 index 0000000000..4b5825c669 --- /dev/null +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Union_Agg.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.snowflake.snowsql.udtfs; + +import org.apache.sedona.snowflake.snowsql.annotations.UDTFAnnotations; + [email protected]( + name = "ST_Union_Agg", + argNames = {"geom"}) +public class ST_Union_Agg extends ST_Union_Aggr {} diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala index 6e6d13e023..fc15570d16 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, Literal} import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.sedona_sql.expressions.{ST_Envelope_Aggr, ST_Intersection_Aggr, ST_Union_Aggr} import org.locationtech.jts.geom.Geometry import scala.reflect.ClassTag @@ -93,14 +94,25 @@ abstract class AbstractCatalog { functionBuilder) } aggregateExpressions.foreach { f => - sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f)) - FunctionRegistry.builtin.registerFunction( - FunctionIdentifier(f.getClass.getSimpleName), - new ExpressionInfo(f.getClass.getCanonicalName, null, f.getClass.getSimpleName), - (_: Seq[Expression]) => - throw new UnsupportedOperationException( - s"Aggregate function ${f.getClass.getSimpleName} cannot be used as a regular function")) + registerAggregateFunction(sparkSession, f.getClass.getSimpleName, f) } + // Register aliases for *_Aggr functions with *_Agg suffix + registerAggregateFunction(sparkSession, "ST_Envelope_Agg", new ST_Envelope_Aggr) + registerAggregateFunction(sparkSession, "ST_Intersection_Agg", new ST_Intersection_Aggr) + registerAggregateFunction(sparkSession, "ST_Union_Agg", new ST_Union_Aggr()) + } + + private def registerAggregateFunction( + sparkSession: SparkSession, + functionName: String, + aggregator: Aggregator[Geometry, _, _]): Unit = { + sparkSession.udf.register(functionName, functions.udaf(aggregator)) + FunctionRegistry.builtin.registerFunction( + FunctionIdentifier(functionName), + new ExpressionInfo(aggregator.getClass.getCanonicalName, null, functionName), + (_: Seq[Expression]) => + throw new UnsupportedOperationException( + s"Aggregate function $functionName cannot be used as a regular function")) } def dropAll(sparkSession: SparkSession): Unit = { @@ -110,5 +122,9 @@ abstract class AbstractCatalog { aggregateExpressions.foreach(f => sparkSession.sessionState.functionRegistry.dropFunction( FunctionIdentifier(f.getClass.getSimpleName))) + // Drop aliases for *_Aggr functions + Seq("ST_Envelope_Agg", "ST_Intersection_Agg", "ST_Union_Agg").foreach { aliasName => + sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(aliasName)) + } } } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala index fe0dc8d714..2befcee1ad 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala @@ -62,4 +62,17 @@ object st_aggregates { val aggrFunc = udaf(new ST_Collect_Agg) aggrFunc(col(geometry)) } + + // Aliases for *_Aggr functions with *_Agg suffix + def ST_Envelope_Agg(geometry: Column): Column = ST_Envelope_Aggr(geometry) + + def ST_Envelope_Agg(geometry: String): Column = ST_Envelope_Aggr(geometry) + + def ST_Intersection_Agg(geometry: Column): Column = ST_Intersection_Aggr(geometry) + + def ST_Intersection_Agg(geometry: String): Column = ST_Intersection_Aggr(geometry) + + def ST_Union_Agg(geometry: Column): Column = ST_Union_Aggr(geometry) + + def ST_Union_Agg(geometry: String): Column = ST_Union_Aggr(geometry) } diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala index 911769e2ac..1c429ca618 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala @@ -245,6 +245,47 @@ class aggregateFunctionTestScala extends TestBaseScala { // Should only have 2 points (nulls are skipped) assert(result.getNumGeometries == 2) } + + // Test aliases for *_Aggr functions with *_Agg suffix + it("Passed ST_Envelope_Agg alias") { + var pointCsvDF = sparkSession.read + .format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csvPointInputLocation) + pointCsvDF.createOrReplaceTempView("pointtable_alias") + var pointDf = sparkSession.sql( + "select ST_Point(cast(pointtable_alias._c0 as Decimal(24,20)), cast(pointtable_alias._c1 as Decimal(24,20))) as arealandmark from pointtable_alias") + pointDf.createOrReplaceTempView("pointdf_alias") + var boundary = + sparkSession.sql("select ST_Envelope_Agg(pointdf_alias.arealandmark) from pointdf_alias") + val coordinates: Array[Coordinate] = new Array[Coordinate](5) + coordinates(0) = new Coordinate(1.1, 101.1) + coordinates(1) = new Coordinate(1.1, 1100.1) + coordinates(2) = new Coordinate(1000.1, 1100.1) + coordinates(3) = new Coordinate(1000.1, 101.1) + coordinates(4) = coordinates(0) + val geometryFactory = new GeometryFactory() + geometryFactory.createPolygon(coordinates) + assert(boundary.take(1)(0).get(0) == geometryFactory.createPolygon(coordinates)) + } + + it("Passed ST_Intersection_Agg alias") { + val polygonDf = createPolygonDataFrame(100) + polygonDf.createOrReplaceTempView("polygondf_alias") + val intersectionDF = + sparkSession.sql("SELECT ST_Intersection_Agg(geom) FROM polygondf_alias") + val result = intersectionDF.take(1)(0).get(0).asInstanceOf[Polygon] + assert(result.getArea > 0) + } + + it("Passed ST_Union_Agg alias") { + val polygonDf = createPolygonDataFrame(100) + polygonDf.createOrReplaceTempView("polygondf_union_alias") + val unionDF = sparkSession.sql("SELECT ST_Union_Agg(geom) FROM polygondf_union_alias") + val result = unionDF.take(1)(0).get(0).asInstanceOf[Geometry] + assert(result.getArea > 0) + } } def generateRandomPolygon(index: Int): String = { diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala index 8b8c8ca20c..2fb0d5b5c5 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala @@ -1832,6 +1832,34 @@ class dataFrameAPITestScala extends TestBaseScala { assert(actualResult.getNumGeometries == 3) } + // Test aliases for *_Aggr functions with *_Agg suffix + it("Passed ST_Envelope_Agg alias") { + val baseDf = + sparkSession.sql("SELECT explode(array(ST_Point(0.0, 0.0), ST_Point(1.0, 1.0))) AS geom") + val df = baseDf.select(ST_Envelope_Agg("geom")) + val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText() + val expectedResult = "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))" + assert(actualResult == expectedResult) + } + + it("Passed ST_Union_Agg alias") { + val baseDf = sparkSession.sql( + "SELECT explode(array(ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))'), ST_GeomFromWKT('POLYGON ((1 0, 2 0, 2 1, 1 1, 1 0))'))) AS geom") + val df = baseDf.select(ST_Union_Agg("geom")) + val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText() + val expectedResult = "POLYGON ((0 0, 0 1, 1 1, 2 1, 2 0, 1 0, 0 0))" + assert(actualResult == expectedResult) + } + + it("Passed ST_Intersection_Agg alias") { + val baseDf = sparkSession.sql( + "SELECT explode(array(ST_GeomFromWKT('POLYGON ((0 0, 2 0, 2 1, 0 1, 0 0))'), ST_GeomFromWKT('POLYGON ((1 0, 3 0, 3 1, 1 1, 1 0))'))) AS geom") + val df = baseDf.select(ST_Intersection_Agg("geom")) + val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText() + val expectedResult = "POLYGON ((2 0, 1 0, 1 1, 2 1, 2 0))" + assert(actualResult == expectedResult) + } + it("Passed ST_LineFromMultiPoint") { val baseDf = sparkSession.sql( "SELECT ST_GeomFromWKT('MULTIPOINT((10 40), (40 30), (20 20), (30 10))') AS multipoint")
