This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new fa32f44ab6 [GH-2565] Fix NULL handling for various aggregation 
functions in SedonaSpark (#2563)
fa32f44ab6 is described below

commit fa32f44ab6cd3580a871798ff0aab7a12374d0aa
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Thu Dec 18 02:28:21 2025 +0800

    [GH-2565] Fix NULL handling for various aggregation functions in 
SedonaSpark (#2563)
---
 .../expressions/AggregateFunctions.scala           | 170 ++++++++++-----------
 .../sedona/sql/aggregateFunctionTestScala.scala    | 149 ++++++++++++++++++
 2 files changed, 234 insertions(+), 85 deletions(-)

diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
index fc0cab6260..ca169a2598 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
@@ -19,9 +19,10 @@
 package org.apache.spark.sql.sedona_sql.expressions
 
 import org.apache.sedona.common.Functions
+import org.apache.spark.sql.{Encoder, Encoders}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.expressions.Aggregator
-import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
+import org.locationtech.jts.geom.{Coordinate, Envelope, Geometry, 
GeometryFactory}
 import org.locationtech.jts.operation.overlayng.OverlayNGRobust
 
 import scala.collection.JavaConverters._
@@ -32,18 +33,7 @@ import scala.collection.mutable.ListBuffer
  */
 
 trait TraitSTAggregateExec {
-  val initialGeometry: Geometry = {
-    // dummy value for initial value(polygon but )
-    // any other value is ok.
-    val coordinates: Array[Coordinate] = new Array[Coordinate](5)
-    coordinates(0) = new Coordinate(-999999999, -999999999)
-    coordinates(1) = new Coordinate(-999999999, -999999999)
-    coordinates(2) = new Coordinate(-999999999, -999999999)
-    coordinates(3) = new Coordinate(-999999999, -999999999)
-    coordinates(4) = coordinates(0)
-    val geometryFactory = new GeometryFactory()
-    geometryFactory.createPolygon(coordinates)
-  }
+  val initialGeometry: Geometry = null
   val serde = ExpressionEncoder[Geometry]()
 
   def zero: Geometry = initialGeometry
@@ -62,7 +52,9 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
   val bufferSerde = ExpressionEncoder[ListBuffer[Geometry]]()
 
   override def reduce(buffer: ListBuffer[Geometry], input: Geometry): 
ListBuffer[Geometry] = {
-    buffer += input
+    if (input != null) {
+      buffer += input
+    }
     if (buffer.size >= bufferSize) {
       // Perform the union when buffer size is reached
       val unionGeometry = OverlayNGRobust.union(buffer.asJava)
@@ -86,6 +78,9 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
   }
 
   override def finish(reduction: ListBuffer[Geometry]): Geometry = {
+    if (reduction.isEmpty) {
+      return null
+    }
     OverlayNGRobust.union(reduction.asJava)
   }
 
@@ -97,81 +92,76 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
 }
 
 /**
- * Return the envelope boundary of the entire column
+ * A helper class to store envelope boundary during aggregation. We use this 
custom case class
+ * instead of JTS Envelope to work with the Spark Encoder.
  */
-private[apache] class ST_Envelope_Aggr
-    extends Aggregator[Geometry, Geometry, Geometry]
-    with TraitSTAggregateExec {
+case class EnvelopeBuffer(minX: Double, maxX: Double, minY: Double, maxY: 
Double) {
+  def isNull: Boolean = minX > maxX
 
-  def reduce(buffer: Geometry, input: Geometry): Geometry = {
-    val accumulateEnvelope = buffer.getEnvelopeInternal
-    val newEnvelope = input.getEnvelopeInternal
-    val coordinates: Array[Coordinate] = new Array[Coordinate](5)
-    var minX = 0.0
-    var minY = 0.0
-    var maxX = 0.0
-    var maxY = 0.0
-    if (accumulateEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
-      // Found the accumulateEnvelope is the initial value
-      minX = newEnvelope.getMinX
-      minY = newEnvelope.getMinY
-      maxX = newEnvelope.getMaxX
-      maxY = newEnvelope.getMaxY
-    } else if (newEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
-      minX = accumulateEnvelope.getMinX
-      minY = accumulateEnvelope.getMinY
-      maxX = accumulateEnvelope.getMaxX
-      maxY = accumulateEnvelope.getMaxY
+  def toEnvelope: Envelope = {
+    if (isNull) {
+      new Envelope()
     } else {
-      minX = Math.min(accumulateEnvelope.getMinX, newEnvelope.getMinX)
-      minY = Math.min(accumulateEnvelope.getMinY, newEnvelope.getMinY)
-      maxX = Math.max(accumulateEnvelope.getMaxX, newEnvelope.getMaxX)
-      maxY = Math.max(accumulateEnvelope.getMaxY, newEnvelope.getMaxY)
+      new Envelope(minX, maxX, minY, maxY)
     }
-    coordinates(0) = new Coordinate(minX, minY)
-    coordinates(1) = new Coordinate(minX, maxY)
-    coordinates(2) = new Coordinate(maxX, maxY)
-    coordinates(3) = new Coordinate(maxX, minY)
-    coordinates(4) = coordinates(0)
-    val geometryFactory = new GeometryFactory()
-    geometryFactory.createPolygon(coordinates)
-
   }
 
-  def merge(buffer1: Geometry, buffer2: Geometry): Geometry = {
-    val leftEnvelope = buffer1.getEnvelopeInternal
-    val rightEnvelope = buffer2.getEnvelopeInternal
-    val coordinates: Array[Coordinate] = new Array[Coordinate](5)
-    var minX = 0.0
-    var minY = 0.0
-    var maxX = 0.0
-    var maxY = 0.0
-    if (leftEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
-      minX = rightEnvelope.getMinX
-      minY = rightEnvelope.getMinY
-      maxX = rightEnvelope.getMaxX
-      maxY = rightEnvelope.getMaxY
-    } else if (rightEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
-      minX = leftEnvelope.getMinX
-      minY = leftEnvelope.getMinY
-      maxX = leftEnvelope.getMaxX
-      maxY = leftEnvelope.getMaxY
+  def merge(other: EnvelopeBuffer): EnvelopeBuffer = {
+    if (this.isNull) {
+      other
+    } else if (other.isNull) {
+      this
     } else {
-      minX = Math.min(leftEnvelope.getMinX, rightEnvelope.getMinX)
-      minY = Math.min(leftEnvelope.getMinY, rightEnvelope.getMinY)
-      maxX = Math.max(leftEnvelope.getMaxX, rightEnvelope.getMaxX)
-      maxY = Math.max(leftEnvelope.getMaxY, rightEnvelope.getMaxY)
+      EnvelopeBuffer(
+        math.min(this.minX, other.minX),
+        math.max(this.maxX, other.maxX),
+        math.min(this.minY, other.minY),
+        math.max(this.maxY, other.maxY))
     }
+  }
+}
 
-    coordinates(0) = new Coordinate(minX, minY)
-    coordinates(1) = new Coordinate(minX, maxY)
-    coordinates(2) = new Coordinate(maxX, maxY)
-    coordinates(3) = new Coordinate(maxX, minY)
-    coordinates(4) = coordinates(0)
-    val geometryFactory = new GeometryFactory()
-    geometryFactory.createPolygon(coordinates)
+/**
+ * Return the envelope boundary of the entire column
+ */
+private[apache] class ST_Envelope_Aggr
+    extends Aggregator[Geometry, Option[EnvelopeBuffer], Geometry] {
+
+  val serde = ExpressionEncoder[Geometry]()
+
+  def reduce(buffer: Option[EnvelopeBuffer], input: Geometry): 
Option[EnvelopeBuffer] = {
+    if (input == null) return buffer
+    val env = input.getEnvelopeInternal
+    val envBuffer = EnvelopeBuffer(env.getMinX, env.getMaxX, env.getMinY, 
env.getMaxY)
+    buffer match {
+      case Some(b) => Some(b.merge(envBuffer))
+      case None => Some(envBuffer)
+    }
+  }
+
+  def merge(
+      buffer1: Option[EnvelopeBuffer],
+      buffer2: Option[EnvelopeBuffer]): Option[EnvelopeBuffer] = {
+    (buffer1, buffer2) match {
+      case (Some(b1), Some(b2)) => Some(b1.merge(b2))
+      case (Some(_), None) => buffer1
+      case (None, Some(_)) => buffer2
+      case (None, None) => None
+    }
+  }
+
+  def finish(reduction: Option[EnvelopeBuffer]): Geometry = {
+    reduction match {
+      case Some(b) => new GeometryFactory().toGeometry(b.toEnvelope)
+      case None => null
+    }
   }
 
+  def bufferEncoder: Encoder[Option[EnvelopeBuffer]] = 
Encoders.product[Option[EnvelopeBuffer]]
+
+  def outputEncoder: ExpressionEncoder[Geometry] = serde
+
+  def zero: Option[EnvelopeBuffer] = None
 }
 
 /**
@@ -181,16 +171,26 @@ private[apache] class ST_Intersection_Aggr
     extends Aggregator[Geometry, Geometry, Geometry]
     with TraitSTAggregateExec {
   def reduce(buffer: Geometry, input: Geometry): Geometry = {
-    if (buffer.isEmpty) input
-    else if (buffer.equalsExact(initialGeometry)) input
-    else buffer.intersection(input)
+    if (input == null) {
+      return buffer
+    }
+    if (buffer == null) {
+      return input
+    }
+    buffer.intersection(input)
   }
 
   def merge(buffer1: Geometry, buffer2: Geometry): Geometry = {
-    if (buffer1.equalsExact(initialGeometry)) buffer2
-    else if (buffer2.equalsExact(initialGeometry)) buffer1
-    else buffer1.intersection(buffer2)
+    if (buffer1 == null) {
+      return buffer2
+    }
+    if (buffer2 == null) {
+      return buffer1
+    }
+    buffer1.intersection(buffer2)
   }
+
+  override def finish(out: Geometry): Geometry = out
 }
 
 /**
@@ -219,7 +219,7 @@ private[apache] class ST_Collect_Agg
 
   override def finish(reduction: ListBuffer[Geometry]): Geometry = {
     if (reduction.isEmpty) {
-      new GeometryFactory().createGeometryCollection()
+      null
     } else {
       Functions.createMultiGeometry(reduction.toArray)
     }
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
 
b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
index 911769e2ac..4485f9fcfe 100644
--- 
a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
+++ 
b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
@@ -245,6 +245,155 @@ class aggregateFunctionTestScala extends TestBaseScala {
       // Should only have 2 points (nulls are skipped)
       assert(result.getNumGeometries == 2)
     }
+
+    it("ST_Union_Aggr should handle null values") {
+      sparkSession
+        .sql("""
+          |SELECT explode(array(
+          |  ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'),
+          |  ST_GeomFromWKT(NULL),
+          |  ST_GeomFromWKT('POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))')
+          |)) AS geom
+        """.stripMargin)
+        .createOrReplaceTempView("polygons_with_null_for_union")
+
+      val unionDF =
+        sparkSession.sql("SELECT ST_Union_Aggr(geom) FROM 
polygons_with_null_for_union")
+      val result = unionDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+      // Should union the 2 non-null polygons (total area = 2.0)
+      assert(result.getArea == 2.0)
+    }
+
+    it("ST_Envelope_Aggr should handle null values") {
+      sparkSession
+        .sql("""
+          |SELECT explode(array(
+          |  ST_GeomFromWKT('POINT(1 2)'),
+          |  ST_GeomFromWKT(NULL),
+          |  ST_GeomFromWKT('POINT(3 4)')
+          |)) AS geom
+        """.stripMargin)
+        .createOrReplaceTempView("points_with_null_for_envelope")
+
+      val envelopeDF =
+        sparkSession.sql("SELECT ST_Envelope_Aggr(geom) FROM 
points_with_null_for_envelope")
+      val result = envelopeDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+      // Should create envelope from the 2 non-null points
+      assert(result.getGeometryType == "Polygon")
+      val envelope = result.getEnvelopeInternal
+      assert(envelope.getMinX == 1.0)
+      assert(envelope.getMinY == 2.0)
+      assert(envelope.getMaxX == 3.0)
+      assert(envelope.getMaxY == 4.0)
+    }
+
+    it("ST_Intersection_Aggr should handle null values") {
+      sparkSession
+        .sql("""
+          |SELECT explode(array(
+          |  ST_GeomFromWKT('POLYGON((0 0, 4 0, 4 4, 0 4, 0 0))'),
+          |  ST_GeomFromWKT(NULL),
+          |  ST_GeomFromWKT('POLYGON((2 2, 6 2, 6 6, 2 6, 2 2))')
+          |)) AS geom
+        """.stripMargin)
+        .createOrReplaceTempView("polygons_with_null_for_intersection")
+
+      val intersectionDF = sparkSession.sql(
+        "SELECT ST_Intersection_Aggr(geom) FROM 
polygons_with_null_for_intersection")
+      val result = intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+      // Should intersect the 2 non-null polygons (intersection area = 4.0)
+      assert(result.getArea == 4.0)
+    }
+
+    it("ST_Union_Aggr should return null if all inputs are null") {
+      sparkSession
+        .sql("""
+          |SELECT explode(array(
+          |  ST_GeomFromWKT(NULL),
+          |  ST_GeomFromWKT(NULL)
+          |)) AS geom
+        """.stripMargin)
+        .createOrReplaceTempView("all_null_union")
+
+      val unionDF = sparkSession.sql("SELECT ST_Union_Aggr(geom) FROM 
all_null_union")
+      val result = unionDF.take(1)(0).get(0)
+
+      assert(result == null)
+    }
+
+    it("ST_Envelope_Aggr should return null if all inputs are null") {
+      sparkSession
+        .sql("""
+          |SELECT explode(array(
+          |  ST_GeomFromWKT(NULL),
+          |  ST_GeomFromWKT(NULL)
+          |)) AS geom
+        """.stripMargin)
+        .createOrReplaceTempView("all_null_envelope")
+
+      val envelopeDF = sparkSession.sql("SELECT ST_Envelope_Aggr(geom) FROM 
all_null_envelope")
+      val result = envelopeDF.take(1)(0).get(0)
+
+      assert(result == null)
+    }
+
+    it("ST_Intersection_Aggr should return null if all inputs are null") {
+      sparkSession
+        .sql("""
+          |SELECT explode(array(
+          |  ST_GeomFromWKT(NULL),
+          |  ST_GeomFromWKT(NULL)
+          |)) AS geom
+        """.stripMargin)
+        .createOrReplaceTempView("all_null_intersection")
+
+      val intersectionDF =
+        sparkSession.sql("SELECT ST_Intersection_Aggr(geom) FROM 
all_null_intersection")
+      val result = intersectionDF.take(1)(0).get(0)
+
+      assert(result == null)
+    }
+
+    it("ST_Collect_Agg should return null if all inputs are null") {
+      sparkSession
+        .sql("""
+          |SELECT explode(array(
+          |  ST_GeomFromWKT(NULL),
+          |  ST_GeomFromWKT(NULL)
+          |)) AS geom
+        """.stripMargin)
+        .createOrReplaceTempView("all_null_collect")
+
+      val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM 
all_null_collect")
+      val result = collectDF.take(1)(0).get(0)
+
+      assert(result == null)
+    }
+
+    it(
+      "ST_Envelope_Aggr should return empty geometry if inputs are mixed with 
null and empty geometries") {
+      sparkSession
+        .sql("""
+          |SELECT explode(array(
+          |  NULL,
+          |  NULL,
+          |  ST_GeomFromWKT('POINT EMPTY'),
+          |  NULL,
+          |  ST_GeomFromWKT('POLYGON EMPTY')
+          |)) AS geom
+        """.stripMargin)
+        .createOrReplaceTempView("mixed_null_empty_envelope")
+
+      val envelopeDF =
+        sparkSession.sql("SELECT ST_Envelope_Aggr(geom) FROM 
mixed_null_empty_envelope")
+      val result = envelopeDF.take(1)(0).get(0)
+
+      assert(result != null)
+      assert(result.asInstanceOf[Geometry].isEmpty)
+    }
   }
 
   def generateRandomPolygon(index: Int): String = {

Reply via email to