This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 1fd3b86518 [SEDONA-690] Set default metric to use Haversine for KNN 
join and code refactoring (#1909)
1fd3b86518 is described below

commit 1fd3b86518d97c56a72f4f31a4dcfb67a6d55496
Author: Feng Zhang <[email protected]>
AuthorDate: Thu Apr 10 09:52:58 2025 -0700

    [SEDONA-690] Set default metric to use Haversine for KNN join and code 
refactoring (#1909)
    
    * [SEDONA-690] Set default metric to use Haversine for KNN join and some 
code refactor
    
    * fix unit tests
    
    * clean up join params
---
 .../joinJudgement/InMemoryKNNJoinIterator.java     | 155 ++++++++++++++++
 .../core/joinJudgement/KnnJoinIndexJudgement.java  | 200 +++++++--------------
 .../sedona/core/spatialOperator/JoinQuery.java     | 169 ++++++++---------
 .../join/BroadcastObjectSideKNNJoinExec.scala      |   4 +-
 .../join/BroadcastQuerySideKNNJoinExec.scala       |   4 +-
 .../sql/sedona_sql/strategy/join/KNNJoinExec.scala |   4 +-
 .../scala/org/apache/sedona/sql/KnnJoinSuite.scala |  28 +++
 7 files changed, 329 insertions(+), 235 deletions(-)

diff --git 
a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/InMemoryKNNJoinIterator.java
 
b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/InMemoryKNNJoinIterator.java
new file mode 100644
index 0000000000..54ba42485e
--- /dev/null
+++ 
b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/InMemoryKNNJoinIterator.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.core.joinJudgement;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.NoSuchElementException;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sedona.core.enums.DistanceMetric;
+import org.apache.sedona.core.wrapper.UniqueGeometry;
+import org.apache.spark.util.LongAccumulator;
+import org.locationtech.jts.geom.Envelope;
+import org.locationtech.jts.geom.Geometry;
+import org.locationtech.jts.index.strtree.ItemDistance;
+import org.locationtech.jts.index.strtree.STRtree;
+
+public class InMemoryKNNJoinIterator<T extends Geometry, U extends Geometry>
+    implements Iterator<Pair<T, U>> {
+  private final Iterator<T> querySideIterator;
+  private final STRtree strTree;
+
+  private final int k;
+  private final DistanceMetric distanceMetric;
+  private final boolean includeTies;
+  private final ItemDistance itemDistance;
+
+  private final LongAccumulator streamCount;
+  private final LongAccumulator resultCount;
+
+  private final List<Pair<T, U>> currentResults = new ArrayList<>();
+  private int currentResultIndex = 0;
+
+  public InMemoryKNNJoinIterator(
+      Iterator<T> querySideIterator,
+      STRtree strTree,
+      int k,
+      DistanceMetric distanceMetric,
+      boolean includeTies,
+      LongAccumulator streamCount,
+      LongAccumulator resultCount) {
+    this.querySideIterator = querySideIterator;
+    this.strTree = strTree;
+
+    this.k = k;
+    this.distanceMetric = distanceMetric;
+    this.includeTies = includeTies;
+    this.itemDistance = KnnJoinIndexJudgement.getItemDistance(distanceMetric);
+
+    this.streamCount = streamCount;
+    this.resultCount = resultCount;
+  }
+
+  @Override
+  public boolean hasNext() {
+    if (currentResultIndex < currentResults.size()) {
+      return true;
+    }
+
+    currentResultIndex = 0;
+    currentResults.clear();
+    while (querySideIterator.hasNext()) {
+      populateNextBatch();
+      if (!currentResults.isEmpty()) {
+        return true;
+      }
+    }
+
+    return false;
+  }
+
+  @Override
+  public Pair<T, U> next() {
+    if (!hasNext()) {
+      throw new NoSuchElementException();
+    }
+
+    return currentResults.get(currentResultIndex++);
+  }
+
+  private void populateNextBatch() {
+    T queryItem = querySideIterator.next();
+    Geometry queryGeom;
+    if (queryItem instanceof UniqueGeometry) {
+      queryGeom = (Geometry) ((UniqueGeometry<?>) 
queryItem).getOriginalGeometry();
+    } else {
+      queryGeom = queryItem;
+    }
+    streamCount.add(1);
+
+    Object[] localK =
+        strTree.nearestNeighbour(queryGeom.getEnvelopeInternal(), queryGeom, 
itemDistance, k);
+    if (includeTies) {
+      localK = getUpdatedLocalKWithTies(queryGeom, localK, strTree);
+    }
+
+    for (Object obj : localK) {
+      U candidate = (U) obj;
+      Pair<T, U> pair = Pair.of(queryItem, candidate);
+      currentResults.add(pair);
+      resultCount.add(1);
+    }
+  }
+
+  private Object[] getUpdatedLocalKWithTies(
+      Geometry streamShape, Object[] localK, STRtree strTree) {
+    Envelope searchEnvelope = streamShape.getEnvelopeInternal();
+    // get the maximum distance from the k nearest neighbors
+    double maxDistance = 0.0;
+    LinkedHashSet<U> uniqueCandidates = new LinkedHashSet<>();
+    for (Object obj : localK) {
+      U candidate = (U) obj;
+      uniqueCandidates.add(candidate);
+      double distance = streamShape.distance(candidate);
+      if (distance > maxDistance) {
+        maxDistance = distance;
+      }
+    }
+    searchEnvelope.expandBy(maxDistance);
+    List<U> candidates = strTree.query(searchEnvelope);
+    if (!candidates.isEmpty()) {
+      // update localK with all candidates that are within the maxDistance
+      List<Object> tiedResults = new ArrayList<>();
+      // add all localK
+      Collections.addAll(tiedResults, localK);
+
+      for (U candidate : candidates) {
+        double distance = streamShape.distance(candidate);
+        if (distance == maxDistance && !uniqueCandidates.contains(candidate)) {
+          tiedResults.add(candidate);
+        }
+      }
+      localK = tiedResults.toArray();
+    }
+    return localK;
+  }
+}
diff --git 
a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
 
b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
index f5375009ed..0dda586986 100644
--- 
a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
+++ 
b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
@@ -19,19 +19,18 @@
 package org.apache.sedona.core.joinJudgement;
 
 import java.io.Serializable;
-import java.util.*;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sedona.core.enums.DistanceMetric;
 import org.apache.sedona.core.knnJudgement.EuclideanItemDistance;
 import org.apache.sedona.core.knnJudgement.HaversineItemDistance;
 import org.apache.sedona.core.knnJudgement.SpheroidDistance;
-import org.apache.sedona.core.wrapper.UniqueGeometry;
 import org.apache.spark.api.java.function.FlatMapFunction2;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.util.LongAccumulator;
-import org.locationtech.jts.geom.Envelope;
 import org.locationtech.jts.geom.Geometry;
-import org.locationtech.jts.index.SpatialIndex;
 import org.locationtech.jts.index.strtree.GeometryItemDistance;
 import org.locationtech.jts.index.strtree.ItemDistance;
 import org.locationtech.jts.index.strtree.STRtree;
@@ -45,19 +44,17 @@ import org.locationtech.jts.index.strtree.STRtree;
  */
 public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry>
     extends JudgementBase<T, U>
-    implements FlatMapFunction2<Iterator<T>, Iterator<SpatialIndex>, Pair<U, 
T>>, Serializable {
+    implements FlatMapFunction2<Iterator<T>, Iterator<U>, Pair<T, U>>, 
Serializable {
   private final int k;
-  private final Double searchRadius;
   private final DistanceMetric distanceMetric;
   private final boolean includeTies;
-  private final Broadcast<List> broadcastQueryObjects;
+  private final Broadcast<List<T>> broadcastQueryObjects;
   private final Broadcast<STRtree> broadcastObjectsTreeIndex;
 
   /**
    * Constructor for the KnnJoinIndexJudgement class.
    *
    * @param k the number of nearest neighbors to find
-   * @param searchRadius
    * @param distanceMetric the distance metric to use
    * @param broadcastQueryObjects the broadcast geometries on queries
    * @param broadcastObjectsTreeIndex the broadcast spatial index on objects
@@ -68,10 +65,9 @@ public class KnnJoinIndexJudgement<T extends Geometry, U 
extends Geometry>
    */
   public KnnJoinIndexJudgement(
       int k,
-      Double searchRadius,
       DistanceMetric distanceMetric,
       boolean includeTies,
-      Broadcast<List> broadcastQueryObjects,
+      Broadcast<List<T>> broadcastQueryObjects,
       Broadcast<STRtree> broadcastObjectsTreeIndex,
       LongAccumulator buildCount,
       LongAccumulator streamCount,
@@ -79,7 +75,6 @@ public class KnnJoinIndexJudgement<T extends Geometry, U 
extends Geometry>
       LongAccumulator candidateCount) {
     super(null, buildCount, streamCount, resultCount, candidateCount);
     this.k = k;
-    this.searchRadius = searchRadius;
     this.distanceMetric = distanceMetric;
     this.includeTies = includeTies;
     this.broadcastQueryObjects = broadcastQueryObjects;
@@ -91,15 +86,15 @@ public class KnnJoinIndexJudgement<T extends Geometry, U 
extends Geometry>
    * and uses the spatial index to find the k nearest neighbors for each 
geometry. The method
    * returns an iterator over the join results.
    *
-   * @param streamShapes iterator over the geometries in the stream side
-   * @param treeIndexes iterator over the spatial indexes
+   * @param queryShapes iterator over the geometries in the query side
+   * @param objectShapes iterator over the geometries in the object side
    * @return an iterator over the join results
    * @throws Exception if the spatial index is not of type STRtree
    */
   @Override
-  public Iterator<Pair<U, T>> call(Iterator<T> streamShapes, 
Iterator<SpatialIndex> treeIndexes)
+  public Iterator<Pair<T, U>> call(Iterator<T> queryShapes, Iterator<U> 
objectShapes)
       throws Exception {
-    if (!treeIndexes.hasNext() || (streamShapes != null && 
!streamShapes.hasNext())) {
+    if (!objectShapes.hasNext() || (queryShapes != null && 
!queryShapes.hasNext())) {
       buildCount.add(0);
       streamCount.add(0);
       resultCount.add(0);
@@ -107,91 +102,64 @@ public class KnnJoinIndexJudgement<T extends Geometry, U 
extends Geometry>
       return Collections.emptyIterator();
     }
 
-    STRtree strTree;
-    if (broadcastObjectsTreeIndex != null) {
-      // get the broadcast spatial index on objects side if available
-      strTree = broadcastObjectsTreeIndex.getValue();
-    } else {
-      // get the spatial index from the iterator
-      SpatialIndex treeIndex = treeIndexes.next();
-      if (!(treeIndex instanceof STRtree)) {
-        throw new Exception(
-            "[KnnJoinIndexJudgement][Call] Only STRtree index supports KNN 
search.");
-      }
-      strTree = (STRtree) treeIndex;
-    }
-
-    // TODO: For future improvement, instead of using a list to store the 
results,
-    // we can use lazy evaluation to avoid storing all the results in memory.
-    List<Pair<U, T>> result = new ArrayList<>();
-
-    List queryItems;
-    if (broadcastQueryObjects != null) {
-      // get the broadcast spatial index on queries side if available
-      queryItems = broadcastQueryObjects.getValue();
-      for (Object item : queryItems) {
-        T queryGeom;
-        if (item instanceof UniqueGeometry) {
-          queryGeom = (T) ((UniqueGeometry) item).getOriginalGeometry();
-        } else {
-          queryGeom = (T) item;
-        }
-        streamCount.add(1);
-
-        Object[] localK =
-            strTree.nearestNeighbour(
-                queryGeom.getEnvelopeInternal(), queryGeom, getItemDistance(), 
k);
-        if (includeTies) {
-          localK = getUpdatedLocalKWithTies(queryGeom, localK, strTree);
-        }
-        if (searchRadius != null) {
-          localK = getInSearchRadius(localK, queryGeom);
-        }
+    STRtree strTree = buildSTRtree(objectShapes);
+    return new InMemoryKNNJoinIterator<>(
+        queryShapes, strTree, k, distanceMetric, includeTies, streamCount, 
resultCount);
+  }
 
-        for (Object obj : localK) {
-          T candidate = (T) obj;
-          Pair<U, T> pair = Pair.of((U) item, candidate);
-          result.add(pair);
-          resultCount.add(1);
-        }
-      }
-      return result.iterator();
-    } else {
-      while (streamShapes.hasNext()) {
-        T streamShape = streamShapes.next();
-        streamCount.add(1);
+  /**
+   * This method performs the KNN join operation using the broadcast spatial 
index built using all
+   * geometries in the object side.
+   *
+   * @param queryShapes iterator over the geometries in the query side
+   * @return an iterator over the join results
+   */
+  public Iterator<Pair<T, U>> callUsingBroadcastObjectIndex(Iterator<T> 
queryShapes) {
+    if (!queryShapes.hasNext()) {
+      buildCount.add(0);
+      streamCount.add(0);
+      resultCount.add(0);
+      candidateCount.add(0);
+      return Collections.emptyIterator();
+    }
 
-        Object[] localK =
-            strTree.nearestNeighbour(
-                streamShape.getEnvelopeInternal(), streamShape, 
getItemDistance(), k);
-        if (includeTies) {
-          localK = getUpdatedLocalKWithTies(streamShape, localK, strTree);
-        }
-        if (searchRadius != null) {
-          localK = getInSearchRadius(localK, streamShape);
-        }
+    // There's no need to use external spatial index, since the object side is 
small enough to be
+    // broadcasted, the STRtree built from the broadcasted object should be 
able to fit into memory.
+    STRtree strTree = broadcastObjectsTreeIndex.getValue();
+    return new InMemoryKNNJoinIterator<>(
+        queryShapes, strTree, k, distanceMetric, includeTies, streamCount, 
resultCount);
+  }
 
-        for (Object obj : localK) {
-          T candidate = (T) obj;
-          Pair<U, T> pair = Pair.of((U) streamShape, candidate);
-          result.add(pair);
-          resultCount.add(1);
-        }
-      }
-      return result.iterator();
+  /**
+   * This method performs the KNN join operation using the broadcast query 
geometries.
+   *
+   * @param objectShapes iterator over the geometries in the object side
+   * @return an iterator over the join results
+   */
+  public Iterator<Pair<T, U>> callUsingBroadcastQueryList(Iterator<U> 
objectShapes) {
+    if (!objectShapes.hasNext()) {
+      buildCount.add(0);
+      streamCount.add(0);
+      resultCount.add(0);
+      candidateCount.add(0);
+      return Collections.emptyIterator();
     }
+
+    List<T> queryItems = broadcastQueryObjects.getValue();
+    STRtree strTree = buildSTRtree(objectShapes);
+    return new InMemoryKNNJoinIterator<>(
+        queryItems.iterator(), strTree, k, distanceMetric, includeTies, 
streamCount, resultCount);
   }
 
-  private Object[] getInSearchRadius(Object[] localK, T queryGeom) {
-    localK =
-        Arrays.stream(localK)
-            .filter(
-                candidate -> {
-                  Geometry candidateGeom = (Geometry) candidate;
-                  return distanceByMetric(queryGeom, candidateGeom, 
distanceMetric) <= searchRadius;
-                })
-            .toArray();
-    return localK;
+  private STRtree buildSTRtree(Iterator<U> objectShapes) {
+    STRtree strTree = new STRtree();
+    while (objectShapes.hasNext()) {
+      U spatialObject = objectShapes.next();
+      strTree.insert(spatialObject.getEnvelopeInternal(), spatialObject);
+      buildCount.add(1);
+    }
+    strTree.build();
+    return strTree;
   }
 
   /**
@@ -219,12 +187,6 @@ public class KnnJoinIndexJudgement<T extends Geometry, U 
extends Geometry>
     }
   }
 
-  private ItemDistance getItemDistance() {
-    ItemDistance itemDistance;
-    itemDistance = getItemDistanceByMetric(distanceMetric);
-    return itemDistance;
-  }
-
   /**
    * This method returns the ItemDistance object based on the specified 
distance metric.
    *
@@ -250,38 +212,6 @@ public class KnnJoinIndexJudgement<T extends Geometry, U 
extends Geometry>
     return itemDistance;
   }
 
-  private Object[] getUpdatedLocalKWithTies(T streamShape, Object[] localK, 
STRtree strTree) {
-    Envelope searchEnvelope = streamShape.getEnvelopeInternal();
-    // get the maximum distance from the k nearest neighbors
-    double maxDistance = 0.0;
-    LinkedHashSet<T> uniqueCandidates = new LinkedHashSet<>();
-    for (Object obj : localK) {
-      T candidate = (T) obj;
-      uniqueCandidates.add(candidate);
-      double distance = streamShape.distance(candidate);
-      if (distance > maxDistance) {
-        maxDistance = distance;
-      }
-    }
-    searchEnvelope.expandBy(maxDistance);
-    List<T> candidates = strTree.query(searchEnvelope);
-    if (!candidates.isEmpty()) {
-      // update localK with all candidates that are within the maxDistance
-      List<Object> tiedResults = new ArrayList<>();
-      // add all localK
-      Collections.addAll(tiedResults, localK);
-
-      for (T candidate : candidates) {
-        double distance = streamShape.distance(candidate);
-        if (distance == maxDistance && !uniqueCandidates.contains(candidate)) {
-          tiedResults.add(candidate);
-        }
-      }
-      localK = tiedResults.toArray();
-    }
-    return localK;
-  }
-
   public static <U extends Geometry, T extends Geometry> double distance(
       U key, T value, DistanceMetric distanceMetric) {
     switch (distanceMetric) {
@@ -295,4 +225,10 @@ public class KnnJoinIndexJudgement<T extends Geometry, U 
extends Geometry>
         return new EuclideanItemDistance().distance(key, value);
     }
   }
+
+  public static ItemDistance getItemDistance(DistanceMetric distanceMetric) {
+    ItemDistance itemDistance;
+    itemDistance = getItemDistanceByMetric(distanceMetric);
+    return itemDistance;
+  }
 }
diff --git 
a/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
 
b/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
index a5665726e0..7b55dd0763 100644
--- 
a/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
+++ 
b/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
@@ -37,13 +37,11 @@ import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.api.java.function.Function2;
 import org.apache.spark.api.java.function.PairFunction;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.util.LongAccumulator;
 import org.locationtech.jts.geom.Geometry;
-import org.locationtech.jts.index.SpatialIndex;
 import org.locationtech.jts.index.strtree.STRtree;
 import scala.Tuple2;
 
@@ -401,7 +399,7 @@ public class JoinQuery {
       DistanceMetric distanceMetric)
       throws Exception {
     final JoinParams joinParams =
-        new JoinParams(true, null, IndexType.RTREE, null, k, distanceMetric, 
null);
+        new JoinParams(true, null, IndexType.RTREE, null, k, distanceMetric);
 
     final JavaPairRDD<U, T> joinResults = knnJoin(queryRDD, objectRDD, 
joinParams, false, false);
     return collectGeometriesByKey(joinResults);
@@ -785,7 +783,7 @@ public class JoinQuery {
     LongAccumulator candidateCount = Metrics.createMetric(sparkContext, 
"candidateCount");
 
     final Broadcast<STRtree> broadcastObjectsTreeIndex;
-    final Broadcast<List> broadcastQueryObjects;
+    final Broadcast<List<UniqueGeometry<U>>> broadcastQueryObjects;
     if (broadcastJoin && objectRDD.indexedRawRDD != null && 
objectRDD.indexedRDD == null) {
       // If broadcastJoin is true and rawIndex is created on object side
       // we will broadcast queryRDD to objectRDD
@@ -816,10 +814,9 @@ public class JoinQuery {
     final JavaRDD<Pair<U, T>> joinResult;
     if (broadcastObjectsTreeIndex == null && broadcastQueryObjects == null) {
       // no broadcast join
-      final KnnJoinIndexJudgement judgement =
-          new KnnJoinIndexJudgement(
+      final KnnJoinIndexJudgement<U, T> judgement =
+          new KnnJoinIndexJudgement<>(
               joinParams.k,
-              joinParams.searchRadius,
               joinParams.distanceMetric,
               includeTies,
               null,
@@ -828,13 +825,13 @@ public class JoinQuery {
               streamCount,
               resultCount,
               candidateCount);
-      joinResult = 
queryRDD.spatialPartitionedRDD.zipPartitions(objectRDD.indexedRDD, judgement);
+      joinResult =
+          
queryRDD.spatialPartitionedRDD.zipPartitions(objectRDD.spatialPartitionedRDD, 
judgement);
     } else if (broadcastObjectsTreeIndex != null) {
       // broadcast join with objectRDD as broadcast side
-      final KnnJoinIndexJudgement judgement =
-          new KnnJoinIndexJudgement(
+      final KnnJoinIndexJudgement<U, T> judgement =
+          new KnnJoinIndexJudgement<>(
               joinParams.k,
-              joinParams.searchRadius,
               joinParams.distanceMetric,
               includeTies,
               null,
@@ -844,13 +841,12 @@ public class JoinQuery {
               resultCount,
               candidateCount);
       // won't need inputs from the shapes in the objectRDD
-      joinResult = 
queryRDD.rawSpatialRDD.zipPartitions(queryRDD.rawSpatialRDD, judgement);
-    } else if (broadcastQueryObjects != null) {
+      joinResult = 
queryRDD.rawSpatialRDD.mapPartitions(judgement::callUsingBroadcastObjectIndex);
+    } else {
       // broadcast join with queryRDD as broadcast side
-      final KnnJoinIndexJudgement judgement =
-          new KnnJoinIndexJudgement(
+      final KnnJoinIndexJudgement<UniqueGeometry<U>, T> judgement =
+          new KnnJoinIndexJudgement<>(
               joinParams.k,
-              joinParams.searchRadius,
               joinParams.distanceMetric,
               includeTies,
               broadcastQueryObjects,
@@ -860,8 +856,6 @@ public class JoinQuery {
               resultCount,
               candidateCount);
       joinResult = querySideBroadcastKNNJoin(objectRDD, joinParams, judgement, 
includeTies);
-    } else {
-      throw new IllegalArgumentException("No index found on the input RDDs.");
     }
 
     return joinResult.mapToPair(
@@ -891,22 +885,11 @@ public class JoinQuery {
       JavaRDD<Pair<U, T>> querySideBroadcastKNNJoin(
           SpatialRDD<T> objectRDD,
           JoinParams joinParams,
-          KnnJoinIndexJudgement judgement,
+          KnnJoinIndexJudgement<UniqueGeometry<U>, T> judgement,
           boolean includeTies) {
     final JavaRDD<Pair<U, T>> joinResult;
-    JavaRDD<Pair<U, T>> joinResultMapped =
-        objectRDD.indexedRawRDD.mapPartitions(
-            iterator -> {
-              List<Pair<U, T>> results = new ArrayList<>();
-              if (iterator.hasNext()) {
-                SpatialIndex spatialIndex = iterator.next();
-                // the broadcast join won't need inputs from the query's shape 
stream
-                Iterator<Pair<U, T>> callResult =
-                    judgement.call(null, 
Collections.singletonList(spatialIndex).iterator());
-                callResult.forEachRemaining(results::add);
-              }
-              return results.iterator();
-            });
+    JavaRDD<Pair<UniqueGeometry<U>, T>> joinResultMapped =
+        
objectRDD.rawSpatialRDD.mapPartitions(judgement::callUsingBroadcastQueryList);
     // this is to avoid serializable issues with the broadcast variable
     int k = joinParams.k;
     DistanceMetric distanceMetric = joinParams.distanceMetric;
@@ -915,72 +898,67 @@ public class JoinQuery {
     // (based on a grouping key and distance)
     joinResult =
         joinResultMapped
-            .groupBy(pair -> pair.getKey()) // Group by the first geometry
+            .groupBy(pair -> pair.getKey().getUniqueId())
             .flatMap(
-                (FlatMapFunction<Tuple2<U, Iterable<Pair<U, T>>>, Pair<U, T>>)
-                    pair -> {
-                      Iterable<Pair<U, T>> values = pair._2;
-
-                      // Extract and sort values by distance
-                      List<Pair<U, T>> sortedPairs = new ArrayList<>();
-                      for (Pair<U, T> p : values) {
-                        Pair<U, T> newPair =
-                            Pair.of(
-                                (U) ((UniqueGeometry<?>) 
p.getKey()).getOriginalGeometry(),
-                                p.getValue());
-                        sortedPairs.add(newPair);
-                      }
-
-                      // Sort pairs based on the distance function between the 
two geometries
-                      sortedPairs.sort(
-                          (p1, p2) -> {
-                            double distance1 =
-                                KnnJoinIndexJudgement.distance(
-                                    p1.getKey(), p1.getValue(), 
distanceMetric);
-                            double distance2 =
-                                KnnJoinIndexJudgement.distance(
-                                    p2.getKey(), p2.getValue(), 
distanceMetric);
-                            return Double.compare(
-                                distance1, distance2); // Sort ascending by 
distance
-                          });
-
-                      if (includeTies) {
-                        // Keep the top k pairs, including ties
-                        List<Pair<U, T>> topPairs = new ArrayList<>();
-                        double kthDistance = -1;
-                        for (int i = 0; i < sortedPairs.size(); i++) {
-                          if (i < k) {
-                            topPairs.add(sortedPairs.get(i));
-                            if (i == k - 1) {
-                              kthDistance =
-                                  KnnJoinIndexJudgement.distance(
-                                      sortedPairs.get(i).getKey(),
-                                      sortedPairs.get(i).getValue(),
-                                      distanceMetric);
-                            }
-                          } else {
-                            double currentDistance =
-                                KnnJoinIndexJudgement.distance(
-                                    sortedPairs.get(i).getKey(),
-                                    sortedPairs.get(i).getValue(),
-                                    distanceMetric);
-                            if (currentDistance == kthDistance) {
-                              topPairs.add(sortedPairs.get(i));
-                            } else {
-                              break;
-                            }
-                          }
+                pair -> {
+                  Iterable<Pair<UniqueGeometry<U>, T>> values = pair._2;
+
+                  // Extract and sort values by distance
+                  List<Pair<U, T>> sortedPairs = new ArrayList<>();
+                  for (Pair<UniqueGeometry<U>, T> p : values) {
+                    Pair<U, T> newPair = 
Pair.of(p.getKey().getOriginalGeometry(), p.getValue());
+                    sortedPairs.add(newPair);
+                  }
+
+                  // Sort pairs based on the distance function between the two 
geometries
+                  sortedPairs.sort(
+                      (p1, p2) -> {
+                        double distance1 =
+                            KnnJoinIndexJudgement.distance(
+                                p1.getKey(), p1.getValue(), distanceMetric);
+                        double distance2 =
+                            KnnJoinIndexJudgement.distance(
+                                p2.getKey(), p2.getValue(), distanceMetric);
+                        return Double.compare(distance1, distance2); // Sort 
ascending by distance
+                      });
+
+                  if (includeTies) {
+                    // Keep the top k pairs, including ties
+                    List<Pair<U, T>> topPairs = new ArrayList<>();
+                    double kthDistance = -1;
+                    for (int i = 0; i < sortedPairs.size(); i++) {
+                      if (i < k) {
+                        topPairs.add(sortedPairs.get(i));
+                        if (i == k - 1) {
+                          kthDistance =
+                              KnnJoinIndexJudgement.distance(
+                                  sortedPairs.get(i).getKey(),
+                                  sortedPairs.get(i).getValue(),
+                                  distanceMetric);
                         }
-                        return topPairs.iterator();
                       } else {
-                        // Keep the top k pairs without ties
-                        List<Pair<U, T>> topPairs = new ArrayList<>();
-                        for (int i = 0; i < Math.min(k, sortedPairs.size()); 
i++) {
+                        double currentDistance =
+                            KnnJoinIndexJudgement.distance(
+                                sortedPairs.get(i).getKey(),
+                                sortedPairs.get(i).getValue(),
+                                distanceMetric);
+                        if (currentDistance == kthDistance) {
                           topPairs.add(sortedPairs.get(i));
+                        } else {
+                          break;
                         }
-                        return topPairs.iterator();
                       }
-                    });
+                    }
+                    return topPairs.iterator();
+                  } else {
+                    // Keep the top k pairs without ties
+                    List<Pair<U, T>> topPairs = new ArrayList<>();
+                    for (int i = 0; i < Math.min(k, sortedPairs.size()); i++) {
+                      topPairs.add(sortedPairs.get(i));
+                    }
+                    return topPairs.iterator();
+                  }
+                });
 
     return joinResult;
   }
@@ -994,14 +972,13 @@ public class JoinQuery {
     // KNN specific parameters
     public final int k;
     public final DistanceMetric distanceMetric;
-    public final Double searchRadius;
 
     public JoinParams(
         boolean useIndex,
         SpatialPredicate spatialPredicate,
         IndexType polygonIndexType,
         JoinBuildSide joinBuildSide) {
-      this(useIndex, spatialPredicate, polygonIndexType, joinBuildSide, -1, 
null, null);
+      this(useIndex, spatialPredicate, polygonIndexType, joinBuildSide, -1, 
null);
     }
 
     public JoinParams(
@@ -1010,15 +987,13 @@ public class JoinQuery {
         IndexType polygonIndexType,
         JoinBuildSide joinBuildSide,
         int k,
-        DistanceMetric distanceMetric,
-        Double searchRadius) {
+        DistanceMetric distanceMetric) {
       this.useIndex = useIndex;
       this.spatialPredicate = spatialPredicate;
       this.indexType = polygonIndexType;
       this.joinBuildSide = joinBuildSide;
       this.k = k;
       this.distanceMetric = distanceMetric;
-      this.searchRadius = searchRadius;
     }
 
     public JoinParams(boolean useIndex, SpatialPredicate spatialPredicate) {
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
index c5777be3c1..f4bdae40d5 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
@@ -138,9 +138,9 @@ case class BroadcastObjectSideKNNJoinExec(
     // Number of neighbors to find
     val kValue: Int = this.k.eval().asInstanceOf[Int]
     // Metric to use in the join to calculate the distance, only Euclidean and 
Spheroid are supported
-    val distanceMetric = if (isGeography) DistanceMetric.SPHEROID else 
DistanceMetric.EUCLIDEAN
+    val distanceMetric = if (isGeography) DistanceMetric.HAVERSINE else 
DistanceMetric.EUCLIDEAN
     val joinParams =
-      new JoinParams(true, null, IndexType.RTREE, null, kValue, 
distanceMetric, null)
+      new JoinParams(true, null, IndexType.RTREE, null, kValue, distanceMetric)
     joinParams
   }
 
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
index 9ce40c6d42..575ded9125 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
@@ -149,9 +149,9 @@ case class BroadcastQuerySideKNNJoinExec(
     // Number of neighbors to find
     val kValue: Int = this.k.eval().asInstanceOf[Int]
     // Metric to use in the join to calculate the distance, only Euclidean and 
Spheroid are supported
-    val distanceMetric = if (isGeography) DistanceMetric.SPHEROID else 
DistanceMetric.EUCLIDEAN
+    val distanceMetric = if (isGeography) DistanceMetric.HAVERSINE else 
DistanceMetric.EUCLIDEAN
     val joinParams =
-      new JoinParams(true, null, IndexType.RTREE, null, kValue, 
distanceMetric, null)
+      new JoinParams(true, null, IndexType.RTREE, null, kValue, distanceMetric)
     joinParams
   }
 
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
index fdc53d13ce..a879447d40 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
@@ -209,9 +209,9 @@ case class KNNJoinExec(
     // Number of neighbors to find
     val kValue: Int = this.k.eval().asInstanceOf[Int]
     // Metric to use in the join to calculate the distance, only Euclidean and 
Spheroid are supported
-    val distanceMetric = if (isGeography) DistanceMetric.SPHEROID else 
DistanceMetric.EUCLIDEAN
+    val distanceMetric = if (isGeography) DistanceMetric.HAVERSINE else 
DistanceMetric.EUCLIDEAN
     val joinParams =
-      new JoinParams(true, null, IndexType.RTREE, null, kValue, 
distanceMetric, null)
+      new JoinParams(true, null, IndexType.RTREE, null, kValue, distanceMetric)
     joinParams
   }
 }
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala
index ab2c64898a..50696337f7 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala
@@ -458,6 +458,34 @@ class KnnJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
       df2.cache()
       df1.join(df2, expr("ST_KNN(geom1, geom2, 1)")).count() should be(0)
     }
+
+    it("KNN Join using spider data source") {
+      val dfRandomSquares = sparkSession.read
+        .format("spider")
+        .option("n", "10000")
+        .option("distribution", "parcel")
+        .option("dither", "0.5")
+        .option("splitRange", "0.5")
+        .load()
+
+      dfRandomSquares.createOrReplaceTempView("df_random_squares")
+
+      val dfRandomPoints = sparkSession.read
+        .format("spider")
+        .option("n", "1000")
+        .option("distribution", "uniform")
+        .load()
+
+      dfRandomPoints.createOrReplaceTempView("df_random_points")
+
+      // Execute a KNN join query: attribute points to the nearest square
+      val knnJoined = sparkSession.sql("""SELECT sq.id, pt.id
+          |FROM df_random_squares sq
+          |JOIN df_random_points pt
+          |ON ST_KNN(sq.geometry, pt.geometry, 1, TRUE)""".stripMargin)
+
+      assert(knnJoined.count() > 0)
+    }
   }
 
   private def withOptimizationMode(mode: String)(body: => Unit): Unit = {


Reply via email to