This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new af74a17cf1 [SEDONA-690] Optimize query side broadcast knn join (#1741)
af74a17cf1 is described below
commit af74a17cf1a274431542fd2ce74b86fb5cb2de52
Author: Feng Zhang <[email protected]>
AuthorDate: Sun Jan 5 06:49:01 2025 -0800
[SEDONA-690] Optimize query side broadcast knn join (#1741)
* [SEDONA-688] Verify KNN parameter K must be equal or larger than 1
* [SEDONA-690] Optimize query side broadcast knn join
* fix isGeography parameter
---
.../core/joinJudgement/KnnJoinIndexJudgement.java | 189 +++++++++++++++-----
.../core/knnJudgement/EuclideanItemDistance.java | 8 +
.../core/knnJudgement/HaversineItemDistance.java | 8 +
.../sedona/core/knnJudgement/SpheroidDistance.java | 8 +
.../sedona/core/spatialOperator/JoinQuery.java | 190 ++++++++++++++++++---
.../apache/sedona/core/wrapper/UniqueGeometry.java | 168 ++++++++++++++++++
.../join/BroadcastQuerySideKNNJoinExec.scala | 17 +-
.../strategy/join/JoinQueryDetector.scala | 43 ++---
8 files changed, 535 insertions(+), 96 deletions(-)
diff --git
a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
index 1c7fe7a0ae..f5375009ed 100644
---
a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
+++
b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
@@ -25,6 +25,7 @@ import org.apache.sedona.core.enums.DistanceMetric;
import org.apache.sedona.core.knnJudgement.EuclideanItemDistance;
import org.apache.sedona.core.knnJudgement.HaversineItemDistance;
import org.apache.sedona.core.knnJudgement.SpheroidDistance;
+import org.apache.sedona.core.wrapper.UniqueGeometry;
import org.apache.spark.api.java.function.FlatMapFunction2;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.util.LongAccumulator;
@@ -46,35 +47,43 @@ public class KnnJoinIndexJudgement<T extends Geometry, U
extends Geometry>
extends JudgementBase<T, U>
implements FlatMapFunction2<Iterator<T>, Iterator<SpatialIndex>, Pair<U,
T>>, Serializable {
private final int k;
+ private final Double searchRadius;
private final DistanceMetric distanceMetric;
private final boolean includeTies;
- private final Broadcast<STRtree> broadcastedTreeIndex;
+ private final Broadcast<List> broadcastQueryObjects;
+ private final Broadcast<STRtree> broadcastObjectsTreeIndex;
/**
* Constructor for the KnnJoinIndexJudgement class.
*
* @param k the number of nearest neighbors to find
+ * @param searchRadius
* @param distanceMetric the distance metric to use
+ * @param broadcastQueryObjects the broadcast geometries on queries
+ * @param broadcastObjectsTreeIndex the broadcast spatial index on objects
* @param buildCount accumulator for the number of geometries processed from
the build side
* @param streamCount accumulator for the number of geometries processed
from the stream side
* @param resultCount accumulator for the number of join results
* @param candidateCount accumulator for the number of candidate matches
- * @param broadcastedTreeIndex the broadcasted spatial index
*/
public KnnJoinIndexJudgement(
int k,
+ Double searchRadius,
DistanceMetric distanceMetric,
boolean includeTies,
- Broadcast<STRtree> broadcastedTreeIndex,
+ Broadcast<List> broadcastQueryObjects,
+ Broadcast<STRtree> broadcastObjectsTreeIndex,
LongAccumulator buildCount,
LongAccumulator streamCount,
LongAccumulator resultCount,
LongAccumulator candidateCount) {
super(null, buildCount, streamCount, resultCount, candidateCount);
this.k = k;
+ this.searchRadius = searchRadius;
this.distanceMetric = distanceMetric;
this.includeTies = includeTies;
- this.broadcastedTreeIndex = broadcastedTreeIndex;
+ this.broadcastQueryObjects = broadcastQueryObjects;
+ this.broadcastObjectsTreeIndex = broadcastObjectsTreeIndex;
}
/**
@@ -90,7 +99,7 @@ public class KnnJoinIndexJudgement<T extends Geometry, U
extends Geometry>
@Override
public Iterator<Pair<U, T>> call(Iterator<T> streamShapes,
Iterator<SpatialIndex> treeIndexes)
throws Exception {
- if (!treeIndexes.hasNext() || !streamShapes.hasNext()) {
+ if (!treeIndexes.hasNext() || (streamShapes != null &&
!streamShapes.hasNext())) {
buildCount.add(0);
streamCount.add(0);
resultCount.add(0);
@@ -99,10 +108,9 @@ public class KnnJoinIndexJudgement<T extends Geometry, U
extends Geometry>
}
STRtree strTree;
- if (broadcastedTreeIndex != null) {
- // get the broadcasted spatial index if available
- // this is to support the broadcast join
- strTree = broadcastedTreeIndex.getValue();
+ if (broadcastObjectsTreeIndex != null) {
+ // get the broadcast spatial index on objects side if available
+ strTree = broadcastObjectsTreeIndex.getValue();
} else {
// get the spatial index from the iterator
SpatialIndex treeIndex = treeIndexes.next();
@@ -113,44 +121,133 @@ public class KnnJoinIndexJudgement<T extends Geometry, U
extends Geometry>
strTree = (STRtree) treeIndex;
}
+ // TODO: For future improvement, instead of using a list to store the
results,
+ // we can use lazy evaluation to avoid storing all the results in memory.
List<Pair<U, T>> result = new ArrayList<>();
- ItemDistance itemDistance;
- while (streamShapes.hasNext()) {
- T streamShape = streamShapes.next();
- streamCount.add(1);
-
- Object[] localK;
- switch (distanceMetric) {
- case EUCLIDEAN:
- itemDistance = new EuclideanItemDistance();
- break;
- case HAVERSINE:
- itemDistance = new HaversineItemDistance();
- break;
- case SPHEROID:
- itemDistance = new SpheroidDistance();
- break;
- default:
- itemDistance = new GeometryItemDistance();
- break;
- }
+ List queryItems;
+ if (broadcastQueryObjects != null) {
+ // get the broadcast spatial index on queries side if available
+ queryItems = broadcastQueryObjects.getValue();
+ for (Object item : queryItems) {
+ T queryGeom;
+ if (item instanceof UniqueGeometry) {
+ queryGeom = (T) ((UniqueGeometry) item).getOriginalGeometry();
+ } else {
+ queryGeom = (T) item;
+ }
+ streamCount.add(1);
- localK =
- strTree.nearestNeighbour(streamShape.getEnvelopeInternal(),
streamShape, itemDistance, k);
- if (includeTies) {
- localK = getUpdatedLocalKWithTies(streamShape, localK, strTree);
+ Object[] localK =
+ strTree.nearestNeighbour(
+ queryGeom.getEnvelopeInternal(), queryGeom, getItemDistance(),
k);
+ if (includeTies) {
+ localK = getUpdatedLocalKWithTies(queryGeom, localK, strTree);
+ }
+ if (searchRadius != null) {
+ localK = getInSearchRadius(localK, queryGeom);
+ }
+
+ for (Object obj : localK) {
+ T candidate = (T) obj;
+ Pair<U, T> pair = Pair.of((U) item, candidate);
+ result.add(pair);
+ resultCount.add(1);
+ }
}
+ return result.iterator();
+ } else {
+ while (streamShapes.hasNext()) {
+ T streamShape = streamShapes.next();
+ streamCount.add(1);
+
+ Object[] localK =
+ strTree.nearestNeighbour(
+ streamShape.getEnvelopeInternal(), streamShape,
getItemDistance(), k);
+ if (includeTies) {
+ localK = getUpdatedLocalKWithTies(streamShape, localK, strTree);
+ }
+ if (searchRadius != null) {
+ localK = getInSearchRadius(localK, streamShape);
+ }
- for (Object obj : localK) {
- T candidate = (T) obj;
- Pair<U, T> pair = Pair.of((U) streamShape, candidate);
- result.add(pair);
- resultCount.add(1);
+ for (Object obj : localK) {
+ T candidate = (T) obj;
+ Pair<U, T> pair = Pair.of((U) streamShape, candidate);
+ result.add(pair);
+ resultCount.add(1);
+ }
}
+ return result.iterator();
}
+ }
- return result.iterator();
+ private Object[] getInSearchRadius(Object[] localK, T queryGeom) {
+ localK =
+ Arrays.stream(localK)
+ .filter(
+ candidate -> {
+ Geometry candidateGeom = (Geometry) candidate;
+ return distanceByMetric(queryGeom, candidateGeom,
distanceMetric) <= searchRadius;
+ })
+ .toArray();
+ return localK;
+ }
+
+ /**
+ * This method calculates the distance between two geometries using the
specified distance metric.
+ *
+ * @param queryGeom the query geometry
+ * @param candidateGeom the candidate geometry
+ * @param distanceMetric the distance metric to use
+ * @return the distance between the two geometries
+ */
+ public static double distanceByMetric(
+ Geometry queryGeom, Geometry candidateGeom, DistanceMetric
distanceMetric) {
+ switch (distanceMetric) {
+ case EUCLIDEAN:
+ EuclideanItemDistance euclideanItemDistance = new
EuclideanItemDistance();
+ return euclideanItemDistance.distance(queryGeom, candidateGeom);
+ case HAVERSINE:
+ HaversineItemDistance haversineItemDistance = new
HaversineItemDistance();
+ return haversineItemDistance.distance(queryGeom, candidateGeom);
+ case SPHEROID:
+ SpheroidDistance spheroidDistance = new SpheroidDistance();
+ return spheroidDistance.distance(queryGeom, candidateGeom);
+ default:
+ return queryGeom.distance(candidateGeom);
+ }
+ }
+
+ private ItemDistance getItemDistance() {
+ ItemDistance itemDistance;
+ itemDistance = getItemDistanceByMetric(distanceMetric);
+ return itemDistance;
+ }
+
+ /**
+ * This method returns the ItemDistance object based on the specified
distance metric.
+ *
+ * @param distanceMetric the distance metric to use
+ * @return the ItemDistance object
+ */
+ public static ItemDistance getItemDistanceByMetric(DistanceMetric
distanceMetric) {
+ ItemDistance itemDistance;
+ switch (distanceMetric) {
+ case EUCLIDEAN:
+ itemDistance = new EuclideanItemDistance();
+ break;
+ case HAVERSINE:
+ itemDistance = new HaversineItemDistance();
+ break;
+ case SPHEROID:
+ itemDistance = new SpheroidDistance();
+ break;
+ default:
+ itemDistance = new GeometryItemDistance();
+ break;
+ }
+ return itemDistance;
}
private Object[] getUpdatedLocalKWithTies(T streamShape, Object[] localK,
STRtree strTree) {
@@ -184,4 +281,18 @@ public class KnnJoinIndexJudgement<T extends Geometry, U
extends Geometry>
}
return localK;
}
+
+ public static <U extends Geometry, T extends Geometry> double distance(
+ U key, T value, DistanceMetric distanceMetric) {
+ switch (distanceMetric) {
+ case EUCLIDEAN:
+ return new EuclideanItemDistance().distance(key, value);
+ case HAVERSINE:
+ return new HaversineItemDistance().distance(key, value);
+ case SPHEROID:
+ return new SpheroidDistance().distance(key, value);
+ default:
+ return new EuclideanItemDistance().distance(key, value);
+ }
+ }
}
diff --git
a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/EuclideanItemDistance.java
b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/EuclideanItemDistance.java
index a27bf543b1..1aba8f87f7 100644
---
a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/EuclideanItemDistance.java
+++
b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/EuclideanItemDistance.java
@@ -36,4 +36,12 @@ public class EuclideanItemDistance implements ItemDistance {
return g1.distance(g2);
}
}
+
+ public double distance(Geometry geometry1, Geometry geometry2) {
+ if (geometry1 == geometry2) {
+ return Double.MAX_VALUE;
+ } else {
+ return geometry1.distance(geometry2);
+ }
+ }
}
diff --git
a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/HaversineItemDistance.java
b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/HaversineItemDistance.java
index 9ad1bfbee4..b04627074e 100644
---
a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/HaversineItemDistance.java
+++
b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/HaversineItemDistance.java
@@ -37,4 +37,12 @@ public class HaversineItemDistance implements ItemDistance {
return Haversine.distance(g1, g2);
}
}
+
+ public double distance(Geometry geometry1, Geometry geometry2) {
+ if (geometry1 == geometry2) {
+ return Double.MAX_VALUE;
+ } else {
+ return Haversine.distance(geometry1, geometry2);
+ }
+ }
}
diff --git
a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/SpheroidDistance.java
b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/SpheroidDistance.java
index df22d3565e..4ecdbf84c6 100644
---
a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/SpheroidDistance.java
+++
b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/SpheroidDistance.java
@@ -37,4 +37,12 @@ public class SpheroidDistance implements ItemDistance {
return Spheroid.distance(g1, g2);
}
}
+
+ public double distance(Geometry geometry1, Geometry geometry2) {
+ if (geometry1 == geometry2) {
+ return Double.MAX_VALUE;
+ } else {
+ return Spheroid.distance(geometry1, geometry2);
+ }
+ }
}
diff --git
a/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
b/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
index d20563d279..a5665726e0 100644
---
a/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
+++
b/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
@@ -18,10 +18,7 @@
*/
package org.apache.sedona.core.spatialOperator;
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Objects;
+import java.util.*;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
@@ -35,15 +32,18 @@ import org.apache.sedona.core.monitoring.Metrics;
import org.apache.sedona.core.spatialPartitioning.SpatialPartitioner;
import org.apache.sedona.core.spatialRDD.CircleRDD;
import org.apache.sedona.core.spatialRDD.SpatialRDD;
+import org.apache.sedona.core.wrapper.UniqueGeometry;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.util.LongAccumulator;
import org.locationtech.jts.geom.Geometry;
+import org.locationtech.jts.index.SpatialIndex;
import org.locationtech.jts.index.strtree.STRtree;
import scala.Tuple2;
@@ -784,47 +784,82 @@ public class JoinQuery {
LongAccumulator resultCount = Metrics.createMetric(sparkContext,
"resultCount");
LongAccumulator candidateCount = Metrics.createMetric(sparkContext,
"candidateCount");
- final Broadcast<STRtree> broadcastedTreeIndex;
- if (broadcastJoin) {
- // adjust auto broadcast threshold to avoid building index on large RDDs
+ final Broadcast<STRtree> broadcastObjectsTreeIndex;
+ final Broadcast<List> broadcastQueryObjects;
+ if (broadcastJoin && objectRDD.indexedRawRDD != null &&
objectRDD.indexedRDD == null) {
+ // If broadcastJoin is true and rawIndex is created on object side
+ // we will broadcast queryRDD to objectRDD
+ List<UniqueGeometry<U>> uniqueQueryObjects = new ArrayList<>();
+ for (U queryObject : queryRDD.rawSpatialRDD.collect()) {
+ // Wrap the query objects in a UniqueGeometry object to count for
duplicate queries in the
+ // join
+ uniqueQueryObjects.add(new UniqueGeometry<>(queryObject));
+ }
+ broadcastQueryObjects =
+
JavaSparkContext.fromSparkContext(sparkContext).broadcast(uniqueQueryObjects);
+ broadcastObjectsTreeIndex = null;
+ } else if (broadcastJoin && objectRDD.indexedRawRDD == null &&
objectRDD.indexedRDD == null) {
+ // If broadcastJoin is true and index and rawIndex are NOT created on
object side
+ // we will broadcast objectRDD to queryRDD
STRtree strTree = objectRDD.coalesceAndBuildRawIndex(IndexType.RTREE);
- broadcastedTreeIndex =
JavaSparkContext.fromSparkContext(sparkContext).broadcast(strTree);
+ broadcastObjectsTreeIndex =
+ JavaSparkContext.fromSparkContext(sparkContext).broadcast(strTree);
+ broadcastQueryObjects = null;
} else {
- broadcastedTreeIndex = null;
+ // Regular join does not need to set broadcast inderx
+ broadcastQueryObjects = null;
+ broadcastObjectsTreeIndex = null;
}
// The reason for using objectRDD as the right side is that the partitions
are built on the
// right side.
final JavaRDD<Pair<U, T>> joinResult;
- if (objectRDD.indexedRDD != null) {
+ if (broadcastObjectsTreeIndex == null && broadcastQueryObjects == null) {
+ // no broadcast join
final KnnJoinIndexJudgement judgement =
new KnnJoinIndexJudgement(
joinParams.k,
+ joinParams.searchRadius,
joinParams.distanceMetric,
includeTies,
- broadcastedTreeIndex,
+ null,
+ null,
buildCount,
streamCount,
resultCount,
candidateCount);
joinResult =
queryRDD.spatialPartitionedRDD.zipPartitions(objectRDD.indexedRDD, judgement);
- } else if (broadcastedTreeIndex != null) {
+ } else if (broadcastObjectsTreeIndex != null) {
+ // broadcast join with objectRDD as broadcast side
final KnnJoinIndexJudgement judgement =
new KnnJoinIndexJudgement(
joinParams.k,
+ joinParams.searchRadius,
joinParams.distanceMetric,
includeTies,
- broadcastedTreeIndex,
+ null,
+ broadcastObjectsTreeIndex,
buildCount,
streamCount,
resultCount,
candidateCount);
- int numPartitionsObjects = objectRDD.rawSpatialRDD.getNumPartitions();
- joinResult =
- queryRDD
- .rawSpatialRDD
- .repartition(numPartitionsObjects)
- .zipPartitions(objectRDD.rawSpatialRDD, judgement);
+ // won't need inputs from the shapes in the objectRDD
+ joinResult =
queryRDD.rawSpatialRDD.zipPartitions(queryRDD.rawSpatialRDD, judgement);
+ } else if (broadcastQueryObjects != null) {
+ // broadcast join with queryRDD as broadcast side
+ final KnnJoinIndexJudgement judgement =
+ new KnnJoinIndexJudgement(
+ joinParams.k,
+ joinParams.searchRadius,
+ joinParams.distanceMetric,
+ includeTies,
+ broadcastQueryObjects,
+ null,
+ buildCount,
+ streamCount,
+ resultCount,
+ candidateCount);
+ joinResult = querySideBroadcastKNNJoin(objectRDD, joinParams, judgement,
includeTies);
} else {
throw new IllegalArgumentException("No index found on the input RDDs.");
}
@@ -833,6 +868,123 @@ public class JoinQuery {
(PairFunction<Pair<U, T>, U, T>) pair -> new Tuple2<>(pair.getKey(),
pair.getValue()));
}
+ /**
+ * Performs a KNN join where the query side is broadcasted.
+ *
+ * <p>This function performs a K-Nearest Neighbors (KNN) join operation
where the query geometries
+ * are broadcasted to all partitions of the object geometries.
+ *
+ * <p>The function first maps partitions of the indexed raw RDD to perform
the KNN join, then
+ * groups the results by the query geometry and keeps the top K pair for
each query geometry based
+ * on the distance.
+ *
+ * @param objectRDD The set of geometries (neighbors) to be queried.
+ * @param joinParams The parameters for the join, including index type,
number of neighbors (k),
+ * and distance metric.
+ * @param judgement The judgement function used to perform the KNN join.
+ * @param <U> The type of the geometries in the queryRDD set.
+ * @param <T> The type of the geometries in the objectRDD set.
+ * @return A JavaRDD of pairs where each pair contains a geometry from the
queryRDD and a matching
+ * geometry from the objectRDD.
+ */
+ private static <U extends Geometry, T extends Geometry>
+ JavaRDD<Pair<U, T>> querySideBroadcastKNNJoin(
+ SpatialRDD<T> objectRDD,
+ JoinParams joinParams,
+ KnnJoinIndexJudgement judgement,
+ boolean includeTies) {
+ final JavaRDD<Pair<U, T>> joinResult;
+ JavaRDD<Pair<U, T>> joinResultMapped =
+ objectRDD.indexedRawRDD.mapPartitions(
+ iterator -> {
+ List<Pair<U, T>> results = new ArrayList<>();
+ if (iterator.hasNext()) {
+ SpatialIndex spatialIndex = iterator.next();
+ // the broadcast join won't need inputs from the query's shape
stream
+ Iterator<Pair<U, T>> callResult =
+ judgement.call(null,
Collections.singletonList(spatialIndex).iterator());
+ callResult.forEachRemaining(results::add);
+ }
+ return results.iterator();
+ });
+ // this is to avoid serializable issues with the broadcast variable
+ int k = joinParams.k;
+ DistanceMetric distanceMetric = joinParams.distanceMetric;
+
+ // Transform joinResultMapped to keep the top k pairs for each geometry
+ // (based on a grouping key and distance)
+ joinResult =
+ joinResultMapped
+ .groupBy(pair -> pair.getKey()) // Group by the first geometry
+ .flatMap(
+ (FlatMapFunction<Tuple2<U, Iterable<Pair<U, T>>>, Pair<U, T>>)
+ pair -> {
+ Iterable<Pair<U, T>> values = pair._2;
+
+ // Extract and sort values by distance
+ List<Pair<U, T>> sortedPairs = new ArrayList<>();
+ for (Pair<U, T> p : values) {
+ Pair<U, T> newPair =
+ Pair.of(
+ (U) ((UniqueGeometry<?>)
p.getKey()).getOriginalGeometry(),
+ p.getValue());
+ sortedPairs.add(newPair);
+ }
+
+ // Sort pairs based on the distance function between the
two geometries
+ sortedPairs.sort(
+ (p1, p2) -> {
+ double distance1 =
+ KnnJoinIndexJudgement.distance(
+ p1.getKey(), p1.getValue(),
distanceMetric);
+ double distance2 =
+ KnnJoinIndexJudgement.distance(
+ p2.getKey(), p2.getValue(),
distanceMetric);
+ return Double.compare(
+ distance1, distance2); // Sort ascending by
distance
+ });
+
+ if (includeTies) {
+ // Keep the top k pairs, including ties
+ List<Pair<U, T>> topPairs = new ArrayList<>();
+ double kthDistance = -1;
+ for (int i = 0; i < sortedPairs.size(); i++) {
+ if (i < k) {
+ topPairs.add(sortedPairs.get(i));
+ if (i == k - 1) {
+ kthDistance =
+ KnnJoinIndexJudgement.distance(
+ sortedPairs.get(i).getKey(),
+ sortedPairs.get(i).getValue(),
+ distanceMetric);
+ }
+ } else {
+ double currentDistance =
+ KnnJoinIndexJudgement.distance(
+ sortedPairs.get(i).getKey(),
+ sortedPairs.get(i).getValue(),
+ distanceMetric);
+ if (currentDistance == kthDistance) {
+ topPairs.add(sortedPairs.get(i));
+ } else {
+ break;
+ }
+ }
+ }
+ return topPairs.iterator();
+ } else {
+ // Keep the top k pairs without ties
+ List<Pair<U, T>> topPairs = new ArrayList<>();
+ for (int i = 0; i < Math.min(k, sortedPairs.size());
i++) {
+ topPairs.add(sortedPairs.get(i));
+ }
+ return topPairs.iterator();
+ }
+ });
+
+ return joinResult;
+ }
+
public static final class JoinParams {
public final boolean useIndex;
public final SpatialPredicate spatialPredicate;
diff --git
a/spark/common/src/main/java/org/apache/sedona/core/wrapper/UniqueGeometry.java
b/spark/common/src/main/java/org/apache/sedona/core/wrapper/UniqueGeometry.java
new file mode 100644
index 0000000000..01f20f2fa6
--- /dev/null
+++
b/spark/common/src/main/java/org/apache/sedona/core/wrapper/UniqueGeometry.java
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.core.wrapper;
+
+import java.util.UUID;
+import org.apache.commons.lang3.NotImplementedException;
+import org.locationtech.jts.geom.*;
+
+public class UniqueGeometry<T> extends Geometry {
+ private final T originalGeometry;
+ private final String uniqueId;
+
+ public UniqueGeometry(T originalGeometry) {
+ super(new GeometryFactory());
+ this.originalGeometry = originalGeometry;
+ this.uniqueId = UUID.randomUUID().toString();
+ }
+
+ public T getOriginalGeometry() {
+ return originalGeometry;
+ }
+
+ public String getUniqueId() {
+ return uniqueId;
+ }
+
+ @Override
+ public int hashCode() {
+ return uniqueId.hashCode(); // Uniqueness ensured by uniqueId
+ }
+
+ @Override
+ public String getGeometryType() {
+ throw new NotImplementedException("getGeometryType is not implemented.");
+ }
+
+ @Override
+ public Coordinate getCoordinate() {
+ throw new NotImplementedException("getCoordinate is not implemented.");
+ }
+
+ @Override
+ public Coordinate[] getCoordinates() {
+ throw new NotImplementedException("getCoordinates is not implemented.");
+ }
+
+ @Override
+ public int getNumPoints() {
+ throw new NotImplementedException("getNumPoints is not implemented.");
+ }
+
+ @Override
+ public boolean isEmpty() {
+ throw new NotImplementedException("isEmpty is not implemented.");
+ }
+
+ @Override
+ public int getDimension() {
+ throw new NotImplementedException("getDimension is not implemented.");
+ }
+
+ @Override
+ public Geometry getBoundary() {
+ throw new NotImplementedException("getBoundary is not implemented.");
+ }
+
+ @Override
+ public int getBoundaryDimension() {
+ throw new NotImplementedException("getBoundaryDimension is not
implemented.");
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) return true;
+ if (obj == null || getClass() != obj.getClass()) return false;
+ UniqueGeometry<?> that = (UniqueGeometry<?>) obj;
+ return uniqueId.equals(that.uniqueId);
+ }
+
+ @Override
+ public String toString() {
+ return "UniqueGeometry{"
+ + "originalGeometry="
+ + originalGeometry
+ + ", uniqueId='"
+ + uniqueId
+ + '\''
+ + '}';
+ }
+
+ @Override
+ protected Geometry reverseInternal() {
+ throw new NotImplementedException("reverseInternal is not implemented.");
+ }
+
+ @Override
+ public boolean equalsExact(Geometry geometry, double v) {
+ throw new NotImplementedException("equalsExact is not implemented.");
+ }
+
+ @Override
+ public void apply(CoordinateFilter coordinateFilter) {
+ throw new NotImplementedException("apply(CoordinateFilter) is not
implemented.");
+ }
+
+ @Override
+ public void apply(CoordinateSequenceFilter coordinateSequenceFilter) {
+ throw new NotImplementedException("apply(CoordinateSequenceFilter) is not
implemented.");
+ }
+
+ @Override
+ public void apply(GeometryFilter geometryFilter) {
+ throw new NotImplementedException("apply(GeometryFilter) is not
implemented.");
+ }
+
+ @Override
+ public void apply(GeometryComponentFilter geometryComponentFilter) {
+ throw new NotImplementedException("apply(GeometryComponentFilter) is not
implemented.");
+ }
+
+ @Override
+ protected Geometry copyInternal() {
+ throw new NotImplementedException("copyInternal is not implemented.");
+ }
+
+ @Override
+ public void normalize() {
+ throw new NotImplementedException("normalize is not implemented.");
+ }
+
+ @Override
+ protected Envelope computeEnvelopeInternal() {
+ throw new NotImplementedException("computeEnvelopeInternal is not
implemented.");
+ }
+
+ @Override
+ protected int compareToSameClass(Object o) {
+ throw new NotImplementedException("compareToSameClass(Object) is not
implemented.");
+ }
+
+ @Override
+ protected int compareToSameClass(
+ Object o, CoordinateSequenceComparator coordinateSequenceComparator) {
+ throw new NotImplementedException(
+ "compareToSameClass(Object, CoordinateSequenceComparator) is not
implemented.");
+ }
+
+ @Override
+ protected int getTypeCode() {
+ throw new NotImplementedException("getTypeCode is not implemented.");
+ }
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
index 001c0a1ca3..9ce40c6d42 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
@@ -130,19 +130,10 @@ case class BroadcastQuerySideKNNJoinExec(
require(kValue >= 1, "The number of neighbors (k) must be equal or greater
than 1.")
objectsShapes.setNeighborSampleNumber(kValue)
- val joinPartitions: Integer = numPartitions
- broadcastJoin = false
-
- // expand the boundary for partition to include both RDDs
- objectsShapes.analyze()
- queryShapes.analyze()
-
objectsShapes.boundaryEnvelope.expandToInclude(queryShapes.boundaryEnvelope)
-
- objectsShapes.spatialPartitioning(GridType.QUADTREE_RTREE, joinPartitions)
- queryShapes.spatialPartitioning(
-
objectsShapes.getPartitioner.asInstanceOf[QuadTreeRTPartitioner].nonOverlappedPartitioner())
-
- objectsShapes.buildIndex(IndexType.RTREE, true)
+ // index the objects on regular partitions (not spatial partitions)
+ // this avoids the cost of spatial partitioning
+ objectsShapes.buildIndex(IndexType.RTREE, false)
+ broadcastJoin = true
}
/**
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index da9bd5359b..b89b1adeda 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -589,7 +589,14 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
val leftShape = children.head
val rightShape = children.tail.head
- val querySide = getKNNQuerySide(left, leftShape)
+ val querySide = matchExpressionsToPlans(leftShape, rightShape, left,
right) match {
+ case Some((_, _, false)) =>
+ LeftSide
+ case Some((_, _, true)) =>
+ RightSide
+ case None =>
+ Nil
+ }
val objectSidePlan = if (querySide == LeftSide) right else left
checkObjectPlanFilterPushdown(objectSidePlan)
@@ -722,7 +729,14 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
val leftShape = children.head
val rightShape = children.tail.head
- val querySide = getKNNQuerySide(left, leftShape)
+ val querySide = matchExpressionsToPlans(leftShape, rightShape, left,
right) match {
+ case Some((_, _, false)) =>
+ LeftSide
+ case Some((_, _, true)) =>
+ RightSide
+ case None =>
+ Nil
+ }
val objectSidePlan = if (querySide == LeftSide) right else left
checkObjectPlanFilterPushdown(objectSidePlan)
@@ -739,7 +753,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends
Strategy {
k = distance.get,
useApproximate = false,
spatialPredicate,
- isGeography = false,
+ isGeography,
condition = null,
extraCondition = None) :: Nil
} else {
@@ -754,7 +768,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends
Strategy {
k = distance.get,
useApproximate = false,
spatialPredicate,
- isGeography = false,
+ isGeography,
condition = null,
extraCondition = None) :: Nil
}
@@ -865,27 +879,6 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
}
}
- /**
- * Gets the query and object plans based on the left shape.
- *
- * This method checks if the left shape is part of the left or right plan
and returns the query
- * and object plans accordingly.
- *
- * @param leftShape
- * The left shape expression.
- * @return
- * The join side where the left shape is located.
- */
- private def getKNNQuerySide(left: LogicalPlan, leftShape: Expression) = {
- val isLeftQuerySide =
-
left.toString().toLowerCase().contains(leftShape.toString().toLowerCase())
- if (isLeftQuerySide) {
- LeftSide
- } else {
- RightSide
- }
- }
-
/**
* Check if the given condition is an equi-join between the given plans.
This method basically
* replicates the logic of