This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new ec7d9a64f3 [GH-2545] Add ST_Collect_Agg aggregate function (#2546)
ec7d9a64f3 is described below

commit ec7d9a64f35a5ab93bf5c5d0b0969144d1dbd5b5
Author: Feng Zhang <[email protected]>
AuthorDate: Thu Dec 4 23:03:03 2025 -0800

    [GH-2545] Add ST_Collect_Agg aggregate function (#2546)
---
 docs/api/sql/AggregateFunction.md                  |  32 ++++++
 python/sedona/spark/sql/st_aggregates.py           |  16 +++
 python/tests/sql/test_aggregate_functions.py       |  74 +++++++++++++
 python/tests/sql/test_dataframe_api.py             |   7 ++
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |   2 +-
 .../expressions/AggregateFunctions.scala           |  40 +++++++
 .../sql/sedona_sql/expressions/st_aggregates.scala |  10 ++
 .../sedona/sql/aggregateFunctionTestScala.scala    | 117 +++++++++++++++++++++
 .../apache/sedona/sql/dataFrameAPITestScala.scala  |   9 ++
 9 files changed, 306 insertions(+), 1 deletion(-)

diff --git a/docs/api/sql/AggregateFunction.md 
b/docs/api/sql/AggregateFunction.md
index d3e4f0a3ab..8918c6d428 100644
--- a/docs/api/sql/AggregateFunction.md
+++ b/docs/api/sql/AggregateFunction.md
@@ -17,6 +17,38 @@
  under the License.
  -->
 
+## ST_Collect_Agg
+
+Introduction: Collects all geometries in a geometry column into a single 
multi-geometry (MultiPoint, MultiLineString, MultiPolygon, or 
GeometryCollection). Unlike `ST_Union_Aggr`, this function does not dissolve 
boundaries between geometries - it simply collects them into a multi-geometry.
+
+Format: `ST_Collect_Agg (A: geometryColumn)`
+
+Since: `v1.8.1`
+
+SQL Example
+
+```sql
+SELECT ST_Collect_Agg(geom) FROM (
+  SELECT ST_GeomFromWKT('POINT(1 2)') AS geom
+  UNION ALL
+  SELECT ST_GeomFromWKT('POINT(3 4)') AS geom
+  UNION ALL
+  SELECT ST_GeomFromWKT('POINT(5 6)') AS geom
+)
+```
+
+Output:
+
+```
+MULTIPOINT ((1 2), (3 4), (5 6))
+```
+
+SQL Example with GROUP BY
+
+```sql
+SELECT category, ST_Collect_Agg(geom) FROM geometries GROUP BY category
+```
+
 ## ST_Envelope_Aggr
 
 Introduction: Return the entire envelope boundary of all geometries in A
diff --git a/python/sedona/spark/sql/st_aggregates.py 
b/python/sedona/spark/sql/st_aggregates.py
index d9117e94f1..ec20e64307 100644
--- a/python/sedona/spark/sql/st_aggregates.py
+++ b/python/sedona/spark/sql/st_aggregates.py
@@ -65,6 +65,22 @@ def ST_Union_Aggr(geometry: ColumnOrName) -> Column:
     return _call_aggregate_function("ST_Union_Aggr", geometry)
 
 
+@validate_argument_types
+def ST_Collect_Agg(geometry: ColumnOrName) -> Column:
+    """Aggregate Function: Collect all geometries into a multi-geometry.
+
+    Unlike ST_Union_Aggr, this function does not dissolve boundaries between 
geometries.
+    It simply collects all geometries into a MultiPoint, MultiLineString, 
MultiPolygon,
+    or GeometryCollection based on the input geometry types.
+
+    :param geometry: Geometry column to aggregate.
+    :type geometry: ColumnOrName
+    :return: Multi-geometry representing the collection of all geometries in 
the column.
+    :rtype: Column
+    """
+    return _call_aggregate_function("ST_Collect_Agg", geometry)
+
+
 # Automatically populate __all__
 __all__ = [
     name
diff --git a/python/tests/sql/test_aggregate_functions.py 
b/python/tests/sql/test_aggregate_functions.py
index 335ca5687f..4df0e3e230 100644
--- a/python/tests/sql/test_aggregate_functions.py
+++ b/python/tests/sql/test_aggregate_functions.py
@@ -71,3 +71,77 @@ class TestConstructors(TestBase):
         )
 
         assert union.take(1)[0][0].area == 10100
+
+    def test_st_collect_aggr_points(self):
+        self.spark.sql(
+            """
+            SELECT explode(array(
+              ST_GeomFromWKT('POINT(1 2)'),
+              ST_GeomFromWKT('POINT(3 4)'),
+              ST_GeomFromWKT('POINT(5 6)')
+            )) AS geom
+            """
+        ).createOrReplaceTempView("points_table")
+
+        result = self.spark.sql("SELECT ST_Collect_Agg(geom) FROM 
points_table").take(
+            1
+        )[0][0]
+
+        assert result.geom_type == "MultiPoint"
+        assert len(result.geoms) == 3
+
+    def test_st_collect_aggr_polygons(self):
+        self.spark.sql(
+            """
+            SELECT explode(array(
+              ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'),
+              ST_GeomFromWKT('POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))')
+            )) AS geom
+            """
+        ).createOrReplaceTempView("polygons_table")
+
+        result = self.spark.sql("SELECT ST_Collect_Agg(geom) FROM 
polygons_table").take(
+            1
+        )[0][0]
+
+        assert result.geom_type == "MultiPolygon"
+        assert len(result.geoms) == 2
+        assert result.area == 2.0
+
+    def test_st_collect_aggr_mixed_types(self):
+        self.spark.sql(
+            """
+            SELECT explode(array(
+              ST_GeomFromWKT('POINT(1 2)'),
+              ST_GeomFromWKT('LINESTRING(0 0, 1 1)'),
+              ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))')
+            )) AS geom
+            """
+        ).createOrReplaceTempView("mixed_geom_table")
+
+        result = self.spark.sql(
+            "SELECT ST_Collect_Agg(geom) FROM mixed_geom_table"
+        ).take(1)[0][0]
+
+        assert result.geom_type == "GeometryCollection"
+        assert len(result.geoms) == 3
+
+    def test_st_collect_aggr_preserves_duplicates(self):
+        # Test that ST_Collect_Agg keeps duplicate geometries (unlike 
ST_Union_Aggr)
+        self.spark.sql(
+            """
+            SELECT explode(array(
+              ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'),
+              ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))')
+            )) AS geom
+            """
+        ).createOrReplaceTempView("duplicate_polygons_table")
+
+        result = self.spark.sql(
+            "SELECT ST_Collect_Agg(geom) FROM duplicate_polygons_table"
+        ).take(1)[0][0]
+
+        # ST_Collect_Agg should preserve both polygons
+        assert len(result.geoms) == 2
+        # Area should be 2 because it doesn't merge overlapping areas
+        assert result.area == 2.0
diff --git a/python/tests/sql/test_dataframe_api.py 
b/python/tests/sql/test_dataframe_api.py
index 9683fc3029..939411b5c1 100644
--- a/python/tests/sql/test_dataframe_api.py
+++ b/python/tests/sql/test_dataframe_api.py
@@ -1223,6 +1223,13 @@ test_configurations = [
         "",
         "POLYGON ((0 0, 0 1, 1 1, 2 1, 2 0, 1 0, 0 0))",
     ),
+    (
+        sta.ST_Collect_Agg,
+        ("geom",),
+        "exploded_points",
+        "",
+        "MULTIPOINT ((0 0), (1 1))",
+    ),
 ]
 
 wrong_type_configurations = [
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala 
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index f79c3762db..e584e666e3 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -356,7 +356,7 @@ object Catalog extends AbstractCatalog with Logging {
     function[ST_GeomToGeography]()) ++ geoStatsFunctions()
 
   val aggregateExpressions: Seq[Aggregator[Geometry, _, _]] =
-    Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr, new ST_Union_Aggr())
+    Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr, new ST_Union_Aggr(), 
new ST_Collect_Agg())
 
   private def geoStatsFunctions(): Seq[FunctionDescription] = {
     // Try loading geostats functions. Return a seq of geo-stats functions. If 
any error occurs,
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
index c7724211ee..fc0cab6260 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
@@ -18,6 +18,7 @@
  */
 package org.apache.spark.sql.sedona_sql.expressions
 
+import org.apache.sedona.common.Functions
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.expressions.Aggregator
 import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
@@ -191,3 +192,42 @@ private[apache] class ST_Intersection_Aggr
     else buffer1.intersection(buffer2)
   }
 }
+
+/**
+ * Return a multi-geometry collection of all geometries in the given column. 
Unlike ST_Union_Aggr,
+ * this function does not dissolve boundaries between geometries.
+ */
+private[apache] class ST_Collect_Agg
+    extends Aggregator[Geometry, ListBuffer[Geometry], Geometry] {
+
+  val serde = ExpressionEncoder[Geometry]()
+  val bufferSerde = ExpressionEncoder[ListBuffer[Geometry]]()
+
+  override def reduce(buffer: ListBuffer[Geometry], input: Geometry): 
ListBuffer[Geometry] = {
+    if (input != null) {
+      buffer += input
+    }
+    buffer
+  }
+
+  override def merge(
+      buffer1: ListBuffer[Geometry],
+      buffer2: ListBuffer[Geometry]): ListBuffer[Geometry] = {
+    buffer1 ++= buffer2
+    buffer1
+  }
+
+  override def finish(reduction: ListBuffer[Geometry]): Geometry = {
+    if (reduction.isEmpty) {
+      new GeometryFactory().createGeometryCollection()
+    } else {
+      Functions.createMultiGeometry(reduction.toArray)
+    }
+  }
+
+  def bufferEncoder: ExpressionEncoder[ListBuffer[Geometry]] = bufferSerde
+
+  def outputEncoder: ExpressionEncoder[Geometry] = serde
+
+  override def zero: ListBuffer[Geometry] = ListBuffer.empty
+}
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala
index ae2ac145be..fe0dc8d714 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala
@@ -52,4 +52,14 @@ object st_aggregates {
     val aggrFunc = udaf(new ST_Union_Aggr)
     aggrFunc(col(geometry))
   }
+
+  def ST_Collect_Agg(geometry: Column): Column = {
+    val aggrFunc = udaf(new ST_Collect_Agg)
+    aggrFunc(geometry)
+  }
+
+  def ST_Collect_Agg(geometry: String): Column = {
+    val aggrFunc = udaf(new ST_Collect_Agg)
+    aggrFunc(col(geometry))
+  }
 }
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
 
b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
index 9608cbcbe3..911769e2ac 100644
--- 
a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
+++ 
b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
@@ -128,6 +128,123 @@ class aggregateFunctionTestScala extends TestBaseScala {
 
       
assertResult(0.0)(intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry].getArea)
     }
+
+    it("Passed ST_Collect_Agg with points") {
+      sparkSession
+        .sql("""
+          |SELECT explode(array(
+          |  ST_GeomFromWKT('POINT(1 2)'),
+          |  ST_GeomFromWKT('POINT(3 4)'),
+          |  ST_GeomFromWKT('POINT(5 6)')
+          |)) AS geom
+        """.stripMargin)
+        .createOrReplaceTempView("points_table")
+
+      val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM 
points_table")
+      val result = collectDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+      assert(result.getGeometryType == "MultiPoint")
+      assert(result.getNumGeometries == 3)
+    }
+
+    it("Passed ST_Collect_Agg with polygons") {
+      sparkSession
+        .sql("""
+          |SELECT explode(array(
+          |  ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'),
+          |  ST_GeomFromWKT('POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))')
+          |)) AS geom
+        """.stripMargin)
+        .createOrReplaceTempView("polygons_table")
+
+      val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM 
polygons_table")
+      val result = collectDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+      assert(result.getGeometryType == "MultiPolygon")
+      assert(result.getNumGeometries == 2)
+      // Total area should be 2 (each polygon has area 1)
+      assert(result.getArea == 2.0)
+    }
+
+    it("Passed ST_Collect_Agg with mixed geometry types") {
+      sparkSession
+        .sql("""
+          |SELECT explode(array(
+          |  ST_GeomFromWKT('POINT(1 2)'),
+          |  ST_GeomFromWKT('LINESTRING(0 0, 1 1)'),
+          |  ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))')
+          |)) AS geom
+        """.stripMargin)
+        .createOrReplaceTempView("mixed_geom_table")
+
+      val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM 
mixed_geom_table")
+      val result = collectDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+      assert(result.getGeometryType == "GeometryCollection")
+      assert(result.getNumGeometries == 3)
+    }
+
+    it("Passed ST_Collect_Agg with GROUP BY") {
+      sparkSession
+        .sql("""
+          |SELECT * FROM (VALUES
+          |  (1, ST_GeomFromWKT('POINT(1 2)')),
+          |  (1, ST_GeomFromWKT('POINT(3 4)')),
+          |  (2, ST_GeomFromWKT('POINT(5 6)')),
+          |  (2, ST_GeomFromWKT('POINT(7 8)')),
+          |  (2, ST_GeomFromWKT('POINT(9 10)'))
+          |) AS t(group_id, geom)
+        """.stripMargin)
+        .createOrReplaceTempView("grouped_points_table")
+
+      val collectDF = sparkSession.sql(
+        "SELECT group_id, ST_Collect_Agg(geom) as collected FROM 
grouped_points_table GROUP BY group_id ORDER BY group_id")
+      val results = collectDF.collect()
+
+      // Group 1 should have 2 points
+      assert(results(0).getAs[Geometry]("collected").getNumGeometries == 2)
+      // Group 2 should have 3 points
+      assert(results(1).getAs[Geometry]("collected").getNumGeometries == 3)
+    }
+
+    it("Passed ST_Collect_Agg preserves duplicates unlike ST_Union_Aggr") {
+      // Test that ST_Collect_Agg keeps duplicate geometries (unlike 
ST_Union_Aggr which merges them)
+      sparkSession
+        .sql("""
+          |SELECT explode(array(
+          |  ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'),
+          |  ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))')
+          |)) AS geom
+        """.stripMargin)
+        .createOrReplaceTempView("duplicate_polygons_table")
+
+      val collectDF =
+        sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM 
duplicate_polygons_table")
+      val result = collectDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+      // ST_Collect_Agg should preserve both polygons
+      assert(result.getNumGeometries == 2)
+      // Area should be 2 because it doesn't merge overlapping areas
+      assert(result.getArea == 2.0)
+    }
+
+    it("Passed ST_Collect_Agg with null values") {
+      sparkSession
+        .sql("""
+          |SELECT explode(array(
+          |  ST_GeomFromWKT('POINT(1 2)'),
+          |  ST_GeomFromWKT(NULL),
+          |  ST_GeomFromWKT('POINT(3 4)')
+          |)) AS geom
+        """.stripMargin)
+        .createOrReplaceTempView("points_with_null_table")
+
+      val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM 
points_with_null_table")
+      val result = collectDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+      // Should only have 2 points (nulls are skipped)
+      assert(result.getNumGeometries == 2)
+    }
   }
 
   def generateRandomPolygon(index: Int): String = {
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
index 15901c67a1..81bd279c00 100644
--- 
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
+++ 
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
@@ -1814,6 +1814,15 @@ class dataFrameAPITestScala extends TestBaseScala {
       assert(actualResult == expectedResult)
     }
 
+    it("Passed ST_Collect_Agg") {
+      val baseDf = sparkSession.sql(
+        "SELECT explode(array(ST_GeomFromWKT('POINT (1 2)'), 
ST_GeomFromWKT('POINT (3 4)'), ST_GeomFromWKT('POINT (5 6)'))) AS geom")
+      val df = baseDf.select(ST_Collect_Agg("geom"))
+      val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry]
+      assert(actualResult.getGeometryType == "MultiPoint")
+      assert(actualResult.getNumGeometries == 3)
+    }
+
     it("Passed ST_LineFromMultiPoint") {
       val baseDf = sparkSession.sql(
         "SELECT ST_GeomFromWKT('MULTIPOINT((10 40), (40 30), (20 20), (30 
10))') AS multipoint")

Reply via email to