pwrliang commented on code in PR #465:
URL: https://github.com/apache/sedona-db/pull/465#discussion_r2696571337


##########
rust/sedona-spatial-join/src/optimizer.rs:
##########
@@ -1054,6 +1080,282 @@ fn is_spatial_predicate_supported(
     }
 }
 
+// ============================================================================
+// GPU Optimizer Module
+// ============================================================================
+
+/// GPU optimizer module - conditionally compiled when GPU feature is enabled
+#[cfg(feature = "gpu")]
+mod gpu_optimizer {
+    use super::*;
+    use datafusion_common::DataFusionError;
+    use sedona_spatial_join_gpu::{
+        GeometryColumnInfo, GpuSpatialJoinConfig, GpuSpatialJoinExec, 
GpuSpatialPredicate,
+    };
+
+    /// Attempt to create a GPU-accelerated spatial join.
+    /// Returns None if GPU path is not applicable for this query.
+    pub fn try_create_gpu_spatial_join(
+        spatial_join: &SpatialJoinExec,
+        config: &ConfigOptions,
+    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
+        let sedona_options = config
+            .extensions
+            .get::<SedonaOptions>()
+            .ok_or_else(|| DataFusionError::Internal("SedonaOptions not 
found".into()))?;
+
+        // Check if GPU is enabled
+        if !sedona_options.spatial_join.gpu.enable {
+            return Ok(None);
+        }
+
+        // Check if predicate is supported on GPU
+        if !is_gpu_supported_predicate(&spatial_join.on) {
+            log::debug!("Predicate {:?} not supported on GPU", 
spatial_join.on);
+            return Ok(None);
+        }
+
+        // Get child plans
+        let left = spatial_join.left.clone();
+        let right = spatial_join.right.clone();
+
+        // Get schemas from child plans
+        let left_schema = left.schema();
+        let right_schema = right.schema();
+
+        // Find geometry columns in schemas
+        let left_geom_col = find_geometry_column(&left_schema)?;
+        let right_geom_col = find_geometry_column(&right_schema)?;
+
+        // Convert spatial predicate to GPU predicate
+        let gpu_predicate = convert_to_gpu_predicate(&spatial_join.on)?;
+
+        // Create GPU spatial join configuration
+        let gpu_config = GpuSpatialJoinConfig {
+            join_type: *spatial_join.join_type(),
+            left_geom_column: left_geom_col,
+            right_geom_column: right_geom_col,
+            predicate: gpu_predicate,
+            device_id: sedona_options.spatial_join.gpu.device_id as i32,
+            batch_size: sedona_options.spatial_join.gpu.batch_size,
+            additional_filters: spatial_join.filter.clone(),
+            max_memory: if sedona_options.spatial_join.gpu.max_memory_mb > 0 {
+                Some(sedona_options.spatial_join.gpu.max_memory_mb * 1024 * 
1024)
+            } else {
+                None
+            },
+            fallback_to_cpu: sedona_options.spatial_join.gpu.fallback_to_cpu,
+        };
+
+        log::info!(
+            "Creating GPU spatial join: predicate: {:?}, left geom: {}, right 
geom: {}",
+            gpu_config.predicate,
+            gpu_config.left_geom_column.name,
+            gpu_config.right_geom_column.name,
+        );
+
+        let gpu_join = Arc::new(GpuSpatialJoinExec::new(left, right, 
gpu_config)?);
+
+        // If the original SpatialJoinExec had a projection, wrap the GPU join 
with a ProjectionExec
+        if spatial_join.contains_projection() {
+            use datafusion_physical_expr::expressions::Column;
+            use datafusion_physical_plan::projection::ProjectionExec;
+
+            // Get the projection indices from the SpatialJoinExec
+            let projection_indices = spatial_join
+                .projection()
+                .expect("contains_projection() was true but projection() 
returned None");
+
+            // Create projection expressions that map from GPU join output to 
desired output
+            let mut projection_exprs = Vec::new();
+            let gpu_schema = gpu_join.schema();
+
+            for &idx in projection_indices {
+                let field = gpu_schema.field(idx);
+                let col_expr = Arc::new(Column::new(field.name(), idx))
+                    as Arc<dyn datafusion_physical_expr::PhysicalExpr>;
+                projection_exprs.push((col_expr, field.name().clone()));
+            }
+
+            let projection_exec = ProjectionExec::try_new(projection_exprs, 
gpu_join)?;
+            Ok(Some(Arc::new(projection_exec)))
+        } else {
+            Ok(Some(gpu_join))
+        }
+    }
+
+    /// Check if spatial predicate is supported on GPU
+    pub(crate) fn is_gpu_supported_predicate(predicate: &SpatialPredicate) -> 
bool {
+        match predicate {
+            SpatialPredicate::Relation(rel) => {
+                use crate::spatial_predicate::SpatialRelationType;
+                matches!(
+                    rel.relation_type,
+                    SpatialRelationType::Intersects
+                        | SpatialRelationType::Contains
+                        | SpatialRelationType::Covers
+                        | SpatialRelationType::Within
+                        | SpatialRelationType::CoveredBy
+                        | SpatialRelationType::Touches
+                        | SpatialRelationType::Equals
+                )
+            }
+            // Distance predicates not yet supported on GPU
+            SpatialPredicate::Distance(_) => false,
+            // KNN not yet supported on GPU
+            SpatialPredicate::KNearestNeighbors(_) => false,
+        }
+    }
+
+    /// Find geometry column in schema
+    pub(crate) fn find_geometry_column(schema: &SchemaRef) -> 
Result<GeometryColumnInfo> {
+        use arrow_schema::DataType;
+
+        for (idx, field) in schema.fields().iter().enumerate() {
+            // Check if this is a WKB geometry column (Binary, LargeBinary, or 
BinaryView)
+            if matches!(
+                field.data_type(),
+                DataType::Binary | DataType::LargeBinary | DataType::BinaryView
+            ) {
+                // Check metadata for geometry type
+                if let Some(meta) = 
field.metadata().get("ARROW:extension:name") {
+                    if meta.contains("geoarrow.wkb") || 
meta.contains("geometry") {
+                        return Ok(GeometryColumnInfo {
+                            name: field.name().clone(),
+                            index: idx,
+                        });
+                    }
+                }
+
+                // If no metadata, assume first binary column is geometry
+                // This is a fallback for files without proper GeoArrow 
metadata
+                if idx == schema.fields().len() - 1
+                    || schema.fields().iter().skip(idx + 1).all(|f| {
+                        !matches!(
+                            f.data_type(),
+                            DataType::Binary | DataType::LargeBinary | 
DataType::BinaryView
+                        )
+                    })
+                {
+                    log::warn!(
+                        "Geometry column '{}' has no GeoArrow metadata, 
assuming it's WKB",
+                        field.name()
+                    );
+                    return Ok(GeometryColumnInfo {
+                        name: field.name().clone(),
+                        index: idx,
+                    });
+                }
+            }
+        }
+
+        Err(DataFusionError::Plan(
+            "No geometry column found in schema".into(),
+        ))
+    }
+
+    /// Convert SpatialPredicate to GPU predicate
+    pub(crate) fn convert_to_gpu_predicate(
+        predicate: &SpatialPredicate,
+    ) -> Result<GpuSpatialPredicate> {
+        use crate::spatial_predicate::SpatialRelationType;
+        use sedona_libgpuspatial::SpatialPredicate as LibGpuPred;
+
+        match predicate {
+            SpatialPredicate::Relation(rel) => {
+                let gpu_pred = match rel.relation_type {
+                    SpatialRelationType::Intersects => LibGpuPred::Intersects,
+                    SpatialRelationType::Contains => LibGpuPred::Contains,
+                    SpatialRelationType::Covers => LibGpuPred::Covers,
+                    SpatialRelationType::Within => LibGpuPred::Within,
+                    SpatialRelationType::CoveredBy => LibGpuPred::CoveredBy,
+                    SpatialRelationType::Touches => LibGpuPred::Touches,
+                    SpatialRelationType::Equals => LibGpuPred::Equals,

Review Comment:
   I have extracted SpatialRelationType into a separate file under 
sedona-geometry.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to