Copilot commented on code in PR #2546:
URL: https://github.com/apache/sedona/pull/2546#discussion_r2591628672
##########
spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala:
##########
@@ -128,6 +128,123 @@ class aggregateFunctionTestScala extends TestBaseScala {
assertResult(0.0)(intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry].getArea)
}
+
+ it("Passed ST_Collect_Agg with points") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT('POINT(1 2)'),
+ | ST_GeomFromWKT('POINT(3 4)'),
+ | ST_GeomFromWKT('POINT(5 6)')
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("points_table")
+
+ val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM
points_table")
+ val result = collectDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+ assert(result.getGeometryType == "MultiPoint")
+ assert(result.getNumGeometries == 3)
+ }
+
+ it("Passed ST_Collect_Agg with polygons") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'),
+ | ST_GeomFromWKT('POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))')
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("polygons_table")
+
+ val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM
polygons_table")
+ val result = collectDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+ assert(result.getGeometryType == "MultiPolygon")
+ assert(result.getNumGeometries == 2)
+ // Total area should be 2 (each polygon has area 1)
+ assert(result.getArea == 2.0)
+ }
+
+ it("Passed ST_Collect_Agg with mixed geometry types") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT('POINT(1 2)'),
+ | ST_GeomFromWKT('LINESTRING(0 0, 1 1)'),
+ | ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))')
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("mixed_geom_table")
+
+ val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM
mixed_geom_table")
+ val result = collectDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+ assert(result.getGeometryType == "GeometryCollection")
+ assert(result.getNumGeometries == 3)
+ }
+
+ it("Passed ST_Collect_Agg with GROUP BY") {
+ sparkSession
+ .sql("""
+ |SELECT * FROM (VALUES
+ | (1, ST_GeomFromWKT('POINT(1 2)')),
+ | (1, ST_GeomFromWKT('POINT(3 4)')),
+ | (2, ST_GeomFromWKT('POINT(5 6)')),
+ | (2, ST_GeomFromWKT('POINT(7 8)')),
+ | (2, ST_GeomFromWKT('POINT(9 10)'))
+ |) AS t(group_id, geom)
+ """.stripMargin)
+ .createOrReplaceTempView("grouped_points_table")
+
+ val collectDF = sparkSession.sql(
+ "SELECT group_id, ST_Collect_Agg(geom) as collected FROM
grouped_points_table GROUP BY group_id ORDER BY group_id")
+ val results = collectDF.collect()
+
+ // Group 1 should have 2 points
+ assert(results(0).getAs[Geometry]("collected").getNumGeometries == 2)
+ // Group 2 should have 3 points
+ assert(results(1).getAs[Geometry]("collected").getNumGeometries == 3)
+ }
+
+ it("Passed ST_Collect_Agg preserves duplicates unlike ST_Union_Aggr") {
Review Comment:
Corrected spelling of 'ST_Union_Aggr' to 'ST_Union_Agg' in test name to
match the actual function name used elsewhere in the codebase.
```suggestion
it("Passed ST_Collect_Agg preserves duplicates unlike ST_Union_Agg") {
```
##########
python/tests/sql/test_aggregate_functions.py:
##########
@@ -71,3 +71,77 @@ def test_st_union_aggr(self):
)
assert union.take(1)[0][0].area == 10100
+
+ def test_st_collect_aggr_points(self):
+ self.spark.sql(
+ """
+ SELECT explode(array(
+ ST_GeomFromWKT('POINT(1 2)'),
+ ST_GeomFromWKT('POINT(3 4)'),
+ ST_GeomFromWKT('POINT(5 6)')
+ )) AS geom
+ """
+ ).createOrReplaceTempView("points_table")
+
+ result = self.spark.sql("SELECT ST_Collect_Agg(geom) FROM
points_table").take(
+ 1
+ )[0][0]
+
+ assert result.geom_type == "MultiPoint"
+ assert len(result.geoms) == 3
+
+ def test_st_collect_aggr_polygons(self):
+ self.spark.sql(
+ """
+ SELECT explode(array(
+ ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'),
+ ST_GeomFromWKT('POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))')
+ )) AS geom
+ """
+ ).createOrReplaceTempView("polygons_table")
+
+ result = self.spark.sql("SELECT ST_Collect_Agg(geom) FROM
polygons_table").take(
+ 1
+ )[0][0]
+
+ assert result.geom_type == "MultiPolygon"
+ assert len(result.geoms) == 2
+ assert result.area == 2.0
+
+ def test_st_collect_aggr_mixed_types(self):
+ self.spark.sql(
+ """
+ SELECT explode(array(
+ ST_GeomFromWKT('POINT(1 2)'),
+ ST_GeomFromWKT('LINESTRING(0 0, 1 1)'),
+ ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))')
+ )) AS geom
+ """
+ ).createOrReplaceTempView("mixed_geom_table")
+
+ result = self.spark.sql(
+ "SELECT ST_Collect_Agg(geom) FROM mixed_geom_table"
+ ).take(1)[0][0]
+
+ assert result.geom_type == "GeometryCollection"
+ assert len(result.geoms) == 3
+
+ def test_st_collect_aggr_preserves_duplicates(self):
+ # Test that ST_Collect_Agg keeps duplicate geometries (unlike
ST_Union_Aggr)
Review Comment:
Corrected spelling of 'ST_Union_Aggr' to 'ST_Union_Agg' in comment to match
the actual function name.
```suggestion
# Test that ST_Collect_Agg keeps duplicate geometries (unlike
ST_Union_Agg)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]