This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 9497a075c9 [SEDONA-688] Verify KNN parameter K must be equal or larger 
than 1 (#1739)
9497a075c9 is described below

commit 9497a075c9c8517ddc2223c0deb92c43c5bfbdde
Author: Feng Zhang <[email protected]>
AuthorDate: Fri Jan 3 18:39:21 2025 -0800

    [SEDONA-688] Verify KNN parameter K must be equal or larger than 1 (#1739)
---
 .../strategy/join/BroadcastObjectSideKNNJoinExec.scala   |  2 +-
 .../strategy/join/BroadcastQuerySideKNNJoinExec.scala    |  2 +-
 .../sql/sedona_sql/strategy/join/JoinQueryDetector.scala |  8 ++++++++
 .../spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala |  2 +-
 .../test/scala/org/apache/sedona/sql/KnnJoinSuite.scala  | 16 ++++++++++++++++
 5 files changed, 27 insertions(+), 3 deletions(-)

diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
index 1b21c79e7c..c5777be3c1 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
@@ -120,7 +120,7 @@ case class BroadcastObjectSideKNNJoinExec(
       sedonaConf: SedonaConf): Unit = {
     require(numPartitions > 0, "The number of partitions must be greater than 
0.")
     val kValue: Int = this.k.eval().asInstanceOf[Int]
-    require(kValue > 0, "The number of neighbors must be greater than 0.")
+    require(kValue >= 1, "The number of neighbors (k) must be equal or greater 
than 1.")
     objectsShapes.setNeighborSampleNumber(kValue)
     broadcastJoin = true
   }
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
index 812bc6e6d6..001c0a1ca3 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
@@ -127,7 +127,7 @@ case class BroadcastQuerySideKNNJoinExec(
       sedonaConf: SedonaConf): Unit = {
     require(numPartitions > 0, "The number of partitions must be greater than 
0.")
     val kValue: Int = this.k.eval().asInstanceOf[Int]
-    require(kValue > 0, "The number of neighbors must be greater than 0.")
+    require(kValue >= 1, "The number of neighbors (k) must be equal or greater 
than 1.")
     objectsShapes.setNeighborSampleNumber(kValue)
 
     val joinPartitions: Integer = numPartitions
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index 825855b88c..da9bd5359b 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -582,6 +582,10 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
       return Nil
     }
 
+    // validate the k value
+    val kValue: Int = distance.eval().asInstanceOf[Int]
+    require(kValue >= 1, "The number of neighbors (k) must be equal or greater 
than 1.")
+
     val leftShape = children.head
     val rightShape = children.tail.head
 
@@ -711,6 +715,10 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
 
     if (spatialPredicate == SpatialPredicate.KNN) {
       {
+        // validate the k value for KNN join
+        val kValue: Int = distance.get.eval().asInstanceOf[Int]
+        require(kValue >= 1, "The number of neighbors (k) must be equal or 
greater than 1.")
+
         val leftShape = children.head
         val rightShape = children.tail.head
 
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
index 2b9bbfb50b..fdc53d13ce 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
@@ -162,7 +162,7 @@ case class KNNJoinExec(
       sedonaConf: SedonaConf): Unit = {
     require(numPartitions > 0, "The number of partitions must be greater than 
0.")
     val kValue: Int = this.k.eval().asInstanceOf[Int]
-    require(kValue > 0, "The number of neighbors must be greater than 0.")
+    require(kValue >= 1, "The number of neighbors (k) must be equal or greater 
than 1.")
     objectsShapes.setNeighborSampleNumber(kValue)
 
     exactSpatialPartitioning(objectsShapes, queryShapes, numPartitions)
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala
index 1d6119d02d..f3b07c2501 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala
@@ -209,6 +209,22 @@ class KnnJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
         "[1,3][1,6][1,13][1,16][2,1][2,5][2,11][2,15][3,3][3,9][3,13][3,19]")
     }
 
+    it("KNN Join should verify the correct parameter k is passed to the join 
function") {
+      val df = sparkSession
+        .range(0, 1)
+        .toDF("id")
+        .withColumn("geom", expr("ST_Point(id, id)"))
+        .repartition(1)
+      df.createOrReplaceTempView("df1")
+      val exception = intercept[IllegalArgumentException] {
+        sparkSession
+          .sql(s"SELECT A.ID, B.ID FROM df1 A JOIN df1 B ON ST_KNN(A.GEOM, 
B.GEOM, 0, false)")
+          .collect()
+      }
+      exception.getMessage should include(
+        "The number of neighbors (k) must be equal or greater than 1.")
+    }
+
     it("KNN Join with exact algorithms with additional join conditions on id") 
{
       val df = sparkSession.sql(
         s"SELECT QUERIES.ID, OBJECTS.ID FROM QUERIES JOIN OBJECTS ON 
ST_KNN(QUERIES.GEOM, OBJECTS.GEOM, 4, false) AND QUERIES.ID > 1")

Reply via email to