This is an automated email from the ASF dual-hosted git repository. jiayu pushed a commit to branch prepare-1.7.2 in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 1e8b70ffdf09f1b7fe58488b48e6630f9ca056bd Author: Furqaan Khan <[email protected]> AuthorDate: Thu Mar 20 15:19:58 2025 -0400 [SEDONA-724] Fix RS_ZonalStats and RS_ZonalStatsAll edge case bug (#1871) * fix: RS_ZonalStats and RS_ZonalStatsAll edge case * fix: spotless * change NaN to null * fix spotless * change all NaNs to nulls --- .../sedona/common/raster/RasterBandAccessors.java | 76 +++++++++++++++------- .../common/raster/RasterBandAccessorsTest.java | 56 ++++++++++++---- .../org/apache/sedona/sql/rasteralgebraTest.scala | 18 +++++ 3 files changed, 117 insertions(+), 33 deletions(-) diff --git a/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java b/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java index b843ce924c..f5011d5a1c 100644 --- a/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java +++ b/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java @@ -99,7 +99,7 @@ public class RasterBandAccessors { * @return An array with all the stats for the region * @throws FactoryException */ - public static double[] getZonalStatsAll( + public static Double[] getZonalStatsAll( GridCoverage2D raster, Geometry roi, int band, @@ -114,18 +114,35 @@ public class RasterBandAccessors { DescriptiveStatistics stats = (DescriptiveStatistics) objects.get(0); double[] pixelData = (double[]) objects.get(1); + // Shortcut for an edge case where ROI barely intersects with raster's extent, but it doesn't + // intersect with the centroid of the pixel. + // This happens when allTouched parameter is false. + if (pixelData.length == 0) { + return new Double[] {0.0, null, null, null, null, null, null, null, null}; + } + // order of stats // count, sum, mean, median, mode, stddev, variance, min, max - double[] result = new double[9]; - result[0] = stats.getN(); - result[1] = stats.getSum(); - result[2] = stats.getMean(); - result[3] = stats.getPercentile(50); + Double[] result = new Double[9]; + result[0] = (double) stats.getN(); + if (stats.getN() == 0) { + result[1] = null; + } else { + result[1] = stats.getSum(); + } + double mean = stats.getMean(); + result[2] = Double.isNaN(mean) ? null : mean; + double median = stats.getPercentile(50); + result[3] = Double.isNaN(median) ? null : median; result[4] = zonalMode(pixelData); - result[5] = stats.getStandardDeviation(); - result[6] = stats.getVariance(); - result[7] = stats.getMin(); - result[8] = stats.getMax(); + double stdDev = stats.getStandardDeviation(); + result[5] = Double.isNaN(stdDev) ? null : stats.getStandardDeviation(); + double variance = stats.getVariance(); + result[6] = Double.isNaN(variance) ? null : variance; + double min = stats.getMin(); + result[7] = Double.isNaN(min) ? null : min; + double max = stats.getMax(); + result[8] = Double.isNaN(max) ? null : max; return result; } @@ -139,7 +156,7 @@ public class RasterBandAccessors { * @return An array with all the stats for the region * @throws FactoryException */ - public static double[] getZonalStatsAll( + public static Double[] getZonalStatsAll( GridCoverage2D raster, Geometry roi, int band, boolean allTouched, boolean excludeNoData) throws FactoryException { return getZonalStatsAll(raster, roi, band, allTouched, excludeNoData, true); @@ -153,7 +170,7 @@ public class RasterBandAccessors { * @return An array with all the stats for the region, excludeNoData is set to true * @throws FactoryException */ - public static double[] getZonalStatsAll( + public static Double[] getZonalStatsAll( GridCoverage2D raster, Geometry roi, int band, boolean allTouched) throws FactoryException { return getZonalStatsAll(raster, roi, band, allTouched, true); } @@ -165,7 +182,7 @@ public class RasterBandAccessors { * @return An array with all the stats for the region, excludeNoData is set to true * @throws FactoryException */ - public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band) + public static Double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band) throws FactoryException { return getZonalStatsAll(raster, roi, band, false); } @@ -177,7 +194,7 @@ public class RasterBandAccessors { * set to 1 * @throws FactoryException */ - public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi) + public static Double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi) throws FactoryException { return getZonalStatsAll(raster, roi, 1); } @@ -213,26 +230,36 @@ public class RasterBandAccessors { switch (statType.toLowerCase()) { case "sum": - return stats.getSum(); + if (pixelData.length == 0) { + return null; + } else { + return stats.getSum(); + } case "average": case "avg": case "mean": - return stats.getMean(); + double mean = stats.getMean(); + return Double.isNaN(mean) ? null : mean; case "count": return (double) stats.getN(); case "max": - return stats.getMax(); + double max = stats.getMax(); + return Double.isNaN(max) ? null : max; case "min": - return stats.getMin(); + double min = stats.getMin(); + return Double.isNaN(min) ? null : min; case "stddev": case "sd": - return stats.getStandardDeviation(); + double stdDev = stats.getStandardDeviation(); + return Double.isNaN(stdDev) ? null : stdDev; case "median": - return stats.getPercentile(50); + double median = stats.getPercentile(50); + return Double.isNaN(median) ? null : median; case "mode": return zonalMode(pixelData); case "variance": - return stats.getVariance(); + double variance = stats.getVariance(); + return Double.isNaN(variance) ? null : variance; default: throw new IllegalArgumentException( "Please select from the accepted options. Some of the valid options are sum, mean, stddev, etc."); @@ -310,8 +337,13 @@ public class RasterBandAccessors { * @return Mode of the pixel values. If there is multiple with same occurrence, then the largest * value will be returned. */ - private static double zonalMode(double[] pixelData) { + private static Double zonalMode(double[] pixelData) { double[] modes = StatUtils.mode(pixelData); + // Return NaN when ROI and raster's extent overlap, but there's no pixel data. + // This behavior only happens when allTouched parameter is false. + if (modes.length == 0) { + return null; + } return modes[modes.length - 1]; } diff --git a/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java b/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java index 94986d039a..21f4e446ae 100644 --- a/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java +++ b/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java @@ -28,7 +28,6 @@ import org.junit.Test; import org.locationtech.jts.geom.Geometry; import org.locationtech.jts.io.ParseException; import org.opengis.referencing.FactoryException; -import org.opengis.referencing.operation.TransformException; public class RasterBandAccessorsTest extends RasterTestBase { @@ -84,6 +83,31 @@ public class RasterBandAccessorsTest extends RasterTestBase { assertEquals("Provided band index 2 is not present in the raster", exception.getMessage()); } + @Test + public void testZonalStatsIntersectingNoPixelData() throws FactoryException, ParseException { + double[][] pixelsValues = + new double[][] { + new double[] { + 3, 7, 5, 40, 61, 70, 60, 80, 27, 55, 35, 44, 21, 36, 53, 54, 86, 28, 45, 24, 99, 22, 18, + 98, 10 + } + }; + GridCoverage2D raster = + RasterConstructors.makeNonEmptyRaster(1, "", 5, 5, 1, -1, 1, -1, 0, 0, 0, pixelsValues); + Geometry extent = + Constructors.geomFromWKT( + "POLYGON ((5.822754 -6.620957, 6.965332 -6.620957, 6.965332 -5.834616, 5.822754 -5.834616, 5.822754 -6.620957))", + 0); + + Double actualZonalStats = RasterBandAccessors.getZonalStats(raster, extent, "mode"); + assertNull(actualZonalStats); + + String actualZonalStatsAll = + Arrays.toString(RasterBandAccessors.getZonalStatsAll(raster, extent)); + String expectedZonalStatsAll = "[0.0, null, null, null, null, null, null, null, null]"; + assertEquals(expectedZonalStatsAll, actualZonalStatsAll); + } + @Test public void testZonalStats() throws FactoryException, ParseException, IOException { GridCoverage2D raster = @@ -182,15 +206,17 @@ public class RasterBandAccessorsTest extends RasterTestBase { } @Test - public void testZonalStatsAll() - throws IOException, FactoryException, ParseException, TransformException { + public void testZonalStatsAll() throws IOException, FactoryException, ParseException { GridCoverage2D raster = rasterFromGeoTiff(resourceFolder + "raster_geotiff_color/FAA_UTM18N_NAD83.tif"); String polygon = "POLYGON ((-8673439.6642 4572993.5327, -8673155.5737 4563873.2099, -8701890.3259 4562931.7093, -8682522.8735 4572703.8908, -8673439.6642 4572993.5327))"; Geometry geom = Constructors.geomFromWKT(polygon, 3857); - double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, false, false); + double[] actual = + Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, false, false)) + .mapToDouble(Double::doubleValue) + .toArray(); double[] expected = new double[] { 185953.0, @@ -209,16 +235,19 @@ public class RasterBandAccessorsTest extends RasterTestBase { Constructors.geomFromWKT( "POLYGON ((-77.96672569800863073 37.91971182746296876, -77.9688630154902711 37.89620133516485367, -77.93936803424354309 37.90517806858776595, -77.96672569800863073 37.91971182746296876))", 0); - actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, false, false); + actual = + Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, false, false)) + .mapToDouble(Double::doubleValue) + .toArray(); assertNotNull(actual); Geometry nonIntersectingGeom = Constructors.geomFromWKT( "POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))", 0); - actual = + Double[] actualNull = RasterBandAccessors.getZonalStatsAll(raster, nonIntersectingGeom, 1, false, false, true); - assertNull(actual); + assertNull(actualNull); assertThrows( IllegalArgumentException.class, () -> @@ -227,15 +256,17 @@ public class RasterBandAccessorsTest extends RasterTestBase { } @Test - public void testZonalStatsAllWithNoData() - throws IOException, FactoryException, ParseException, TransformException { + public void testZonalStatsAllWithNoData() throws IOException, FactoryException, ParseException { GridCoverage2D raster = rasterFromGeoTiff(resourceFolder + "raster/raster_with_no_data/test5.tiff"); String polygon = "POLYGON((-167.750000 87.750000, -155.250000 87.750000, -155.250000 40.250000, -180.250000 40.250000, -167.750000 87.750000))"; Geometry geom = Constructors.geomFromWKT(polygon, RasterAccessors.srid(raster)); - double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, true); + double[] actual = + Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, true)) + .mapToDouble(Double::doubleValue) + .toArray(); double[] expected = new double[] { 14249.0, @@ -265,7 +296,10 @@ public class RasterBandAccessorsTest extends RasterTestBase { // Testing implicit CRS transformation Geometry geom = Constructors.geomFromWKT("POLYGON((2 -2, 2 -6, 6 -6, 6 -2, 2 -2))", 0); - double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, true); + double[] actual = + Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, true)) + .mapToDouble(Double::doubleValue) + .toArray(); double[] expected = new double[] {13.0, 114.0, 8.7692, 9.0, 11.0, 4.7285, 22.3589, 1.0, 16.0}; assertArrayEquals(expected, actual, FP_TOLERANCE); } diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala index 3e77f8f509..5832e54077 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala @@ -1623,6 +1623,24 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen assertTrue(expectedSummary4.equals(actualSummary4)) } + it("Passed RS_ZonalStats edge case") { + val df = sparkSession.sql(""" + |with data as ( + | SELECT array(3, 7, 5, 40, 61, 70, 60, 80, 27, 55, 35, 44, 21, 36, 53, 54, 86, 28, 45, 24, 99, 22, 18, 98, 10) as pixels, + | ST_GeomFromWKT('POLYGON ((5.822754 -6.620957, 6.965332 -6.620957, 6.965332 -5.834616, 5.822754 -5.834616, 5.822754 -6.620957))', 4326) as geom + |) + | + |SELECT RS_SetSRID(RS_AddBandFromArray(RS_MakeEmptyRaster(1, "D", 5, 5, 1, -1, 1), pixels, 1), 4326) as raster, geom FROM data + |""".stripMargin) + + val actual = df.selectExpr("RS_ZonalStats(raster, geom, 1, 'mode')").first().get(0) + assertNull(actual) + + val statsDf = df.selectExpr("RS_ZonalStatsAll(raster, geom) as stats") + val actualBoolean = statsDf.selectExpr("isNull(stats.mode)").first().getAs[Boolean](0) + assertTrue(actualBoolean) + } + it("Passed RS_ZonalStats") { var df = sparkSession.read .format("binaryFile")
