paleolimbot commented on code in PR #91:
URL: https://github.com/apache/sedona-db/pull/91#discussion_r2353127977


##########
rust/sedona-spatial-join/src/optimizer.rs:
##########
@@ -89,6 +90,123 @@ impl PhysicalOptimizerRule for SpatialJoinOptimizer {
     }
 }
 
+impl OptimizerRule for SpatialJoinOptimizer {
+    fn name(&self) -> &str {
+        "spatial_join_optimizer"
+    }
+
+    fn apply_order(&self) -> Option<ApplyOrder> {
+        Some(ApplyOrder::BottomUp)
+    }
+
+    /// Try to rewrite the plan containing a spatial Filter on top of a cross 
join without on or filter
+    /// to a theta-join with filter. For instance, the following query plan:
+    ///
+    /// ```text
+    /// Filter: st_intersects(l.geom, _scalar_sq_1.geom)
+    ///   Left Join (no on, no filter):
+    ///     TableScan: l projection=[id, geom]
+    ///     SubqueryAlias: __scalar_sq_1
+    ///       Projection: r.geom
+    ///         Filter: r.id = Int32(1)
+    ///           TableScan: r projection=[id, geom]
+    /// ```
+    ///
+    /// will be rewritten to
+    ///
+    /// ```text
+    /// Inner Join: Filter: st_intersects(l.geom, _scalar_sq_1.geom)
+    ///   TableScan: l projection=[id, geom]
+    ///   SubqueryAlias: __scalar_sq_1
+    ///     Projection: r.geom
+    ///       Filter: r.id = Int32(1)
+    ///         TableScan: r projection=[id, geom]
+    /// ```
+    ///
+    /// This is for enabling this logical join operator to be converted to a 
NestedLoopJoin physical
+    /// node with a spatial predicate, so that it could subsequently be 
optimized to a SpatialJoin
+    /// physical node. Please refer to the `PhysicalOptimizerRule` 
implementation of this struct
+    /// and [SpatialJoinOptimizer::try_optimize_join] for details.
+    fn rewrite(
+        &self,
+        plan: LogicalPlan,
+        _config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        let LogicalPlan::Filter(Filter {
+            predicate, input, ..
+        }) = &plan
+        else {
+            return Ok(Transformed::no(plan));
+        };
+        if !is_spatial_predicate(predicate) {
+            return Ok(Transformed::no(plan));
+        }
+
+        let LogicalPlan::Join(Join {
+            ref left,
+            ref right,
+            ref on,
+            ref filter,
+            join_type,
+            ref join_constraint,
+            ref null_equality,
+            ..
+        }) = input.as_ref()
+        else {
+            return Ok(Transformed::no(plan));
+        };
+
+        // Check if this is a suitable join for rewriting
+        if !matches!(
+            join_type,
+            JoinType::Inner | JoinType::Left | JoinType::Right
+        ) || !on.is_empty()
+            || filter.is_some()
+        {
+            return Ok(Transformed::no(plan));
+        }
+
+        let rewritten_plan = Join::try_new(
+            Arc::clone(left),
+            Arc::clone(right),
+            on.clone(),
+            Some(predicate.clone()),
+            JoinType::Inner,
+            *join_constraint,
+            *null_equality,
+        )?;
+
+        Ok(Transformed::yes(LogicalPlan::Join(rewritten_plan)))
+    }
+}
+
+/// Check if a given logical expression contains a spatial predicate component 
or not. We assume that the given
+/// `expr` evaluates to a boolean value and originates from a filter logical 
node.
+fn is_spatial_predicate(expr: &Expr) -> bool {
+    fn is_distance_expr(expr: &Expr) -> bool {
+        let Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { func, 
.. }) = expr else {
+            return false;
+        };
+        func.name().to_lowercase() == "st_distance"
+    }
+
+    match expr {
+        Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
+            left, right, op, ..
+        }) => match op {
+            Operator::And => is_spatial_predicate(left) || 
is_spatial_predicate(right),
+            Operator::Lt | Operator::LtEq => is_distance_expr(left),
+            Operator::Gt | Operator::GtEq => is_distance_expr(right),
+            _ => false,
+        },
+        Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { func, .. 
}) => {
+            let func_name = func.name().to_lowercase();
+            func_name.starts_with("st_")

Review Comment:
   Either is OK...it just seems like either there are a few white-listed 
functions that are valid here or *any* predicate that accepts a geometry as its 
argument is valid. I can see how just filtering on the prefix is safer than 
whitelisting any predicate.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to