This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new d1aed8531a [SEDONA-748] Fix issue with no optimization for weighting 
function (#2490)
d1aed8531a is described below

commit d1aed8531a12724705fd62de4f967c54dd4375ba
Author: PaweÅ‚ Tokaj <[email protected]>
AuthorDate: Thu Nov 13 01:22:11 2025 +0100

    [SEDONA-748] Fix issue with no optimization for weighting function (#2490)
---
 .../scala/org/apache/sedona/stats/Weighting.scala  | 30 ++++++++++------
 .../org/apache/sedona/stats/WeightingTest.scala    | 42 ++++++++++++++++++++--
 2 files changed, 59 insertions(+), 13 deletions(-)

diff --git 
a/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala 
b/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala
index d404f2c2db..bd2ac8ed4a 100644
--- a/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala
@@ -109,30 +109,40 @@ object Weighting {
 
     val formattedDataFrame = dataframe.withColumn(ID_COLUMN, 
sha2(to_json(struct("*")), 256))
 
-    formattedDataFrame
+    val spatiallyJoined = formattedDataFrame
       .alias("l")
       .join(
         formattedDataFrame.alias("r"),
-        joinCondition && col(s"l.$ID_COLUMN") =!= col(
-          s"r.$ID_COLUMN"
-        ), // we will add self back later if self.includeSelf
+        joinCondition && col(s"l.$ID_COLUMN") =!= col(s"r.$ID_COLUMN"),
+        "inner")
+      .select(struct("l.*").alias("left"), struct("r.*").alias("right"))
+
+    val mapped = formattedDataFrame
+      .alias("f")
+      .join(
+        spatiallyJoined.alias("s"),
+        col(s"s.left.$ID_COLUMN") === col(s"f.$ID_COLUMN"),
         "left")
       .select(
-        col(s"l.$ID_COLUMN"),
-        struct("l.*").alias("left_contents"),
+        col(ID_COLUMN),
+        struct("f.*").alias("left_contents"),
         struct(
           (
             savedAttributesWithGeom match {
-              case null => struct(col("r.*")).dropFields(ID_COLUMN)
+              case null => struct(col("s.right.*")).dropFields(ID_COLUMN)
               case _ =>
-                struct(savedAttributesWithGeom.map(c => col(s"r.$c")): _*)
+                struct(savedAttributesWithGeom.map(c => col(s"s.right.$c")): 
_*)
             }
           ).alias("neighbor"),
           if (!binary)
-            pow(distanceFunction(col(s"l.$geometryColumn"), 
col(s"r.$geometryColumn")), alpha)
+            pow(
+              distanceFunction(col(s"s.left.$geometryColumn"), 
col(s"s.right.$geometryColumn")),
+              alpha)
               .alias("value")
           else lit(1.0).alias("value")).alias("weight"))
-      .groupBy(s"l.$ID_COLUMN")
+
+    mapped
+      .groupBy(ID_COLUMN)
       .agg(
         first("left_contents").alias("left_contents"),
         concat(
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala 
b/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala
index a7a8865dda..4fcdaa654d 100644
--- a/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala
@@ -22,6 +22,8 @@ import org.apache.sedona.sql.TestBaseScala
 import org.apache.spark.sql.sedona_sql.expressions.st_constructors.ST_MakePoint
 import org.apache.spark.sql.{DataFrame, Row, functions => f}
 
+import java.io.{ByteArrayOutputStream, PrintStream}
+
 class WeightingTest extends TestBaseScala {
 
   case class Neighbors(id: Int, neighbor: Seq[Int])
@@ -78,6 +80,9 @@ class WeightingTest extends TestBaseScala {
           f.col("id"),
           f.array_sort(
             f.transform(f.col("weights"), w => 
w("neighbor")("id")).as("neighbor_ids")))
+
+      hasOptimizationTurnedOn(actualDf)
+
       val expectedDf = sparkSession.createDataFrame(
         Seq(
           Neighbors(0, Seq(1, 3, 5, 7)),
@@ -97,6 +102,7 @@ class WeightingTest extends TestBaseScala {
 
     it("return empty weights array when no neighbors") {
       val actualDf = Weighting.addDistanceBandColumn(getData(), .9)
+      hasOptimizationTurnedOn(actualDf)
 
       assert(actualDf.count() == 11)
       assert(actualDf.filter(f.size(f.col("weights")) > 0).count() == 0)
@@ -113,10 +119,12 @@ class WeightingTest extends TestBaseScala {
           f.col("id"),
           f.transform(f.col("weights"), w => 
w("neighbor")("id")).as("neighbor_ids"))
 
+      hasOptimizationTurnedOn(actualDfWithZeroDistanceNeighbors)
+
       assertDataFramesEqual(
         actualDfWithZeroDistanceNeighbors,
         sparkSession.createDataFrame(
-          Seq(Neighbors(0, Seq(1, 2)), Neighbors(1, Seq(0, 2)), Neighbors(2, 
Seq(0, 1)))))
+          Seq(Neighbors(0, Seq(2, 1)), Neighbors(1, Seq(2, 0)), Neighbors(2, 
Seq(1, 0)))))
 
       val actualDfWithoutZeroDistanceNeighbors = Weighting
         .addDistanceBandColumn(getDupedData(), 1.1)
@@ -127,15 +135,18 @@ class WeightingTest extends TestBaseScala {
       assertDataFramesEqual(
         actualDfWithoutZeroDistanceNeighbors,
         sparkSession.createDataFrame(
-          Seq(Neighbors(0, Seq(2)), Neighbors(1, Seq(2)), Neighbors(2, Seq(0, 
1)))))
+          Seq(Neighbors(0, Seq(2)), Neighbors(1, Seq(2)), Neighbors(2, Seq(1, 
0)))))
 
     }
 
     it("adds binary weights") {
-
       val result = Weighting.addDistanceBandColumn(getData(), 2.0, geometry = 
"geometry")
       val weights = result.select("weights").collect().map(_.getSeq[Row](0))
+      hasOptimizationTurnedOn(result)
+
       assert(weights.forall(_.forall(_.getAs[Double]("value") == 1.0)))
+
+      hasOptimizationTurnedOn(result)
     }
 
     it("adds non-binary weights when binary is false") {
@@ -148,6 +159,8 @@ class WeightingTest extends TestBaseScala {
         geometry = "geometry")
       val weights = result.select("weights").collect().map(_.getSeq[Row](0))
       assert(weights.exists(_.exists(_.getAs[Double]("value") != 1.0)))
+
+      hasOptimizationTurnedOn(result)
     }
 
     it("throws IllegalArgumentException when threshold is negative") {
@@ -175,4 +188,27 @@ class WeightingTest extends TestBaseScala {
       }
     }
   }
+
+  private def hasOptimizationTurnedOn(result: DataFrame) = {
+    val sparkPlan = captureStdOut(result.explain())
+
+    val distanceJoinOptimization = "DistanceJoin"
+
+    val occurrences =
+      sparkPlan.sliding(distanceJoinOptimization.length).count(_ == 
distanceJoinOptimization)
+
+    assert(occurrences == 1)
+  }
+
+  def captureStdOut(block: => Unit): String = {
+    val stream = new ByteArrayOutputStream()
+    val ps = new PrintStream(stream)
+
+    Console.withOut(ps) {
+      block
+    }
+
+    ps.flush()
+    stream.toString("UTF-8")
+  }
 }

Reply via email to