This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new c49c923e00 [SEDONA-724] Fix RS_ZonalStats and RS_ZonalStatsAll edge 
case bug (#1871)
c49c923e00 is described below

commit c49c923e007964251064f00a2af7abe3fb0238d8
Author: Furqaan Khan <[email protected]>
AuthorDate: Thu Mar 20 15:19:58 2025 -0400

    [SEDONA-724] Fix RS_ZonalStats and RS_ZonalStatsAll edge case bug (#1871)
    
    * fix: RS_ZonalStats and RS_ZonalStatsAll edge case
    
    * fix: spotless
    
    * change NaN to null
    
    * fix spotless
    
    * change all NaNs to nulls
---
 .../sedona/common/raster/RasterBandAccessors.java  | 76 +++++++++++++++-------
 .../common/raster/RasterBandAccessorsTest.java     | 56 ++++++++++++----
 .../org/apache/sedona/sql/rasteralgebraTest.scala  | 18 +++++
 3 files changed, 117 insertions(+), 33 deletions(-)

diff --git 
a/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java 
b/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java
index 7f816d8b52..5e2502e2c9 100644
--- 
a/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java
+++ 
b/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java
@@ -99,7 +99,7 @@ public class RasterBandAccessors {
    * @return An array with all the stats for the region
    * @throws FactoryException
    */
-  public static double[] getZonalStatsAll(
+  public static Double[] getZonalStatsAll(
       GridCoverage2D raster,
       Geometry roi,
       int band,
@@ -114,18 +114,35 @@ public class RasterBandAccessors {
     DescriptiveStatistics stats = (DescriptiveStatistics) objects.get(0);
     double[] pixelData = (double[]) objects.get(1);
 
+    // Shortcut for an edge case where ROI barely intersects with raster's 
extent, but it doesn't
+    // intersect with the centroid of the pixel.
+    // This happens when allTouched parameter is false.
+    if (pixelData.length == 0) {
+      return new Double[] {0.0, null, null, null, null, null, null, null, 
null};
+    }
+
     // order of stats
     // count, sum, mean, median, mode, stddev, variance, min, max
-    double[] result = new double[9];
-    result[0] = stats.getN();
-    result[1] = stats.getSum();
-    result[2] = stats.getMean();
-    result[3] = stats.getPercentile(50);
+    Double[] result = new Double[9];
+    result[0] = (double) stats.getN();
+    if (stats.getN() == 0) {
+      result[1] = null;
+    } else {
+      result[1] = stats.getSum();
+    }
+    double mean = stats.getMean();
+    result[2] = Double.isNaN(mean) ? null : mean;
+    double median = stats.getPercentile(50);
+    result[3] = Double.isNaN(median) ? null : median;
     result[4] = zonalMode(pixelData);
-    result[5] = stats.getStandardDeviation();
-    result[6] = stats.getVariance();
-    result[7] = stats.getMin();
-    result[8] = stats.getMax();
+    double stdDev = stats.getStandardDeviation();
+    result[5] = Double.isNaN(stdDev) ? null : stats.getStandardDeviation();
+    double variance = stats.getVariance();
+    result[6] = Double.isNaN(variance) ? null : variance;
+    double min = stats.getMin();
+    result[7] = Double.isNaN(min) ? null : min;
+    double max = stats.getMax();
+    result[8] = Double.isNaN(max) ? null : max;
 
     return result;
   }
@@ -139,7 +156,7 @@ public class RasterBandAccessors {
    * @return An array with all the stats for the region
    * @throws FactoryException
    */
-  public static double[] getZonalStatsAll(
+  public static Double[] getZonalStatsAll(
       GridCoverage2D raster, Geometry roi, int band, boolean allTouched, 
boolean excludeNoData)
       throws FactoryException {
     return getZonalStatsAll(raster, roi, band, allTouched, excludeNoData, 
true);
@@ -153,7 +170,7 @@ public class RasterBandAccessors {
    * @return An array with all the stats for the region, excludeNoData is set 
to true
    * @throws FactoryException
    */
-  public static double[] getZonalStatsAll(
+  public static Double[] getZonalStatsAll(
       GridCoverage2D raster, Geometry roi, int band, boolean allTouched) 
throws FactoryException {
     return getZonalStatsAll(raster, roi, band, allTouched, true);
   }
@@ -165,7 +182,7 @@ public class RasterBandAccessors {
    * @return An array with all the stats for the region, excludeNoData is set 
to true
    * @throws FactoryException
    */
-  public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, 
int band)
+  public static Double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, 
int band)
       throws FactoryException {
     return getZonalStatsAll(raster, roi, band, false);
   }
@@ -177,7 +194,7 @@ public class RasterBandAccessors {
    *     set to 1
    * @throws FactoryException
    */
-  public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi)
+  public static Double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi)
       throws FactoryException {
     return getZonalStatsAll(raster, roi, 1);
   }
@@ -213,26 +230,36 @@ public class RasterBandAccessors {
 
     switch (statType.toLowerCase()) {
       case "sum":
-        return stats.getSum();
+        if (pixelData.length == 0) {
+          return null;
+        } else {
+          return stats.getSum();
+        }
       case "average":
       case "avg":
       case "mean":
-        return stats.getMean();
+        double mean = stats.getMean();
+        return Double.isNaN(mean) ? null : mean;
       case "count":
         return (double) stats.getN();
       case "max":
-        return stats.getMax();
+        double max = stats.getMax();
+        return Double.isNaN(max) ? null : max;
       case "min":
-        return stats.getMin();
+        double min = stats.getMin();
+        return Double.isNaN(min) ? null : min;
       case "stddev":
       case "sd":
-        return stats.getStandardDeviation();
+        double stdDev = stats.getStandardDeviation();
+        return Double.isNaN(stdDev) ? null : stdDev;
       case "median":
-        return stats.getPercentile(50);
+        double median = stats.getPercentile(50);
+        return Double.isNaN(median) ? null : median;
       case "mode":
         return zonalMode(pixelData);
       case "variance":
-        return stats.getVariance();
+        double variance = stats.getVariance();
+        return Double.isNaN(variance) ? null : variance;
       default:
         throw new IllegalArgumentException(
             "Please select from the accepted options. Some of the valid 
options are sum, mean, stddev, etc.");
@@ -310,8 +337,13 @@ public class RasterBandAccessors {
    * @return Mode of the pixel values. If there is multiple with same 
occurrence, then the largest
    *     value will be returned.
    */
-  private static double zonalMode(double[] pixelData) {
+  private static Double zonalMode(double[] pixelData) {
     double[] modes = StatUtils.mode(pixelData);
+    // Return NaN when ROI and raster's extent overlap, but there's no pixel 
data.
+    // This behavior only happens when allTouched parameter is false.
+    if (modes.length == 0) {
+      return null;
+    }
     return modes[modes.length - 1];
   }
 
diff --git 
a/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java
 
b/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java
index 94986d039a..21f4e446ae 100644
--- 
a/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java
+++ 
b/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java
@@ -28,7 +28,6 @@ import org.junit.Test;
 import org.locationtech.jts.geom.Geometry;
 import org.locationtech.jts.io.ParseException;
 import org.opengis.referencing.FactoryException;
-import org.opengis.referencing.operation.TransformException;
 
 public class RasterBandAccessorsTest extends RasterTestBase {
 
@@ -84,6 +83,31 @@ public class RasterBandAccessorsTest extends RasterTestBase {
     assertEquals("Provided band index 2 is not present in the raster", 
exception.getMessage());
   }
 
+  @Test
+  public void testZonalStatsIntersectingNoPixelData() throws FactoryException, 
ParseException {
+    double[][] pixelsValues =
+        new double[][] {
+          new double[] {
+            3, 7, 5, 40, 61, 70, 60, 80, 27, 55, 35, 44, 21, 36, 53, 54, 86, 
28, 45, 24, 99, 22, 18,
+            98, 10
+          }
+        };
+    GridCoverage2D raster =
+        RasterConstructors.makeNonEmptyRaster(1, "", 5, 5, 1, -1, 1, -1, 0, 0, 
0, pixelsValues);
+    Geometry extent =
+        Constructors.geomFromWKT(
+            "POLYGON ((5.822754 -6.620957, 6.965332 -6.620957, 6.965332 
-5.834616, 5.822754 -5.834616, 5.822754 -6.620957))",
+            0);
+
+    Double actualZonalStats = RasterBandAccessors.getZonalStats(raster, 
extent, "mode");
+    assertNull(actualZonalStats);
+
+    String actualZonalStatsAll =
+        Arrays.toString(RasterBandAccessors.getZonalStatsAll(raster, extent));
+    String expectedZonalStatsAll = "[0.0, null, null, null, null, null, null, 
null, null]";
+    assertEquals(expectedZonalStatsAll, actualZonalStatsAll);
+  }
+
   @Test
   public void testZonalStats() throws FactoryException, ParseException, 
IOException {
     GridCoverage2D raster =
@@ -182,15 +206,17 @@ public class RasterBandAccessorsTest extends 
RasterTestBase {
   }
 
   @Test
-  public void testZonalStatsAll()
-      throws IOException, FactoryException, ParseException, TransformException 
{
+  public void testZonalStatsAll() throws IOException, FactoryException, 
ParseException {
     GridCoverage2D raster =
         rasterFromGeoTiff(resourceFolder + 
"raster_geotiff_color/FAA_UTM18N_NAD83.tif");
     String polygon =
         "POLYGON ((-8673439.6642 4572993.5327, -8673155.5737 4563873.2099, 
-8701890.3259 4562931.7093, -8682522.8735 4572703.8908, -8673439.6642 
4572993.5327))";
     Geometry geom = Constructors.geomFromWKT(polygon, 3857);
 
-    double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, 
false, false, false);
+    double[] actual =
+        Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, 
false, false, false))
+            .mapToDouble(Double::doubleValue)
+            .toArray();
     double[] expected =
         new double[] {
           185953.0,
@@ -209,16 +235,19 @@ public class RasterBandAccessorsTest extends 
RasterTestBase {
         Constructors.geomFromWKT(
             "POLYGON ((-77.96672569800863073 37.91971182746296876, 
-77.9688630154902711 37.89620133516485367, -77.93936803424354309 
37.90517806858776595, -77.96672569800863073 37.91971182746296876))",
             0);
-    actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, 
false, false);
+    actual =
+        Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, 
false, false, false))
+            .mapToDouble(Double::doubleValue)
+            .toArray();
     assertNotNull(actual);
 
     Geometry nonIntersectingGeom =
         Constructors.geomFromWKT(
             "POLYGON ((-78.22106647832458748 37.76411511479908967, 
-78.20183062098976734 37.72863564460374874, -78.18088490966962922 
37.76753482276972562, -78.22106647832458748 37.76411511479908967))",
             0);
-    actual =
+    Double[] actualNull =
         RasterBandAccessors.getZonalStatsAll(raster, nonIntersectingGeom, 1, 
false, false, true);
-    assertNull(actual);
+    assertNull(actualNull);
     assertThrows(
         IllegalArgumentException.class,
         () ->
@@ -227,15 +256,17 @@ public class RasterBandAccessorsTest extends 
RasterTestBase {
   }
 
   @Test
-  public void testZonalStatsAllWithNoData()
-      throws IOException, FactoryException, ParseException, TransformException 
{
+  public void testZonalStatsAllWithNoData() throws IOException, 
FactoryException, ParseException {
     GridCoverage2D raster =
         rasterFromGeoTiff(resourceFolder + 
"raster/raster_with_no_data/test5.tiff");
     String polygon =
         "POLYGON((-167.750000 87.750000, -155.250000 87.750000, -155.250000 
40.250000, -180.250000 40.250000, -167.750000 87.750000))";
     Geometry geom = Constructors.geomFromWKT(polygon, 
RasterAccessors.srid(raster));
 
-    double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, 
false, true);
+    double[] actual =
+        Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, 
false, true))
+            .mapToDouble(Double::doubleValue)
+            .toArray();
     double[] expected =
         new double[] {
           14249.0,
@@ -265,7 +296,10 @@ public class RasterBandAccessorsTest extends 
RasterTestBase {
     // Testing implicit CRS transformation
     Geometry geom = Constructors.geomFromWKT("POLYGON((2 -2, 2 -6, 6 -6, 6 -2, 
2 -2))", 0);
 
-    double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, 
false, true);
+    double[] actual =
+        Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, 
false, true))
+            .mapToDouble(Double::doubleValue)
+            .toArray();
     double[] expected = new double[] {13.0, 114.0, 8.7692, 9.0, 11.0, 4.7285, 
22.3589, 1.0, 16.0};
     assertArrayEquals(expected, actual, FP_TOLERANCE);
   }
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
index 3e77f8f509..5832e54077 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
@@ -1623,6 +1623,24 @@ class rasteralgebraTest extends TestBaseScala with 
BeforeAndAfter with GivenWhen
       assertTrue(expectedSummary4.equals(actualSummary4))
     }
 
+    it("Passed RS_ZonalStats edge case") {
+      val df = sparkSession.sql("""
+          |with data as (
+          | SELECT array(3, 7, 5, 40, 61, 70, 60, 80, 27, 55, 35, 44, 21, 36, 
53, 54, 86, 28, 45, 24, 99, 22, 18, 98, 10) as pixels,
+          |   ST_GeomFromWKT('POLYGON ((5.822754 -6.620957, 6.965332 
-6.620957, 6.965332 -5.834616, 5.822754 -5.834616, 5.822754 -6.620957))', 4326) 
as geom
+          |)
+          |
+          |SELECT RS_SetSRID(RS_AddBandFromArray(RS_MakeEmptyRaster(1, "D", 5, 
5, 1, -1, 1), pixels, 1), 4326) as raster, geom FROM data
+          |""".stripMargin)
+
+      val actual = df.selectExpr("RS_ZonalStats(raster, geom, 1, 
'mode')").first().get(0)
+      assertNull(actual)
+
+      val statsDf = df.selectExpr("RS_ZonalStatsAll(raster, geom) as stats")
+      val actualBoolean = 
statsDf.selectExpr("isNull(stats.mode)").first().getAs[Boolean](0)
+      assertTrue(actualBoolean)
+    }
+
     it("Passed RS_ZonalStats") {
       var df = sparkSession.read
         .format("binaryFile")

Reply via email to