This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new ec7d9a64f3 [GH-2545] Add ST_Collect_Agg aggregate function (#2546)
ec7d9a64f3 is described below
commit ec7d9a64f35a5ab93bf5c5d0b0969144d1dbd5b5
Author: Feng Zhang <[email protected]>
AuthorDate: Thu Dec 4 23:03:03 2025 -0800
[GH-2545] Add ST_Collect_Agg aggregate function (#2546)
---
docs/api/sql/AggregateFunction.md | 32 ++++++
python/sedona/spark/sql/st_aggregates.py | 16 +++
python/tests/sql/test_aggregate_functions.py | 74 +++++++++++++
python/tests/sql/test_dataframe_api.py | 7 ++
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 2 +-
.../expressions/AggregateFunctions.scala | 40 +++++++
.../sql/sedona_sql/expressions/st_aggregates.scala | 10 ++
.../sedona/sql/aggregateFunctionTestScala.scala | 117 +++++++++++++++++++++
.../apache/sedona/sql/dataFrameAPITestScala.scala | 9 ++
9 files changed, 306 insertions(+), 1 deletion(-)
diff --git a/docs/api/sql/AggregateFunction.md
b/docs/api/sql/AggregateFunction.md
index d3e4f0a3ab..8918c6d428 100644
--- a/docs/api/sql/AggregateFunction.md
+++ b/docs/api/sql/AggregateFunction.md
@@ -17,6 +17,38 @@
under the License.
-->
+## ST_Collect_Agg
+
+Introduction: Collects all geometries in a geometry column into a single
multi-geometry (MultiPoint, MultiLineString, MultiPolygon, or
GeometryCollection). Unlike `ST_Union_Aggr`, this function does not dissolve
boundaries between geometries - it simply collects them into a multi-geometry.
+
+Format: `ST_Collect_Agg (A: geometryColumn)`
+
+Since: `v1.8.1`
+
+SQL Example
+
+```sql
+SELECT ST_Collect_Agg(geom) FROM (
+ SELECT ST_GeomFromWKT('POINT(1 2)') AS geom
+ UNION ALL
+ SELECT ST_GeomFromWKT('POINT(3 4)') AS geom
+ UNION ALL
+ SELECT ST_GeomFromWKT('POINT(5 6)') AS geom
+)
+```
+
+Output:
+
+```
+MULTIPOINT ((1 2), (3 4), (5 6))
+```
+
+SQL Example with GROUP BY
+
+```sql
+SELECT category, ST_Collect_Agg(geom) FROM geometries GROUP BY category
+```
+
## ST_Envelope_Aggr
Introduction: Return the entire envelope boundary of all geometries in A
diff --git a/python/sedona/spark/sql/st_aggregates.py
b/python/sedona/spark/sql/st_aggregates.py
index d9117e94f1..ec20e64307 100644
--- a/python/sedona/spark/sql/st_aggregates.py
+++ b/python/sedona/spark/sql/st_aggregates.py
@@ -65,6 +65,22 @@ def ST_Union_Aggr(geometry: ColumnOrName) -> Column:
return _call_aggregate_function("ST_Union_Aggr", geometry)
+@validate_argument_types
+def ST_Collect_Agg(geometry: ColumnOrName) -> Column:
+ """Aggregate Function: Collect all geometries into a multi-geometry.
+
+ Unlike ST_Union_Aggr, this function does not dissolve boundaries between
geometries.
+ It simply collects all geometries into a MultiPoint, MultiLineString,
MultiPolygon,
+ or GeometryCollection based on the input geometry types.
+
+ :param geometry: Geometry column to aggregate.
+ :type geometry: ColumnOrName
+ :return: Multi-geometry representing the collection of all geometries in
the column.
+ :rtype: Column
+ """
+ return _call_aggregate_function("ST_Collect_Agg", geometry)
+
+
# Automatically populate __all__
__all__ = [
name
diff --git a/python/tests/sql/test_aggregate_functions.py
b/python/tests/sql/test_aggregate_functions.py
index 335ca5687f..4df0e3e230 100644
--- a/python/tests/sql/test_aggregate_functions.py
+++ b/python/tests/sql/test_aggregate_functions.py
@@ -71,3 +71,77 @@ class TestConstructors(TestBase):
)
assert union.take(1)[0][0].area == 10100
+
+ def test_st_collect_aggr_points(self):
+ self.spark.sql(
+ """
+ SELECT explode(array(
+ ST_GeomFromWKT('POINT(1 2)'),
+ ST_GeomFromWKT('POINT(3 4)'),
+ ST_GeomFromWKT('POINT(5 6)')
+ )) AS geom
+ """
+ ).createOrReplaceTempView("points_table")
+
+ result = self.spark.sql("SELECT ST_Collect_Agg(geom) FROM
points_table").take(
+ 1
+ )[0][0]
+
+ assert result.geom_type == "MultiPoint"
+ assert len(result.geoms) == 3
+
+ def test_st_collect_aggr_polygons(self):
+ self.spark.sql(
+ """
+ SELECT explode(array(
+ ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'),
+ ST_GeomFromWKT('POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))')
+ )) AS geom
+ """
+ ).createOrReplaceTempView("polygons_table")
+
+ result = self.spark.sql("SELECT ST_Collect_Agg(geom) FROM
polygons_table").take(
+ 1
+ )[0][0]
+
+ assert result.geom_type == "MultiPolygon"
+ assert len(result.geoms) == 2
+ assert result.area == 2.0
+
+ def test_st_collect_aggr_mixed_types(self):
+ self.spark.sql(
+ """
+ SELECT explode(array(
+ ST_GeomFromWKT('POINT(1 2)'),
+ ST_GeomFromWKT('LINESTRING(0 0, 1 1)'),
+ ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))')
+ )) AS geom
+ """
+ ).createOrReplaceTempView("mixed_geom_table")
+
+ result = self.spark.sql(
+ "SELECT ST_Collect_Agg(geom) FROM mixed_geom_table"
+ ).take(1)[0][0]
+
+ assert result.geom_type == "GeometryCollection"
+ assert len(result.geoms) == 3
+
+ def test_st_collect_aggr_preserves_duplicates(self):
+ # Test that ST_Collect_Agg keeps duplicate geometries (unlike
ST_Union_Aggr)
+ self.spark.sql(
+ """
+ SELECT explode(array(
+ ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'),
+ ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))')
+ )) AS geom
+ """
+ ).createOrReplaceTempView("duplicate_polygons_table")
+
+ result = self.spark.sql(
+ "SELECT ST_Collect_Agg(geom) FROM duplicate_polygons_table"
+ ).take(1)[0][0]
+
+ # ST_Collect_Agg should preserve both polygons
+ assert len(result.geoms) == 2
+ # Area should be 2 because it doesn't merge overlapping areas
+ assert result.area == 2.0
diff --git a/python/tests/sql/test_dataframe_api.py
b/python/tests/sql/test_dataframe_api.py
index 9683fc3029..939411b5c1 100644
--- a/python/tests/sql/test_dataframe_api.py
+++ b/python/tests/sql/test_dataframe_api.py
@@ -1223,6 +1223,13 @@ test_configurations = [
"",
"POLYGON ((0 0, 0 1, 1 1, 2 1, 2 0, 1 0, 0 0))",
),
+ (
+ sta.ST_Collect_Agg,
+ ("geom",),
+ "exploded_points",
+ "",
+ "MULTIPOINT ((0 0), (1 1))",
+ ),
]
wrong_type_configurations = [
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index f79c3762db..e584e666e3 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -356,7 +356,7 @@ object Catalog extends AbstractCatalog with Logging {
function[ST_GeomToGeography]()) ++ geoStatsFunctions()
val aggregateExpressions: Seq[Aggregator[Geometry, _, _]] =
- Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr, new ST_Union_Aggr())
+ Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr, new ST_Union_Aggr(),
new ST_Collect_Agg())
private def geoStatsFunctions(): Seq[FunctionDescription] = {
// Try loading geostats functions. Return a seq of geo-stats functions. If
any error occurs,
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
index c7724211ee..fc0cab6260 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
@@ -18,6 +18,7 @@
*/
package org.apache.spark.sql.sedona_sql.expressions
+import org.apache.sedona.common.Functions
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
@@ -191,3 +192,42 @@ private[apache] class ST_Intersection_Aggr
else buffer1.intersection(buffer2)
}
}
+
+/**
+ * Return a multi-geometry collection of all geometries in the given column.
Unlike ST_Union_Aggr,
+ * this function does not dissolve boundaries between geometries.
+ */
+private[apache] class ST_Collect_Agg
+ extends Aggregator[Geometry, ListBuffer[Geometry], Geometry] {
+
+ val serde = ExpressionEncoder[Geometry]()
+ val bufferSerde = ExpressionEncoder[ListBuffer[Geometry]]()
+
+ override def reduce(buffer: ListBuffer[Geometry], input: Geometry):
ListBuffer[Geometry] = {
+ if (input != null) {
+ buffer += input
+ }
+ buffer
+ }
+
+ override def merge(
+ buffer1: ListBuffer[Geometry],
+ buffer2: ListBuffer[Geometry]): ListBuffer[Geometry] = {
+ buffer1 ++= buffer2
+ buffer1
+ }
+
+ override def finish(reduction: ListBuffer[Geometry]): Geometry = {
+ if (reduction.isEmpty) {
+ new GeometryFactory().createGeometryCollection()
+ } else {
+ Functions.createMultiGeometry(reduction.toArray)
+ }
+ }
+
+ def bufferEncoder: ExpressionEncoder[ListBuffer[Geometry]] = bufferSerde
+
+ def outputEncoder: ExpressionEncoder[Geometry] = serde
+
+ override def zero: ListBuffer[Geometry] = ListBuffer.empty
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala
index ae2ac145be..fe0dc8d714 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala
@@ -52,4 +52,14 @@ object st_aggregates {
val aggrFunc = udaf(new ST_Union_Aggr)
aggrFunc(col(geometry))
}
+
+ def ST_Collect_Agg(geometry: Column): Column = {
+ val aggrFunc = udaf(new ST_Collect_Agg)
+ aggrFunc(geometry)
+ }
+
+ def ST_Collect_Agg(geometry: String): Column = {
+ val aggrFunc = udaf(new ST_Collect_Agg)
+ aggrFunc(col(geometry))
+ }
}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
index 9608cbcbe3..911769e2ac 100644
---
a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
+++
b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
@@ -128,6 +128,123 @@ class aggregateFunctionTestScala extends TestBaseScala {
assertResult(0.0)(intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry].getArea)
}
+
+ it("Passed ST_Collect_Agg with points") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT('POINT(1 2)'),
+ | ST_GeomFromWKT('POINT(3 4)'),
+ | ST_GeomFromWKT('POINT(5 6)')
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("points_table")
+
+ val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM
points_table")
+ val result = collectDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+ assert(result.getGeometryType == "MultiPoint")
+ assert(result.getNumGeometries == 3)
+ }
+
+ it("Passed ST_Collect_Agg with polygons") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'),
+ | ST_GeomFromWKT('POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))')
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("polygons_table")
+
+ val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM
polygons_table")
+ val result = collectDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+ assert(result.getGeometryType == "MultiPolygon")
+ assert(result.getNumGeometries == 2)
+ // Total area should be 2 (each polygon has area 1)
+ assert(result.getArea == 2.0)
+ }
+
+ it("Passed ST_Collect_Agg with mixed geometry types") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT('POINT(1 2)'),
+ | ST_GeomFromWKT('LINESTRING(0 0, 1 1)'),
+ | ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))')
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("mixed_geom_table")
+
+ val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM
mixed_geom_table")
+ val result = collectDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+ assert(result.getGeometryType == "GeometryCollection")
+ assert(result.getNumGeometries == 3)
+ }
+
+ it("Passed ST_Collect_Agg with GROUP BY") {
+ sparkSession
+ .sql("""
+ |SELECT * FROM (VALUES
+ | (1, ST_GeomFromWKT('POINT(1 2)')),
+ | (1, ST_GeomFromWKT('POINT(3 4)')),
+ | (2, ST_GeomFromWKT('POINT(5 6)')),
+ | (2, ST_GeomFromWKT('POINT(7 8)')),
+ | (2, ST_GeomFromWKT('POINT(9 10)'))
+ |) AS t(group_id, geom)
+ """.stripMargin)
+ .createOrReplaceTempView("grouped_points_table")
+
+ val collectDF = sparkSession.sql(
+ "SELECT group_id, ST_Collect_Agg(geom) as collected FROM
grouped_points_table GROUP BY group_id ORDER BY group_id")
+ val results = collectDF.collect()
+
+ // Group 1 should have 2 points
+ assert(results(0).getAs[Geometry]("collected").getNumGeometries == 2)
+ // Group 2 should have 3 points
+ assert(results(1).getAs[Geometry]("collected").getNumGeometries == 3)
+ }
+
+ it("Passed ST_Collect_Agg preserves duplicates unlike ST_Union_Aggr") {
+ // Test that ST_Collect_Agg keeps duplicate geometries (unlike
ST_Union_Aggr which merges them)
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'),
+ | ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))')
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("duplicate_polygons_table")
+
+ val collectDF =
+ sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM
duplicate_polygons_table")
+ val result = collectDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+ // ST_Collect_Agg should preserve both polygons
+ assert(result.getNumGeometries == 2)
+ // Area should be 2 because it doesn't merge overlapping areas
+ assert(result.getArea == 2.0)
+ }
+
+ it("Passed ST_Collect_Agg with null values") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT('POINT(1 2)'),
+ | ST_GeomFromWKT(NULL),
+ | ST_GeomFromWKT('POINT(3 4)')
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("points_with_null_table")
+
+ val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM
points_with_null_table")
+ val result = collectDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+ // Should only have 2 points (nulls are skipped)
+ assert(result.getNumGeometries == 2)
+ }
}
def generateRandomPolygon(index: Int): String = {
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
index 15901c67a1..81bd279c00 100644
---
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
+++
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
@@ -1814,6 +1814,15 @@ class dataFrameAPITestScala extends TestBaseScala {
assert(actualResult == expectedResult)
}
+ it("Passed ST_Collect_Agg") {
+ val baseDf = sparkSession.sql(
+ "SELECT explode(array(ST_GeomFromWKT('POINT (1 2)'),
ST_GeomFromWKT('POINT (3 4)'), ST_GeomFromWKT('POINT (5 6)'))) AS geom")
+ val df = baseDf.select(ST_Collect_Agg("geom"))
+ val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry]
+ assert(actualResult.getGeometryType == "MultiPoint")
+ assert(actualResult.getNumGeometries == 3)
+ }
+
it("Passed ST_LineFromMultiPoint") {
val baseDf = sparkSession.sql(
"SELECT ST_GeomFromWKT('MULTIPOINT((10 40), (40 30), (20 20), (30
10))') AS multipoint")