This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new c54d144dcd [GH-2356] Implement barrier udf function (#2357)
c54d144dcd is described below
commit c54d144dcd4da7c7ebc967cf2d6a5cfe4773e09e
Author: Feng Zhang <[email protected]>
AuthorDate: Thu Sep 18 23:42:06 2025 -0700
[GH-2356] Implement barrier udf function (#2357)
* Implement barrier udf function
* comment about Barrier's CodegenFallback
* fix format lint
* address copilot comments
---
docs/api/sql/NearestNeighbourSearching.md | 34 ++
python/sedona/spark/sql/st_functions.py | 19 +
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 1 +
.../sedona_sql/expressions/BarrierFunction.scala | 206 ++++++++++
.../sql/sedona_sql/expressions/st_functions.scala | 9 +
.../strategy/join/JoinQueryDetector.scala | 38 --
.../apache/sedona/sql/BarrierFunctionTest.scala | 453 +++++++++++++++++++++
7 files changed, 722 insertions(+), 38 deletions(-)
diff --git a/docs/api/sql/NearestNeighbourSearching.md
b/docs/api/sql/NearestNeighbourSearching.md
index f45dccc064..26712a07a1 100644
--- a/docs/api/sql/NearestNeighbourSearching.md
+++ b/docs/api/sql/NearestNeighbourSearching.md
@@ -68,6 +68,40 @@ CACHE TABLE knnResult;
SELECT * FROM knnResult WHERE condition;
```
+### Optimization Barrier
+
+Use the `barrier` function to prevent filter pushdown and control predicate
evaluation order in complex spatial joins. This function creates an
optimization barrier by evaluating boolean expressions at runtime.
+
+The `barrier` function takes a boolean expression as a string, followed by
pairs of variable names and their values that will be substituted into the
expression:
+
+```sql
+barrier(expression, var_name1, var_value1, var_name2, var_value2, ...)
+```
+
+The placement of filters relative to KNN joins changes the semantic meaning of
the query:
+
+- **Filter before KNN**: First filters the data, then finds K nearest
neighbors from the filtered subset. This answers "What are the K nearest
high-rated restaurants?"
+- **Filter after KNN**: First finds K nearest neighbors from all data, then
filters those results. This answers "Of the K nearest restaurants, which ones
are high-rated?"
+
+### Example
+
+Find the 3 nearest high-rated restaurants to luxury hotels, ensuring the KNN
join completes before filtering.
+
+```sql
+SELECT
+ h.name AS hotel,
+ r.name AS restaurant,
+ r.rating
+FROM hotels AS h
+INNER JOIN restaurants AS r
+ON ST_KNN(h.geometry, r.geometry, 3, false)
+WHERE barrier('rating > 4.0 AND stars >= 4',
+ 'rating', r.rating,
+ 'stars', h.stars)
+```
+
+With the barrier function, this query first finds the 3 nearest restaurants to
each hotel (regardless of rating), then filters to keep only those pairs where
the restaurant has rating > 4.0 and the hotel has stars >= 4. Without the
barrier, an optimizer might push the filters down, changing the query to first
filter for high-rated restaurants and luxury hotels, then find the 3 nearest
among those filtered sets.
+
### Handling SQL-Defined Tables in ST_KNN Joins:
When creating DataFrames from hard-coded SQL select statements in Sedona, and
later using them in `ST_KNN` joins, Sedona may attempt to optimize the query in
a way that bypasses the intended kNN join logic. Specifically, if you create
DataFrames with hard-coded SQL, such as:
diff --git a/python/sedona/spark/sql/st_functions.py
b/python/sedona/spark/sql/st_functions.py
index 08fc20c350..fc10f4b9da 100644
--- a/python/sedona/spark/sql/st_functions.py
+++ b/python/sedona/spark/sql/st_functions.py
@@ -295,6 +295,25 @@ def ST_Azimuth(point_a: ColumnOrName, point_b:
ColumnOrName) -> Column:
return _call_st_function("ST_Azimuth", (point_a, point_b))
+@validate_argument_types
+def barrier(expression: ColumnOrName, *args) -> Column:
+ """Prevent filter pushdown and control predicate evaluation order in
complex spatial joins.
+ This function creates an optimization barrier by evaluating boolean
expressions at runtime.
+
+ :param expression: Boolean expression string to evaluate
+ :type expression: ColumnOrName
+ :param args: Variable name and value pairs (var_name1, var_value1,
var_name2, var_value2, ...)
+ :return: Boolean result of the expression evaluation
+ :rtype: Column
+
+ Example:
+ df.where(barrier('rating > 4.0 AND stars >= 4',
+ 'rating', col('r.rating'),
+ 'stars', col('h.stars')))
+ """
+ return _call_st_function("barrier", (expression,) + args)
+
+
@validate_argument_types
def ST_BestSRID(geometry: ColumnOrName) -> Column:
"""Estimates the best SRID (EPSG code) of the geometry.
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 42f9859b12..41b9a2233c 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -249,6 +249,7 @@ object Catalog extends AbstractCatalog {
function[ST_Rotate](),
function[ST_RotateX](),
function[ST_RotateY](),
+ function[Barrier](),
// Expression for rasters
function[RS_NormalizedDifference](),
function[RS_Mean](),
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/BarrierFunction.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/BarrierFunction.scala
new file mode 100644
index 0000000000..32bd4252fb
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/BarrierFunction.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.{DataType, BooleanType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+import scala.util.parsing.combinator._
+
+/**
+ * Barrier function to prevent filter pushdown and control predicate
evaluation order. Takes a
+ * boolean expression string followed by pairs of variable names and their
values.
+ *
+ * Usage: barrier(expression, var_name1, var_value1, var_name2, var_value2,
...) Example:
+ * barrier('rating > 4.0 AND stars >= 4', 'rating', r.rating, 'stars', h.stars)
+ *
+ * Extends CodegenFallback to prevent Catalyst optimizer from pushing this
filter through joins.
+ * CodegenFallback makes this expression opaque to optimization rules,
ensuring it evaluates at
+ * runtime in its original position within the query plan.
+ */
+private[apache] case class Barrier(inputExpressions: Seq[Expression])
+ extends Expression
+ with CodegenFallback {
+
+ override def nullable: Boolean = false
+
+ override def dataType: DataType = BooleanType
+
+ override def children: Seq[Expression] = inputExpressions
+
+ override def eval(input: InternalRow): Any = {
+ // Get the expression string
+ val exprString = inputExpressions.head.eval(input) match {
+ case s: UTF8String => s.toString
+ case null => throw new IllegalArgumentException("Barrier expression
cannot be null")
+ case other =>
+ throw new IllegalArgumentException(
+ s"Barrier expression must be a string, got: ${other.getClass}")
+ }
+
+ // Build variable map from pairs
+ val varMap = scala.collection.mutable.Map[String, Any]()
+ var i = 1
+ while (i < inputExpressions.length) {
+ if (i + 1 >= inputExpressions.length) {
+ throw new IllegalArgumentException(
+ "Barrier function requires pairs of variable names and values")
+ }
+
+ val varName = inputExpressions(i).eval(input) match {
+ case s: UTF8String => s.toString
+ case null => throw new IllegalArgumentException("Variable name cannot
be null")
+ case other =>
+ throw new IllegalArgumentException(
+ s"Variable name must be a string, got: ${other.getClass}")
+ }
+
+ val varValue = inputExpressions(i + 1).eval(input)
+ varMap(varName) = varValue
+ i += 2
+ }
+
+ // Evaluate the expression with variable substitution
+ evaluateBooleanExpression(exprString, varMap.toMap)
+ }
+
+ /**
+ * Evaluates a boolean expression string with variable substitution.
Supports basic comparison
+ * operators and logical operators (AND, OR, NOT).
+ */
+ private def evaluateBooleanExpression(
+ expression: String,
+ variables: Map[String, Any]): Boolean = {
+ val parser = new BooleanExpressionParser(variables)
+ parser.parseExpression(expression) match {
+ case parser.Success(result, _) => result
+ case parser.NoSuccess(msg, _) =>
+ throw new IllegalArgumentException(s"Failed to parse barrier
expression: $msg")
+ }
+ }
+
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression = {
+ copy(inputExpressions = newChildren)
+ }
+}
+
+/**
+ * Parser for boolean expressions in barrier function. Supports comparison
operators: =, !=, <>,
+ * <, <=, >, >= Supports logical operators: AND, OR, NOT Supports parentheses
for grouping
+ */
+private class BooleanExpressionParser(variables: Map[String, Any]) extends
JavaTokenParsers {
+
+ // Pre-compiled regex patterns for better performance
+ private val truePattern = "(?i)true".r
+ private val falsePattern = "(?i)false".r
+ private val nullPattern = "(?i)null".r
+ private val andPattern = "(?i)AND".r
+ private val orPattern = "(?i)OR".r
+ private val notPattern = "(?i)NOT".r
+
+ def parseExpression(expr: String): ParseResult[Boolean] = parseAll(boolExpr,
expr)
+
+ def boolExpr: Parser[Boolean] = orExpr
+
+ def orExpr: Parser[Boolean] = andExpr ~ rep(orPattern ~> andExpr) ^^ { case
left ~ rights =>
+ rights.foldLeft(left)(_ || _)
+ }
+
+ def andExpr: Parser[Boolean] = notExpr ~ rep(andPattern ~> notExpr) ^^ {
case left ~ rights =>
+ rights.foldLeft(left)(_ && _)
+ }
+
+ def notExpr: Parser[Boolean] =
+ notPattern ~> notExpr ^^ (!_) |
+ primaryExpr
+
+ def primaryExpr: Parser[Boolean] =
+ "(" ~> boolExpr <~ ")" |
+ attempt(comparison) |
+ booleanValue
+
+ def comparison: Parser[Boolean] = value ~ compOp ~ value ^^ { case left ~ op
~ right =>
+ compareValues(left, op, right)
+ }
+
+ def attempt[T](p: Parser[T]): Parser[T] = Parser { in =>
+ p(in) match {
+ case s @ Success(_, _) => s
+ case _ => Failure("", in)
+ }
+ }
+
+ def booleanValue: Parser[Boolean] =
+ truePattern ^^ (_ => true) |
+ falsePattern ^^ (_ => false) |
+ ident.filter(id => !id.toUpperCase.matches("AND|OR|NOT")) ^^ { name =>
+ variables.get(name) match {
+ case Some(b: Boolean) => b
+ case Some(other) =>
+ throw new IllegalArgumentException(s"Expected boolean value for
$name, got: $other")
+ case None =>
+ throw new IllegalArgumentException(s"Unknown variable: $name")
+ }
+ }
+
+ def compOp: Parser[String] = ">=" | "<=" | "!=" | "<>" | "=" | ">" | "<"
+
+ def value: Parser[Any] =
+ floatingPointNumber ^^ (_.toDouble) |
+ wholeNumber ^^ (_.toLong) |
+ stringLiteral ^^ (s => s.substring(1, s.length - 1)) | // Remove quotes
+ truePattern ^^ (_ => true) |
+ falsePattern ^^ (_ => false) |
+ nullPattern ^^ (_ => null) |
+ ident.filter(id => !id.toUpperCase.matches("AND|OR|NOT")) ^^ (name =>
+ variables.getOrElse(name, throw new IllegalArgumentException(s"Unknown
variable: $name")))
+
+ private def compareValues(left: Any, op: String, right: Any): Boolean = {
+ (left, right) match {
+ case (null, null) => op == "=" || op == "<=" || op == ">="
+ case (null, _) | (_, null) => op == "!=" || op == "<>"
+ case _ =>
+ val comparison = compareNonNull(left, right)
+ op match {
+ case "=" => comparison == 0
+ case "!=" | "<>" => comparison != 0
+ case "<" => comparison < 0
+ case "<=" => comparison <= 0
+ case ">" => comparison > 0
+ case ">=" => comparison >= 0
+ }
+ }
+ }
+
+ private def compareNonNull(left: Any, right: Any): Int = {
+ (left, right) match {
+ case (l: Number, r: Number) =>
+ val ld = l.doubleValue()
+ val rd = r.doubleValue()
+ if (ld < rd) -1 else if (ld > rd) 1 else 0
+ case (l: String, r: String) => l.compareTo(r)
+ case (l: Boolean, r: Boolean) => l.compareTo(r)
+ case _ =>
+ // Try to compare as strings as a fallback
+ left.toString.compareTo(right.toString)
+ }
+ }
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
index 554794fb3b..ccfc6f83b4 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
@@ -1038,4 +1038,13 @@ object st_functions {
selfWeight,
useSpheroid,
attributes)
+
+ def barrier(expression: Column, args: Column*): Column = {
+ val allArgs = expression +: args
+ wrapExpression[Barrier](allArgs: _*)
+ }
+ def barrier(expression: String, args: Any*): Column = {
+ val allArgs = expression +: args
+ wrapExpression[Barrier](allArgs: _*)
+ }
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index db22daa5e8..1e15051188 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -597,9 +597,6 @@ class JoinQueryDetector(sparkSession: SparkSession) extends
SparkStrategy {
case None =>
Nil
}
- val objectSidePlan = if (querySide == LeftSide) right else left
-
- checkObjectPlanFilterPushdown(objectSidePlan)
logInfo(
"Planning knn join, left side is for queries and right size is for the
object to be searched")
@@ -737,10 +734,6 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends SparkStrategy {
case None =>
Nil
}
- val objectSidePlan = if (querySide == LeftSide) right else left
-
- checkObjectPlanFilterPushdown(objectSidePlan)
-
if (querySide == broadcastSide.get) {
// broadcast is on query side
return BroadcastQuerySideKNNJoinExec(
@@ -967,35 +960,4 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends SparkStrategy {
case other => other.children.exists(containPlanFilterPushdown)
}
}
-
- /**
- * Check if the given plan has a filter that can be pushed down to the
object side of the KNN
- * join. Print a warning if a filter pushdown is detected.
- * @param objectSidePlan
- */
- private def checkObjectPlanFilterPushdown(objectSidePlan: LogicalPlan): Unit
= {
- if (containPlanFilterPushdown(objectSidePlan)) {
- val warnings = Seq(
- "Warning: One or more filter pushdowns have been detected on the
object side of the KNN join. \n" +
- "These filters will be applied to the object side reader before the
KNN join is executed. \n" +
- "If you intend to apply the filters after the KNN join, please
ensure that you materialize the KNN join results before applying the filters.
\n" +
- "For example, you can use the following approach:\n\n" +
-
- // Scala Example
- "Scala Example:\n" +
- "val knnResult = knnJoinDF.cache()\n" +
- "val filteredResult = knnResult.filter(condition)\n\n" +
-
- // SQL Example
- "SQL Example:\n" +
- "CREATE OR REPLACE TEMP VIEW knnResult AS\n" +
- "SELECT * FROM (\n" +
- " -- Your KNN join SQL here\n" +
- ") AS knnView\n" +
- "CACHE TABLE knnResult;\n" +
- "SELECT * FROM knnResult WHERE condition;")
- logWarning(warnings.mkString("\n"))
- println(warnings.mkString("\n"))
- }
- }
}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/BarrierFunctionTest.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/BarrierFunctionTest.scala
new file mode 100644
index 0000000000..d2c49fb717
--- /dev/null
+++
b/spark/common/src/test/scala/org/apache/sedona/sql/BarrierFunctionTest.scala
@@ -0,0 +1,453 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType,
StructField, StructType, BooleanType}
+import org.apache.spark.sql.functions.expr
+import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, be}
+import org.scalatest.prop.TableDrivenPropertyChecks
+import scala.collection.JavaConverters._
+
+class BarrierFunctionTest extends TestBaseScala with TableDrivenPropertyChecks
{
+
+ describe("Barrier Function Test") {
+
+ it("should evaluate simple comparison expressions") {
+ // Create test data
+ val testDf = sparkSession
+ .createDataFrame(
+ Seq(Row(1, 4.5, 5), Row(2, 3.0, 3), Row(3, 5.0, 4), Row(4, 2.0,
2)).asJava,
+ StructType(
+ Seq(
+ StructField("id", IntegerType, false),
+ StructField("rating", DoubleType, false),
+ StructField("stars", IntegerType, false))))
+
+ testDf.createOrReplaceTempView("test_table")
+
+ // Test simple greater than
+ val result1 =
+ sparkSession.sql("""SELECT id, barrier('rating > 4.0', 'rating',
rating) as result
+ FROM test_table""")
+ val expected1 = Seq(true, false, true, false)
+ result1.collect().map(_.getBoolean(1)) should be(expected1)
+
+ // Test AND condition
+ val result2 = sparkSession.sql("""SELECT id, barrier('rating > 4.0 AND
stars >= 4',
+ 'rating', rating,
+ 'stars', stars) as result
+ FROM test_table""")
+ val expected2 = Seq(true, false, true, false)
+ result2.collect().map(_.getBoolean(1)) should be(expected2)
+
+ // Test OR condition
+ val result3 = sparkSession.sql("""SELECT id, barrier('rating < 3.0 OR
stars >= 5',
+ 'rating', rating,
+ 'stars', stars) as result
+ FROM test_table""")
+ val expected3 = Seq(true, false, false, true)
+ result3.collect().map(_.getBoolean(1)) should be(expected3)
+ }
+
+ it("should handle different comparison operators") {
+ val testDf = sparkSession
+ .createDataFrame(
+ Seq(Row(1, 10, 10), Row(2, 20, 30), Row(3, 15, 15), Row(4, 25,
20)).asJava,
+ StructType(
+ Seq(
+ StructField("id", IntegerType, false),
+ StructField("val1", IntegerType, false),
+ StructField("val2", IntegerType, false))))
+
+ testDf.createOrReplaceTempView("test_table")
+
+ // Test equals
+ val result1 = sparkSession.sql("""SELECT id, barrier('val1 = val2',
+ 'val1', val1,
+ 'val2', val2) as result
+ FROM test_table""")
+ val expected1 = Seq(true, false, true, false)
+ result1.collect().map(_.getBoolean(1)) should be(expected1)
+
+ // Test not equals
+ val result2 = sparkSession.sql("""SELECT id, barrier('val1 != val2',
+ 'val1', val1,
+ 'val2', val2) as result
+ FROM test_table""")
+ val expected2 = Seq(false, true, false, true)
+ result2.collect().map(_.getBoolean(1)) should be(expected2)
+
+ // Test less than or equal
+ val result3 = sparkSession.sql("""SELECT id, barrier('val1 <= val2',
+ 'val1', val1,
+ 'val2', val2) as result
+ FROM test_table""")
+ val expected3 = Seq(true, true, true, false)
+ result3.collect().map(_.getBoolean(1)) should be(expected3)
+ }
+
+ it("should handle NOT operator") {
+ val testDf = sparkSession
+ .createDataFrame(
+ Seq(Row(1, true), Row(2, false), Row(3, true), Row(4, false)).asJava,
+ StructType(
+ Seq(StructField("id", IntegerType, false), StructField("flag",
BooleanType, false))))
+
+ testDf.createOrReplaceTempView("test_table")
+
+ val result = sparkSession.sql("""SELECT id, barrier('NOT flag', 'flag',
flag) as result
+ FROM test_table""")
+ val expected = Seq(false, true, false, true)
+ result.collect().map(_.getBoolean(1)) should be(expected)
+ }
+
+ it("should handle parentheses for grouping") {
+ val testDf = sparkSession
+ .createDataFrame(
+ Seq(
+ Row(1, 10, 20, 30),
+ Row(2, 15, 25, 35),
+ Row(3, 5, 15, 25),
+ Row(4, 20, 10, 5)).asJava,
+ StructType(
+ Seq(
+ StructField("id", IntegerType, false),
+ StructField("a", IntegerType, false),
+ StructField("b", IntegerType, false),
+ StructField("c", IntegerType, false))))
+
+ testDf.createOrReplaceTempView("test_table")
+
+ // Test with parentheses
+ val result =
+ sparkSession.sql("""SELECT id, barrier('(a < b AND b < c) OR (a > b
AND b > c)',
+ 'a', a, 'b', b, 'c', c) as result
+ FROM test_table""")
+ val expected = Seq(true, true, true, true)
+ result.collect().map(_.getBoolean(1)) should be(expected)
+ }
+
+ it("should handle string comparisons") {
+ val testDf = sparkSession
+ .createDataFrame(
+ Seq(
+ Row(1, "apple", "banana"),
+ Row(2, "zebra", "apple"),
+ Row(3, "cat", "cat"),
+ Row(4, "dog", "cat")).asJava,
+ StructType(
+ Seq(
+ StructField("id", IntegerType, false),
+ StructField("str1", StringType, false),
+ StructField("str2", StringType, false))))
+
+ testDf.createOrReplaceTempView("test_table")
+
+ val result = sparkSession.sql("""SELECT id, barrier('str1 < str2',
+ 'str1', str1,
+ 'str2', str2) as result
+ FROM test_table""")
+ val expected = Seq(true, false, false, false)
+ result.collect().map(_.getBoolean(1)) should be(expected)
+ }
+
+ it("should handle null values") {
+ val testDf = sparkSession
+ .createDataFrame(
+ Seq(Row(1, 10, 20), Row(2, null, 20), Row(3, 10, null), Row(4, null,
null)).asJava,
+ StructType(
+ Seq(
+ StructField("id", IntegerType, false),
+ StructField("val1", IntegerType, true),
+ StructField("val2", IntegerType, true))))
+
+ testDf.createOrReplaceTempView("test_table")
+
+ // Test null equality comparisons
+ val resultEq = sparkSession.sql("""SELECT id, barrier('val1 = val2',
+ 'val1', val1,
+ 'val2', val2) as result
+ FROM test_table""")
+ val expectedEq = Seq(false, false, false, true)
+ resultEq.collect().map(_.getBoolean(1)) should be(expectedEq)
+
+ // Test null inequality comparisons
+ val resultNe = sparkSession.sql("""SELECT id, barrier('val1 != val2',
+ 'val1', val1,
+ 'val2', val2) as result
+ FROM test_table""")
+ val expectedNe = Seq(true, true, true, false)
+ resultNe.collect().map(_.getBoolean(1)) should be(expectedNe)
+
+ // Test null with <= operator
+ val resultLe = sparkSession.sql("""SELECT id, barrier('val1 <= val2',
+ 'val1', val1,
+ 'val2', val2) as result
+ FROM test_table""")
+ val expectedLe = Seq(true, false, false, true)
+ resultLe.collect().map(_.getBoolean(1)) should be(expectedLe)
+ }
+ }
+
+ describe("Barrier Function with KNN Spatial Join Tests") {
+ it("should prevent filter pushdown in KNN joins with complex conditions") {
+ // Create test data similar to KNN test patterns
+ val queries = sparkSession
+ .createDataFrame(
+ Seq(
+ Row(1, 1.0, 1.0, 5), // high rating query
+ Row(2, 2.0, 2.0, 3), // low rating query
+ Row(3, 3.0, 3.0, 4) // medium rating query
+ ).asJava,
+ StructType(
+ Seq(
+ StructField("q_id", IntegerType, false),
+ StructField("q_x", DoubleType, false),
+ StructField("q_y", DoubleType, false),
+ StructField("q_rating", IntegerType, false))))
+
+ val objects = sparkSession
+ .createDataFrame(
+ Seq(
+ Row(1, 0.9, 1.1, 2), // close to query 1, low rating
+ Row(2, 1.1, 0.9, 5), // close to query 1, high rating
+ Row(3, 2.1, 1.9, 3), // close to query 2, medium rating
+ Row(4, 2.9, 3.1, 4), // close to query 3, good rating
+ Row(5, 1.5, 1.5, 1), // between queries, very low rating
+ Row(6, 0.5, 0.5, 5) // close to query 1, high rating
+ ).asJava,
+ StructType(
+ Seq(
+ StructField("o_id", IntegerType, false),
+ StructField("o_x", DoubleType, false),
+ StructField("o_y", DoubleType, false),
+ StructField("o_rating", IntegerType, false))))
+
+ // Add geometry columns
+ val queriesWithGeom = queries.withColumn("q_geom", expr("ST_Point(q_x,
q_y)"))
+ val objectsWithGeom = objects.withColumn("o_geom", expr("ST_Point(o_x,
o_y)"))
+
+ queriesWithGeom.createOrReplaceTempView("test_queries")
+ objectsWithGeom.createOrReplaceTempView("test_objects")
+
+ // Test 1: KNN join with barrier to prevent early filtering
+ // This should find K=2 nearest neighbors first, then apply rating filter
+ val resultWithBarrier = sparkSession.sql("""
+ SELECT q.q_id, o.o_id, o.o_rating
+ FROM test_queries q
+ JOIN test_objects o
+ ON ST_KNN(q.q_geom, o.o_geom, 2, false)
+ WHERE barrier('q_rating >= 4 AND o_rating >= 3',
+ 'q_rating', q.q_rating,
+ 'o_rating', o.o_rating)
+ """)
+
+ val matches = resultWithBarrier.collect().sortBy(r => (r.getInt(0),
r.getInt(1)))
+
+ // Should find pairs where both query and object have good ratings
+ // from the 2 nearest neighbors per query
+ matches.length should be > 0
+ matches.foreach { row =>
+ val qRating = row.getInt(2) // This is o_rating in the select
+ qRating should be >= 3
+ }
+ }
+
+ it("should demonstrate different semantics with and without barrier in KNN
joins") {
+ // Create a scenario where barrier changes the result
+ val hotels = sparkSession
+ .createDataFrame(
+ Seq(
+ Row(1, 0.0, 0.0, 5), // luxury hotel at origin
+ Row(2, 1.0, 1.0, 2) // budget hotel nearby
+ ).asJava,
+ StructType(
+ Seq(
+ StructField("h_id", IntegerType, false),
+ StructField("h_x", DoubleType, false),
+ StructField("h_y", DoubleType, false),
+ StructField("h_stars", IntegerType, false))))
+
+ val restaurants = sparkSession
+ .createDataFrame(
+ Seq(
+ Row(1, 0.1, 0.1, 2.0), // close to luxury hotel, poor rating
+ Row(2, 0.2, 0.2, 3.0), // close to luxury hotel, ok rating
+ Row(3, 0.9, 0.9, 5.0), // close to budget hotel, excellent rating
+ Row(4, 5.0, 5.0, 5.0) // far away, excellent rating
+ ).asJava,
+ StructType(
+ Seq(
+ StructField("r_id", IntegerType, false),
+ StructField("r_x", DoubleType, false),
+ StructField("r_y", DoubleType, false),
+ StructField("r_rating", DoubleType, false))))
+
+ val hotelsWithGeom = hotels.withColumn("h_geom", expr("ST_Point(h_x,
h_y)"))
+ val restaurantsWithGeom = restaurants.withColumn("r_geom",
expr("ST_Point(r_x, r_y)"))
+
+ hotelsWithGeom.createOrReplaceTempView("knn_hotels")
+ restaurantsWithGeom.createOrReplaceTempView("knn_restaurants")
+
+ // Query 1: WITHOUT barrier - filter might get pushed down
+ // This could filter high-rated restaurants BEFORE finding nearest
+ val withoutBarrier = sparkSession.sql("""
+ SELECT h.h_id, r.r_id, r.r_rating
+ FROM knn_hotels h
+ JOIN knn_restaurants r
+ ON ST_KNN(h.h_geom, r.r_geom, 2, false)
+ WHERE h.h_stars >= 4 AND r.r_rating >= 4.0
+ """)
+
+ // Query 2: WITH barrier - ensures KNN completes before filtering
+ // This finds 2 nearest restaurants FIRST, then filters
+ val withBarrier = sparkSession.sql("""
+ SELECT h.h_id, r.r_id, r.r_rating
+ FROM knn_hotels h
+ JOIN knn_restaurants r
+ ON ST_KNN(h.h_geom, r.r_geom, 2, false)
+ WHERE barrier('h_stars >= 4 AND r_rating >= 4.0',
+ 'h_stars', h.h_stars,
+ 'r_rating', r.r_rating)
+ """)
+
+ val resultsWithoutBarrier = withoutBarrier.collect()
+ val resultsWithBarrier = withBarrier.collect()
+
+ // The key difference: barrier should return fewer/different results
+ resultsWithoutBarrier.length should be(2)
+
+ // Expected: withBarrier should find 0 results (nearest 1,2 don't meet
rating filter)
+ resultsWithBarrier.length should be(0)
+ }
+
+ it("should work with inequality conditions in KNN barrier functions") {
+ // Test adapted from KNN suite inequality tests
+ val points1 = sparkSession
+ .createDataFrame(
+ Seq(Row(1, 1.0, 1.0), Row(2, 2.0, 2.0), Row(3, 3.0, 3.0)).asJava,
+ StructType(
+ Seq(
+ StructField("id", IntegerType, false),
+ StructField("x", DoubleType, false),
+ StructField("y", DoubleType, false))))
+
+ val points2 = sparkSession
+ .createDataFrame(
+ Seq(
+ Row(1, 1.1, 1.1),
+ Row(3, 3.1, 3.1),
+ Row(6, 6.0, 6.0),
+ Row(13, 13.0, 13.0),
+ Row(16, 16.0, 16.0)).asJava,
+ StructType(
+ Seq(
+ StructField("id", IntegerType, false),
+ StructField("x", DoubleType, false),
+ StructField("y", DoubleType, false))))
+
+ val points1WithGeom = points1.withColumn("geom", expr("ST_Point(x, y)"))
+ val points2WithGeom = points2.withColumn("geom", expr("ST_Point(x, y)"))
+
+ points1WithGeom.createOrReplaceTempView("knn_points1")
+ points2WithGeom.createOrReplaceTempView("knn_points2")
+
+ // Test inequality condition with barrier
+ val result = sparkSession.sql("""
+ SELECT p1.id, p2.id
+ FROM knn_points1 p1
+ JOIN knn_points2 p2
+ ON ST_KNN(p1.geom, p2.geom, 3, false)
+ WHERE barrier('p1_id < p2_id',
+ 'p1_id', p1.id,
+ 'p2_id', p2.id)
+ """)
+
+ val matches = result.collect().sortBy(r => (r.getInt(0), r.getInt(1)))
+
+ // Should only include pairs where p1.id < p2.id
+ matches.foreach { row =>
+ val p1Id = row.getInt(0)
+ val p2Id = row.getInt(1)
+ p1Id should be < p2Id
+ }
+ }
+
+ it("should handle complex boolean expressions in KNN barrier scenarios") {
+ // Create test data for complex conditions
+ val venues = sparkSession
+ .createDataFrame(
+ Seq(
+ Row(1, 0.0, 0.0, "restaurant", 4.5, true),
+ Row(2, 0.1, 0.1, "cafe", 3.8, false),
+ Row(3, 0.2, 0.2, "restaurant", 4.2, true),
+ Row(4, 1.0, 1.0, "bar", 4.0, false)).asJava,
+ StructType(
+ Seq(
+ StructField("v_id", IntegerType, false),
+ StructField("v_x", DoubleType, false),
+ StructField("v_y", DoubleType, false),
+ StructField("v_type", StringType, false),
+ StructField("v_rating", DoubleType, false),
+ StructField("v_open_late", BooleanType, false))))
+
+ val users = sparkSession
+ .createDataFrame(
+ Seq(Row(1, 0.05, 0.05, 25, true), Row(2, 0.5, 0.5, 35,
false)).asJava,
+ StructType(
+ Seq(
+ StructField("u_id", IntegerType, false),
+ StructField("u_x", DoubleType, false),
+ StructField("u_y", DoubleType, false),
+ StructField("u_age", IntegerType, false),
+ StructField("u_night_owl", BooleanType, false))))
+
+ val venuesWithGeom = venues.withColumn("v_geom", expr("ST_Point(v_x,
v_y)"))
+ val usersWithGeom = users.withColumn("u_geom", expr("ST_Point(u_x,
u_y)"))
+
+ venuesWithGeom.createOrReplaceTempView("knn_venues")
+ usersWithGeom.createOrReplaceTempView("knn_users")
+
+ // Complex barrier condition: young night owls want open restaurants
with good ratings
+ val result = sparkSession.sql("""
+ SELECT u.u_id, v.v_id, v.v_type, v.v_rating
+ FROM knn_users u
+ JOIN knn_venues v
+ ON ST_KNN(u.u_geom, v.v_geom, 3, false)
+ WHERE barrier('(u_age < 30 AND u_night_owl AND v_open_late) AND
(v_type = "restaurant" AND v_rating > 4.0)',
+ 'u_age', u.u_age,
+ 'u_night_owl', u.u_night_owl,
+ 'v_open_late', v.v_open_late,
+ 'v_type', v.v_type,
+ 'v_rating', v.v_rating)
+ """)
+
+ val matches = result.collect()
+
+ // Should only match young night owls with open restaurants with good
ratings
+ matches.foreach { row =>
+ val vType = row.getString(2)
+ val vRating = row.getDouble(3)
+ vType should be("restaurant")
+ vRating should be > 4.0
+ }
+ }
+ }
+}