This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 1fd3b86518 [SEDONA-690] Set default metric to use Haversine for KNN
join and code refactoring (#1909)
1fd3b86518 is described below
commit 1fd3b86518d97c56a72f4f31a4dcfb67a6d55496
Author: Feng Zhang <[email protected]>
AuthorDate: Thu Apr 10 09:52:58 2025 -0700
[SEDONA-690] Set default metric to use Haversine for KNN join and code
refactoring (#1909)
* [SEDONA-690] Set default metric to use Haversine for KNN join and some
code refactor
* fix unit tests
* clean up join params
---
.../joinJudgement/InMemoryKNNJoinIterator.java | 155 ++++++++++++++++
.../core/joinJudgement/KnnJoinIndexJudgement.java | 200 +++++++--------------
.../sedona/core/spatialOperator/JoinQuery.java | 169 ++++++++---------
.../join/BroadcastObjectSideKNNJoinExec.scala | 4 +-
.../join/BroadcastQuerySideKNNJoinExec.scala | 4 +-
.../sql/sedona_sql/strategy/join/KNNJoinExec.scala | 4 +-
.../scala/org/apache/sedona/sql/KnnJoinSuite.scala | 28 +++
7 files changed, 329 insertions(+), 235 deletions(-)
diff --git
a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/InMemoryKNNJoinIterator.java
b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/InMemoryKNNJoinIterator.java
new file mode 100644
index 0000000000..54ba42485e
--- /dev/null
+++
b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/InMemoryKNNJoinIterator.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.core.joinJudgement;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.NoSuchElementException;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sedona.core.enums.DistanceMetric;
+import org.apache.sedona.core.wrapper.UniqueGeometry;
+import org.apache.spark.util.LongAccumulator;
+import org.locationtech.jts.geom.Envelope;
+import org.locationtech.jts.geom.Geometry;
+import org.locationtech.jts.index.strtree.ItemDistance;
+import org.locationtech.jts.index.strtree.STRtree;
+
+public class InMemoryKNNJoinIterator<T extends Geometry, U extends Geometry>
+ implements Iterator<Pair<T, U>> {
+ private final Iterator<T> querySideIterator;
+ private final STRtree strTree;
+
+ private final int k;
+ private final DistanceMetric distanceMetric;
+ private final boolean includeTies;
+ private final ItemDistance itemDistance;
+
+ private final LongAccumulator streamCount;
+ private final LongAccumulator resultCount;
+
+ private final List<Pair<T, U>> currentResults = new ArrayList<>();
+ private int currentResultIndex = 0;
+
+ public InMemoryKNNJoinIterator(
+ Iterator<T> querySideIterator,
+ STRtree strTree,
+ int k,
+ DistanceMetric distanceMetric,
+ boolean includeTies,
+ LongAccumulator streamCount,
+ LongAccumulator resultCount) {
+ this.querySideIterator = querySideIterator;
+ this.strTree = strTree;
+
+ this.k = k;
+ this.distanceMetric = distanceMetric;
+ this.includeTies = includeTies;
+ this.itemDistance = KnnJoinIndexJudgement.getItemDistance(distanceMetric);
+
+ this.streamCount = streamCount;
+ this.resultCount = resultCount;
+ }
+
+ @Override
+ public boolean hasNext() {
+ if (currentResultIndex < currentResults.size()) {
+ return true;
+ }
+
+ currentResultIndex = 0;
+ currentResults.clear();
+ while (querySideIterator.hasNext()) {
+ populateNextBatch();
+ if (!currentResults.isEmpty()) {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ @Override
+ public Pair<T, U> next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException();
+ }
+
+ return currentResults.get(currentResultIndex++);
+ }
+
+ private void populateNextBatch() {
+ T queryItem = querySideIterator.next();
+ Geometry queryGeom;
+ if (queryItem instanceof UniqueGeometry) {
+ queryGeom = (Geometry) ((UniqueGeometry<?>)
queryItem).getOriginalGeometry();
+ } else {
+ queryGeom = queryItem;
+ }
+ streamCount.add(1);
+
+ Object[] localK =
+ strTree.nearestNeighbour(queryGeom.getEnvelopeInternal(), queryGeom,
itemDistance, k);
+ if (includeTies) {
+ localK = getUpdatedLocalKWithTies(queryGeom, localK, strTree);
+ }
+
+ for (Object obj : localK) {
+ U candidate = (U) obj;
+ Pair<T, U> pair = Pair.of(queryItem, candidate);
+ currentResults.add(pair);
+ resultCount.add(1);
+ }
+ }
+
+ private Object[] getUpdatedLocalKWithTies(
+ Geometry streamShape, Object[] localK, STRtree strTree) {
+ Envelope searchEnvelope = streamShape.getEnvelopeInternal();
+ // get the maximum distance from the k nearest neighbors
+ double maxDistance = 0.0;
+ LinkedHashSet<U> uniqueCandidates = new LinkedHashSet<>();
+ for (Object obj : localK) {
+ U candidate = (U) obj;
+ uniqueCandidates.add(candidate);
+ double distance = streamShape.distance(candidate);
+ if (distance > maxDistance) {
+ maxDistance = distance;
+ }
+ }
+ searchEnvelope.expandBy(maxDistance);
+ List<U> candidates = strTree.query(searchEnvelope);
+ if (!candidates.isEmpty()) {
+ // update localK with all candidates that are within the maxDistance
+ List<Object> tiedResults = new ArrayList<>();
+ // add all localK
+ Collections.addAll(tiedResults, localK);
+
+ for (U candidate : candidates) {
+ double distance = streamShape.distance(candidate);
+ if (distance == maxDistance && !uniqueCandidates.contains(candidate)) {
+ tiedResults.add(candidate);
+ }
+ }
+ localK = tiedResults.toArray();
+ }
+ return localK;
+ }
+}
diff --git
a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
index f5375009ed..0dda586986 100644
---
a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
+++
b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
@@ -19,19 +19,18 @@
package org.apache.sedona.core.joinJudgement;
import java.io.Serializable;
-import java.util.*;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sedona.core.enums.DistanceMetric;
import org.apache.sedona.core.knnJudgement.EuclideanItemDistance;
import org.apache.sedona.core.knnJudgement.HaversineItemDistance;
import org.apache.sedona.core.knnJudgement.SpheroidDistance;
-import org.apache.sedona.core.wrapper.UniqueGeometry;
import org.apache.spark.api.java.function.FlatMapFunction2;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.util.LongAccumulator;
-import org.locationtech.jts.geom.Envelope;
import org.locationtech.jts.geom.Geometry;
-import org.locationtech.jts.index.SpatialIndex;
import org.locationtech.jts.index.strtree.GeometryItemDistance;
import org.locationtech.jts.index.strtree.ItemDistance;
import org.locationtech.jts.index.strtree.STRtree;
@@ -45,19 +44,17 @@ import org.locationtech.jts.index.strtree.STRtree;
*/
public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry>
extends JudgementBase<T, U>
- implements FlatMapFunction2<Iterator<T>, Iterator<SpatialIndex>, Pair<U,
T>>, Serializable {
+ implements FlatMapFunction2<Iterator<T>, Iterator<U>, Pair<T, U>>,
Serializable {
private final int k;
- private final Double searchRadius;
private final DistanceMetric distanceMetric;
private final boolean includeTies;
- private final Broadcast<List> broadcastQueryObjects;
+ private final Broadcast<List<T>> broadcastQueryObjects;
private final Broadcast<STRtree> broadcastObjectsTreeIndex;
/**
* Constructor for the KnnJoinIndexJudgement class.
*
* @param k the number of nearest neighbors to find
- * @param searchRadius
* @param distanceMetric the distance metric to use
* @param broadcastQueryObjects the broadcast geometries on queries
* @param broadcastObjectsTreeIndex the broadcast spatial index on objects
@@ -68,10 +65,9 @@ public class KnnJoinIndexJudgement<T extends Geometry, U
extends Geometry>
*/
public KnnJoinIndexJudgement(
int k,
- Double searchRadius,
DistanceMetric distanceMetric,
boolean includeTies,
- Broadcast<List> broadcastQueryObjects,
+ Broadcast<List<T>> broadcastQueryObjects,
Broadcast<STRtree> broadcastObjectsTreeIndex,
LongAccumulator buildCount,
LongAccumulator streamCount,
@@ -79,7 +75,6 @@ public class KnnJoinIndexJudgement<T extends Geometry, U
extends Geometry>
LongAccumulator candidateCount) {
super(null, buildCount, streamCount, resultCount, candidateCount);
this.k = k;
- this.searchRadius = searchRadius;
this.distanceMetric = distanceMetric;
this.includeTies = includeTies;
this.broadcastQueryObjects = broadcastQueryObjects;
@@ -91,15 +86,15 @@ public class KnnJoinIndexJudgement<T extends Geometry, U
extends Geometry>
* and uses the spatial index to find the k nearest neighbors for each
geometry. The method
* returns an iterator over the join results.
*
- * @param streamShapes iterator over the geometries in the stream side
- * @param treeIndexes iterator over the spatial indexes
+ * @param queryShapes iterator over the geometries in the query side
+ * @param objectShapes iterator over the geometries in the object side
* @return an iterator over the join results
* @throws Exception if the spatial index is not of type STRtree
*/
@Override
- public Iterator<Pair<U, T>> call(Iterator<T> streamShapes,
Iterator<SpatialIndex> treeIndexes)
+ public Iterator<Pair<T, U>> call(Iterator<T> queryShapes, Iterator<U>
objectShapes)
throws Exception {
- if (!treeIndexes.hasNext() || (streamShapes != null &&
!streamShapes.hasNext())) {
+ if (!objectShapes.hasNext() || (queryShapes != null &&
!queryShapes.hasNext())) {
buildCount.add(0);
streamCount.add(0);
resultCount.add(0);
@@ -107,91 +102,64 @@ public class KnnJoinIndexJudgement<T extends Geometry, U
extends Geometry>
return Collections.emptyIterator();
}
- STRtree strTree;
- if (broadcastObjectsTreeIndex != null) {
- // get the broadcast spatial index on objects side if available
- strTree = broadcastObjectsTreeIndex.getValue();
- } else {
- // get the spatial index from the iterator
- SpatialIndex treeIndex = treeIndexes.next();
- if (!(treeIndex instanceof STRtree)) {
- throw new Exception(
- "[KnnJoinIndexJudgement][Call] Only STRtree index supports KNN
search.");
- }
- strTree = (STRtree) treeIndex;
- }
-
- // TODO: For future improvement, instead of using a list to store the
results,
- // we can use lazy evaluation to avoid storing all the results in memory.
- List<Pair<U, T>> result = new ArrayList<>();
-
- List queryItems;
- if (broadcastQueryObjects != null) {
- // get the broadcast spatial index on queries side if available
- queryItems = broadcastQueryObjects.getValue();
- for (Object item : queryItems) {
- T queryGeom;
- if (item instanceof UniqueGeometry) {
- queryGeom = (T) ((UniqueGeometry) item).getOriginalGeometry();
- } else {
- queryGeom = (T) item;
- }
- streamCount.add(1);
-
- Object[] localK =
- strTree.nearestNeighbour(
- queryGeom.getEnvelopeInternal(), queryGeom, getItemDistance(),
k);
- if (includeTies) {
- localK = getUpdatedLocalKWithTies(queryGeom, localK, strTree);
- }
- if (searchRadius != null) {
- localK = getInSearchRadius(localK, queryGeom);
- }
+ STRtree strTree = buildSTRtree(objectShapes);
+ return new InMemoryKNNJoinIterator<>(
+ queryShapes, strTree, k, distanceMetric, includeTies, streamCount,
resultCount);
+ }
- for (Object obj : localK) {
- T candidate = (T) obj;
- Pair<U, T> pair = Pair.of((U) item, candidate);
- result.add(pair);
- resultCount.add(1);
- }
- }
- return result.iterator();
- } else {
- while (streamShapes.hasNext()) {
- T streamShape = streamShapes.next();
- streamCount.add(1);
+ /**
+ * This method performs the KNN join operation using the broadcast spatial
index built using all
+ * geometries in the object side.
+ *
+ * @param queryShapes iterator over the geometries in the query side
+ * @return an iterator over the join results
+ */
+ public Iterator<Pair<T, U>> callUsingBroadcastObjectIndex(Iterator<T>
queryShapes) {
+ if (!queryShapes.hasNext()) {
+ buildCount.add(0);
+ streamCount.add(0);
+ resultCount.add(0);
+ candidateCount.add(0);
+ return Collections.emptyIterator();
+ }
- Object[] localK =
- strTree.nearestNeighbour(
- streamShape.getEnvelopeInternal(), streamShape,
getItemDistance(), k);
- if (includeTies) {
- localK = getUpdatedLocalKWithTies(streamShape, localK, strTree);
- }
- if (searchRadius != null) {
- localK = getInSearchRadius(localK, streamShape);
- }
+ // There's no need to use external spatial index, since the object side is
small enough to be
+ // broadcasted, the STRtree built from the broadcasted object should be
able to fit into memory.
+ STRtree strTree = broadcastObjectsTreeIndex.getValue();
+ return new InMemoryKNNJoinIterator<>(
+ queryShapes, strTree, k, distanceMetric, includeTies, streamCount,
resultCount);
+ }
- for (Object obj : localK) {
- T candidate = (T) obj;
- Pair<U, T> pair = Pair.of((U) streamShape, candidate);
- result.add(pair);
- resultCount.add(1);
- }
- }
- return result.iterator();
+ /**
+ * This method performs the KNN join operation using the broadcast query
geometries.
+ *
+ * @param objectShapes iterator over the geometries in the object side
+ * @return an iterator over the join results
+ */
+ public Iterator<Pair<T, U>> callUsingBroadcastQueryList(Iterator<U>
objectShapes) {
+ if (!objectShapes.hasNext()) {
+ buildCount.add(0);
+ streamCount.add(0);
+ resultCount.add(0);
+ candidateCount.add(0);
+ return Collections.emptyIterator();
}
+
+ List<T> queryItems = broadcastQueryObjects.getValue();
+ STRtree strTree = buildSTRtree(objectShapes);
+ return new InMemoryKNNJoinIterator<>(
+ queryItems.iterator(), strTree, k, distanceMetric, includeTies,
streamCount, resultCount);
}
- private Object[] getInSearchRadius(Object[] localK, T queryGeom) {
- localK =
- Arrays.stream(localK)
- .filter(
- candidate -> {
- Geometry candidateGeom = (Geometry) candidate;
- return distanceByMetric(queryGeom, candidateGeom,
distanceMetric) <= searchRadius;
- })
- .toArray();
- return localK;
+ private STRtree buildSTRtree(Iterator<U> objectShapes) {
+ STRtree strTree = new STRtree();
+ while (objectShapes.hasNext()) {
+ U spatialObject = objectShapes.next();
+ strTree.insert(spatialObject.getEnvelopeInternal(), spatialObject);
+ buildCount.add(1);
+ }
+ strTree.build();
+ return strTree;
}
/**
@@ -219,12 +187,6 @@ public class KnnJoinIndexJudgement<T extends Geometry, U
extends Geometry>
}
}
- private ItemDistance getItemDistance() {
- ItemDistance itemDistance;
- itemDistance = getItemDistanceByMetric(distanceMetric);
- return itemDistance;
- }
-
/**
* This method returns the ItemDistance object based on the specified
distance metric.
*
@@ -250,38 +212,6 @@ public class KnnJoinIndexJudgement<T extends Geometry, U
extends Geometry>
return itemDistance;
}
- private Object[] getUpdatedLocalKWithTies(T streamShape, Object[] localK,
STRtree strTree) {
- Envelope searchEnvelope = streamShape.getEnvelopeInternal();
- // get the maximum distance from the k nearest neighbors
- double maxDistance = 0.0;
- LinkedHashSet<T> uniqueCandidates = new LinkedHashSet<>();
- for (Object obj : localK) {
- T candidate = (T) obj;
- uniqueCandidates.add(candidate);
- double distance = streamShape.distance(candidate);
- if (distance > maxDistance) {
- maxDistance = distance;
- }
- }
- searchEnvelope.expandBy(maxDistance);
- List<T> candidates = strTree.query(searchEnvelope);
- if (!candidates.isEmpty()) {
- // update localK with all candidates that are within the maxDistance
- List<Object> tiedResults = new ArrayList<>();
- // add all localK
- Collections.addAll(tiedResults, localK);
-
- for (T candidate : candidates) {
- double distance = streamShape.distance(candidate);
- if (distance == maxDistance && !uniqueCandidates.contains(candidate)) {
- tiedResults.add(candidate);
- }
- }
- localK = tiedResults.toArray();
- }
- return localK;
- }
-
public static <U extends Geometry, T extends Geometry> double distance(
U key, T value, DistanceMetric distanceMetric) {
switch (distanceMetric) {
@@ -295,4 +225,10 @@ public class KnnJoinIndexJudgement<T extends Geometry, U
extends Geometry>
return new EuclideanItemDistance().distance(key, value);
}
}
+
+ public static ItemDistance getItemDistance(DistanceMetric distanceMetric) {
+ ItemDistance itemDistance;
+ itemDistance = getItemDistanceByMetric(distanceMetric);
+ return itemDistance;
+ }
}
diff --git
a/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
b/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
index a5665726e0..7b55dd0763 100644
---
a/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
+++
b/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
@@ -37,13 +37,11 @@ import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.util.LongAccumulator;
import org.locationtech.jts.geom.Geometry;
-import org.locationtech.jts.index.SpatialIndex;
import org.locationtech.jts.index.strtree.STRtree;
import scala.Tuple2;
@@ -401,7 +399,7 @@ public class JoinQuery {
DistanceMetric distanceMetric)
throws Exception {
final JoinParams joinParams =
- new JoinParams(true, null, IndexType.RTREE, null, k, distanceMetric,
null);
+ new JoinParams(true, null, IndexType.RTREE, null, k, distanceMetric);
final JavaPairRDD<U, T> joinResults = knnJoin(queryRDD, objectRDD,
joinParams, false, false);
return collectGeometriesByKey(joinResults);
@@ -785,7 +783,7 @@ public class JoinQuery {
LongAccumulator candidateCount = Metrics.createMetric(sparkContext,
"candidateCount");
final Broadcast<STRtree> broadcastObjectsTreeIndex;
- final Broadcast<List> broadcastQueryObjects;
+ final Broadcast<List<UniqueGeometry<U>>> broadcastQueryObjects;
if (broadcastJoin && objectRDD.indexedRawRDD != null &&
objectRDD.indexedRDD == null) {
// If broadcastJoin is true and rawIndex is created on object side
// we will broadcast queryRDD to objectRDD
@@ -816,10 +814,9 @@ public class JoinQuery {
final JavaRDD<Pair<U, T>> joinResult;
if (broadcastObjectsTreeIndex == null && broadcastQueryObjects == null) {
// no broadcast join
- final KnnJoinIndexJudgement judgement =
- new KnnJoinIndexJudgement(
+ final KnnJoinIndexJudgement<U, T> judgement =
+ new KnnJoinIndexJudgement<>(
joinParams.k,
- joinParams.searchRadius,
joinParams.distanceMetric,
includeTies,
null,
@@ -828,13 +825,13 @@ public class JoinQuery {
streamCount,
resultCount,
candidateCount);
- joinResult =
queryRDD.spatialPartitionedRDD.zipPartitions(objectRDD.indexedRDD, judgement);
+ joinResult =
+
queryRDD.spatialPartitionedRDD.zipPartitions(objectRDD.spatialPartitionedRDD,
judgement);
} else if (broadcastObjectsTreeIndex != null) {
// broadcast join with objectRDD as broadcast side
- final KnnJoinIndexJudgement judgement =
- new KnnJoinIndexJudgement(
+ final KnnJoinIndexJudgement<U, T> judgement =
+ new KnnJoinIndexJudgement<>(
joinParams.k,
- joinParams.searchRadius,
joinParams.distanceMetric,
includeTies,
null,
@@ -844,13 +841,12 @@ public class JoinQuery {
resultCount,
candidateCount);
// won't need inputs from the shapes in the objectRDD
- joinResult =
queryRDD.rawSpatialRDD.zipPartitions(queryRDD.rawSpatialRDD, judgement);
- } else if (broadcastQueryObjects != null) {
+ joinResult =
queryRDD.rawSpatialRDD.mapPartitions(judgement::callUsingBroadcastObjectIndex);
+ } else {
// broadcast join with queryRDD as broadcast side
- final KnnJoinIndexJudgement judgement =
- new KnnJoinIndexJudgement(
+ final KnnJoinIndexJudgement<UniqueGeometry<U>, T> judgement =
+ new KnnJoinIndexJudgement<>(
joinParams.k,
- joinParams.searchRadius,
joinParams.distanceMetric,
includeTies,
broadcastQueryObjects,
@@ -860,8 +856,6 @@ public class JoinQuery {
resultCount,
candidateCount);
joinResult = querySideBroadcastKNNJoin(objectRDD, joinParams, judgement,
includeTies);
- } else {
- throw new IllegalArgumentException("No index found on the input RDDs.");
}
return joinResult.mapToPair(
@@ -891,22 +885,11 @@ public class JoinQuery {
JavaRDD<Pair<U, T>> querySideBroadcastKNNJoin(
SpatialRDD<T> objectRDD,
JoinParams joinParams,
- KnnJoinIndexJudgement judgement,
+ KnnJoinIndexJudgement<UniqueGeometry<U>, T> judgement,
boolean includeTies) {
final JavaRDD<Pair<U, T>> joinResult;
- JavaRDD<Pair<U, T>> joinResultMapped =
- objectRDD.indexedRawRDD.mapPartitions(
- iterator -> {
- List<Pair<U, T>> results = new ArrayList<>();
- if (iterator.hasNext()) {
- SpatialIndex spatialIndex = iterator.next();
- // the broadcast join won't need inputs from the query's shape
stream
- Iterator<Pair<U, T>> callResult =
- judgement.call(null,
Collections.singletonList(spatialIndex).iterator());
- callResult.forEachRemaining(results::add);
- }
- return results.iterator();
- });
+ JavaRDD<Pair<UniqueGeometry<U>, T>> joinResultMapped =
+
objectRDD.rawSpatialRDD.mapPartitions(judgement::callUsingBroadcastQueryList);
// this is to avoid serializable issues with the broadcast variable
int k = joinParams.k;
DistanceMetric distanceMetric = joinParams.distanceMetric;
@@ -915,72 +898,67 @@ public class JoinQuery {
// (based on a grouping key and distance)
joinResult =
joinResultMapped
- .groupBy(pair -> pair.getKey()) // Group by the first geometry
+ .groupBy(pair -> pair.getKey().getUniqueId())
.flatMap(
- (FlatMapFunction<Tuple2<U, Iterable<Pair<U, T>>>, Pair<U, T>>)
- pair -> {
- Iterable<Pair<U, T>> values = pair._2;
-
- // Extract and sort values by distance
- List<Pair<U, T>> sortedPairs = new ArrayList<>();
- for (Pair<U, T> p : values) {
- Pair<U, T> newPair =
- Pair.of(
- (U) ((UniqueGeometry<?>)
p.getKey()).getOriginalGeometry(),
- p.getValue());
- sortedPairs.add(newPair);
- }
-
- // Sort pairs based on the distance function between the
two geometries
- sortedPairs.sort(
- (p1, p2) -> {
- double distance1 =
- KnnJoinIndexJudgement.distance(
- p1.getKey(), p1.getValue(),
distanceMetric);
- double distance2 =
- KnnJoinIndexJudgement.distance(
- p2.getKey(), p2.getValue(),
distanceMetric);
- return Double.compare(
- distance1, distance2); // Sort ascending by
distance
- });
-
- if (includeTies) {
- // Keep the top k pairs, including ties
- List<Pair<U, T>> topPairs = new ArrayList<>();
- double kthDistance = -1;
- for (int i = 0; i < sortedPairs.size(); i++) {
- if (i < k) {
- topPairs.add(sortedPairs.get(i));
- if (i == k - 1) {
- kthDistance =
- KnnJoinIndexJudgement.distance(
- sortedPairs.get(i).getKey(),
- sortedPairs.get(i).getValue(),
- distanceMetric);
- }
- } else {
- double currentDistance =
- KnnJoinIndexJudgement.distance(
- sortedPairs.get(i).getKey(),
- sortedPairs.get(i).getValue(),
- distanceMetric);
- if (currentDistance == kthDistance) {
- topPairs.add(sortedPairs.get(i));
- } else {
- break;
- }
- }
+ pair -> {
+ Iterable<Pair<UniqueGeometry<U>, T>> values = pair._2;
+
+ // Extract and sort values by distance
+ List<Pair<U, T>> sortedPairs = new ArrayList<>();
+ for (Pair<UniqueGeometry<U>, T> p : values) {
+ Pair<U, T> newPair =
Pair.of(p.getKey().getOriginalGeometry(), p.getValue());
+ sortedPairs.add(newPair);
+ }
+
+ // Sort pairs based on the distance function between the two
geometries
+ sortedPairs.sort(
+ (p1, p2) -> {
+ double distance1 =
+ KnnJoinIndexJudgement.distance(
+ p1.getKey(), p1.getValue(), distanceMetric);
+ double distance2 =
+ KnnJoinIndexJudgement.distance(
+ p2.getKey(), p2.getValue(), distanceMetric);
+ return Double.compare(distance1, distance2); // Sort
ascending by distance
+ });
+
+ if (includeTies) {
+ // Keep the top k pairs, including ties
+ List<Pair<U, T>> topPairs = new ArrayList<>();
+ double kthDistance = -1;
+ for (int i = 0; i < sortedPairs.size(); i++) {
+ if (i < k) {
+ topPairs.add(sortedPairs.get(i));
+ if (i == k - 1) {
+ kthDistance =
+ KnnJoinIndexJudgement.distance(
+ sortedPairs.get(i).getKey(),
+ sortedPairs.get(i).getValue(),
+ distanceMetric);
}
- return topPairs.iterator();
} else {
- // Keep the top k pairs without ties
- List<Pair<U, T>> topPairs = new ArrayList<>();
- for (int i = 0; i < Math.min(k, sortedPairs.size());
i++) {
+ double currentDistance =
+ KnnJoinIndexJudgement.distance(
+ sortedPairs.get(i).getKey(),
+ sortedPairs.get(i).getValue(),
+ distanceMetric);
+ if (currentDistance == kthDistance) {
topPairs.add(sortedPairs.get(i));
+ } else {
+ break;
}
- return topPairs.iterator();
}
- });
+ }
+ return topPairs.iterator();
+ } else {
+ // Keep the top k pairs without ties
+ List<Pair<U, T>> topPairs = new ArrayList<>();
+ for (int i = 0; i < Math.min(k, sortedPairs.size()); i++) {
+ topPairs.add(sortedPairs.get(i));
+ }
+ return topPairs.iterator();
+ }
+ });
return joinResult;
}
@@ -994,14 +972,13 @@ public class JoinQuery {
// KNN specific parameters
public final int k;
public final DistanceMetric distanceMetric;
- public final Double searchRadius;
public JoinParams(
boolean useIndex,
SpatialPredicate spatialPredicate,
IndexType polygonIndexType,
JoinBuildSide joinBuildSide) {
- this(useIndex, spatialPredicate, polygonIndexType, joinBuildSide, -1,
null, null);
+ this(useIndex, spatialPredicate, polygonIndexType, joinBuildSide, -1,
null);
}
public JoinParams(
@@ -1010,15 +987,13 @@ public class JoinQuery {
IndexType polygonIndexType,
JoinBuildSide joinBuildSide,
int k,
- DistanceMetric distanceMetric,
- Double searchRadius) {
+ DistanceMetric distanceMetric) {
this.useIndex = useIndex;
this.spatialPredicate = spatialPredicate;
this.indexType = polygonIndexType;
this.joinBuildSide = joinBuildSide;
this.k = k;
this.distanceMetric = distanceMetric;
- this.searchRadius = searchRadius;
}
public JoinParams(boolean useIndex, SpatialPredicate spatialPredicate) {
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
index c5777be3c1..f4bdae40d5 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
@@ -138,9 +138,9 @@ case class BroadcastObjectSideKNNJoinExec(
// Number of neighbors to find
val kValue: Int = this.k.eval().asInstanceOf[Int]
// Metric to use in the join to calculate the distance, only Euclidean and
Spheroid are supported
- val distanceMetric = if (isGeography) DistanceMetric.SPHEROID else
DistanceMetric.EUCLIDEAN
+ val distanceMetric = if (isGeography) DistanceMetric.HAVERSINE else
DistanceMetric.EUCLIDEAN
val joinParams =
- new JoinParams(true, null, IndexType.RTREE, null, kValue,
distanceMetric, null)
+ new JoinParams(true, null, IndexType.RTREE, null, kValue, distanceMetric)
joinParams
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
index 9ce40c6d42..575ded9125 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
@@ -149,9 +149,9 @@ case class BroadcastQuerySideKNNJoinExec(
// Number of neighbors to find
val kValue: Int = this.k.eval().asInstanceOf[Int]
// Metric to use in the join to calculate the distance, only Euclidean and
Spheroid are supported
- val distanceMetric = if (isGeography) DistanceMetric.SPHEROID else
DistanceMetric.EUCLIDEAN
+ val distanceMetric = if (isGeography) DistanceMetric.HAVERSINE else
DistanceMetric.EUCLIDEAN
val joinParams =
- new JoinParams(true, null, IndexType.RTREE, null, kValue,
distanceMetric, null)
+ new JoinParams(true, null, IndexType.RTREE, null, kValue, distanceMetric)
joinParams
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
index fdc53d13ce..a879447d40 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
@@ -209,9 +209,9 @@ case class KNNJoinExec(
// Number of neighbors to find
val kValue: Int = this.k.eval().asInstanceOf[Int]
// Metric to use in the join to calculate the distance, only Euclidean and
Spheroid are supported
- val distanceMetric = if (isGeography) DistanceMetric.SPHEROID else
DistanceMetric.EUCLIDEAN
+ val distanceMetric = if (isGeography) DistanceMetric.HAVERSINE else
DistanceMetric.EUCLIDEAN
val joinParams =
- new JoinParams(true, null, IndexType.RTREE, null, kValue,
distanceMetric, null)
+ new JoinParams(true, null, IndexType.RTREE, null, kValue, distanceMetric)
joinParams
}
}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala
index ab2c64898a..50696337f7 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala
@@ -458,6 +458,34 @@ class KnnJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
df2.cache()
df1.join(df2, expr("ST_KNN(geom1, geom2, 1)")).count() should be(0)
}
+
+ it("KNN Join using spider data source") {
+ val dfRandomSquares = sparkSession.read
+ .format("spider")
+ .option("n", "10000")
+ .option("distribution", "parcel")
+ .option("dither", "0.5")
+ .option("splitRange", "0.5")
+ .load()
+
+ dfRandomSquares.createOrReplaceTempView("df_random_squares")
+
+ val dfRandomPoints = sparkSession.read
+ .format("spider")
+ .option("n", "1000")
+ .option("distribution", "uniform")
+ .load()
+
+ dfRandomPoints.createOrReplaceTempView("df_random_points")
+
+ // Execute a KNN join query: attribute points to the nearest square
+ val knnJoined = sparkSession.sql("""SELECT sq.id, pt.id
+ |FROM df_random_squares sq
+ |JOIN df_random_points pt
+ |ON ST_KNN(sq.geometry, pt.geometry, 1, TRUE)""".stripMargin)
+
+ assert(knnJoined.count() > 0)
+ }
}
private def withOptimizationMode(mode: String)(body: => Unit): Unit = {