This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch fix-performance-issue-with-weighting in repository https://gitbox.apache.org/repos/asf/sedona.git
commit f012d4200e9706e0a64e0cf1a35ad07c4a6cc5e9 Author: pawelkocinski <[email protected]> AuthorDate: Tue Nov 11 14:22:19 2025 +0100 SEDONA-748 Fix issue with no optimization for weighting function. --- .../scala/org/apache/sedona/stats/Weighting.scala | 50 ++++++++++++++-------- .../org/apache/sedona/stats/WeightingTest.scala | 42 ++++++++++++++++-- 2 files changed, 70 insertions(+), 22 deletions(-) diff --git a/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala b/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala index d404f2c2db..aaad1b007b 100644 --- a/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala +++ b/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala @@ -109,30 +109,42 @@ object Weighting { val formattedDataFrame = dataframe.withColumn(ID_COLUMN, sha2(to_json(struct("*")), 256)) - formattedDataFrame + val spatiallyJoined = formattedDataFrame .alias("l") .join( formattedDataFrame.alias("r"), - joinCondition && col(s"l.$ID_COLUMN") =!= col( - s"r.$ID_COLUMN" - ), // we will add self back later if self.includeSelf + joinCondition && col(s"l.$ID_COLUMN") =!= col(s"r.$ID_COLUMN"), + "inner") + .select(struct("l.*").alias("left"), struct("r.*").alias("right")) + + val mapped = formattedDataFrame + .alias("f") + .join( + spatiallyJoined.alias("s"), + col(s"s.left.$ID_COLUMN") === col(s"f.$ID_COLUMN"), "left") .select( - col(s"l.$ID_COLUMN"), - struct("l.*").alias("left_contents"), - struct( - ( - savedAttributesWithGeom match { - case null => struct(col("r.*")).dropFields(ID_COLUMN) - case _ => - struct(savedAttributesWithGeom.map(c => col(s"r.$c")): _*) - } - ).alias("neighbor"), - if (!binary) - pow(distanceFunction(col(s"l.$geometryColumn"), col(s"r.$geometryColumn")), alpha) - .alias("value") - else lit(1.0).alias("value")).alias("weight")) - .groupBy(s"l.$ID_COLUMN") + col(ID_COLUMN), + struct("f.*").alias("left_contents"), + when(col(ID_COLUMN).isNull, lit(null)) + .otherwise(struct( + ( + savedAttributesWithGeom match { + case null => struct(col("s.right.*")).dropFields(ID_COLUMN) + case _ => + struct(savedAttributesWithGeom.map(c => col(s"s.right.$c")): _*) + } + ).alias("neighbor"), + if (!binary) + pow( + distanceFunction(col(s"s.left.$geometryColumn"), col(s"s.right.$geometryColumn")), + alpha) + .alias("value") + else lit(1.0).alias("value"))) + .alias("weight")) + + mapped + .groupBy(ID_COLUMN) .agg( first("left_contents").alias("left_contents"), concat( diff --git a/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala b/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala index a7a8865dda..4fcdaa654d 100644 --- a/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala @@ -22,6 +22,8 @@ import org.apache.sedona.sql.TestBaseScala import org.apache.spark.sql.sedona_sql.expressions.st_constructors.ST_MakePoint import org.apache.spark.sql.{DataFrame, Row, functions => f} +import java.io.{ByteArrayOutputStream, PrintStream} + class WeightingTest extends TestBaseScala { case class Neighbors(id: Int, neighbor: Seq[Int]) @@ -78,6 +80,9 @@ class WeightingTest extends TestBaseScala { f.col("id"), f.array_sort( f.transform(f.col("weights"), w => w("neighbor")("id")).as("neighbor_ids"))) + + hasOptimizationTurnedOn(actualDf) + val expectedDf = sparkSession.createDataFrame( Seq( Neighbors(0, Seq(1, 3, 5, 7)), @@ -97,6 +102,7 @@ class WeightingTest extends TestBaseScala { it("return empty weights array when no neighbors") { val actualDf = Weighting.addDistanceBandColumn(getData(), .9) + hasOptimizationTurnedOn(actualDf) assert(actualDf.count() == 11) assert(actualDf.filter(f.size(f.col("weights")) > 0).count() == 0) @@ -113,10 +119,12 @@ class WeightingTest extends TestBaseScala { f.col("id"), f.transform(f.col("weights"), w => w("neighbor")("id")).as("neighbor_ids")) + hasOptimizationTurnedOn(actualDfWithZeroDistanceNeighbors) + assertDataFramesEqual( actualDfWithZeroDistanceNeighbors, sparkSession.createDataFrame( - Seq(Neighbors(0, Seq(1, 2)), Neighbors(1, Seq(0, 2)), Neighbors(2, Seq(0, 1))))) + Seq(Neighbors(0, Seq(2, 1)), Neighbors(1, Seq(2, 0)), Neighbors(2, Seq(1, 0))))) val actualDfWithoutZeroDistanceNeighbors = Weighting .addDistanceBandColumn(getDupedData(), 1.1) @@ -127,15 +135,18 @@ class WeightingTest extends TestBaseScala { assertDataFramesEqual( actualDfWithoutZeroDistanceNeighbors, sparkSession.createDataFrame( - Seq(Neighbors(0, Seq(2)), Neighbors(1, Seq(2)), Neighbors(2, Seq(0, 1))))) + Seq(Neighbors(0, Seq(2)), Neighbors(1, Seq(2)), Neighbors(2, Seq(1, 0))))) } it("adds binary weights") { - val result = Weighting.addDistanceBandColumn(getData(), 2.0, geometry = "geometry") val weights = result.select("weights").collect().map(_.getSeq[Row](0)) + hasOptimizationTurnedOn(result) + assert(weights.forall(_.forall(_.getAs[Double]("value") == 1.0))) + + hasOptimizationTurnedOn(result) } it("adds non-binary weights when binary is false") { @@ -148,6 +159,8 @@ class WeightingTest extends TestBaseScala { geometry = "geometry") val weights = result.select("weights").collect().map(_.getSeq[Row](0)) assert(weights.exists(_.exists(_.getAs[Double]("value") != 1.0))) + + hasOptimizationTurnedOn(result) } it("throws IllegalArgumentException when threshold is negative") { @@ -175,4 +188,27 @@ class WeightingTest extends TestBaseScala { } } } + + private def hasOptimizationTurnedOn(result: DataFrame) = { + val sparkPlan = captureStdOut(result.explain()) + + val distanceJoinOptimization = "DistanceJoin" + + val occurrences = + sparkPlan.sliding(distanceJoinOptimization.length).count(_ == distanceJoinOptimization) + + assert(occurrences == 1) + } + + def captureStdOut(block: => Unit): String = { + val stream = new ByteArrayOutputStream() + val ps = new PrintStream(stream) + + Console.withOut(ps) { + block + } + + ps.flush() + stream.toString("UTF-8") + } }
