This is an automated email from the ASF dual-hosted git repository.

kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new 5102eedc feat(rust/sedona-spatial-join) Add partitioned index provider 
(#555)
5102eedc is described below

commit 5102eedc6f9f48d68f49208794ccafb5eceb0213
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Thu Jan 29 22:17:46 2026 +0800

    feat(rust/sedona-spatial-join) Add partitioned index provider (#555)
    
    This patch adds a index provider for coordinating the creation of spatial 
index for specified partitions. It is also integrated into `SpatialJoinExec` so 
we use it to create the spatial index even when there's only one spatial 
partition (the degenerate case). The handling for multiple spatial partitions 
will be added in a subsequent PR.
    
    The memory reservations growed in the build side collection phase will be 
held by `PartitionedIndexProvider`. Spatial indexes created by the provider 
does not need to hold memory reservations.
    
    The next step is to support partitioned probe side by adding a 
`PartitionedProbeStreamProvider`, and modify the state machine of 
`SpatialJoinStream` to process multiple spatial partitions sequentially.
---
 rust/sedona-spatial-join/Cargo.toml                |   1 +
 rust/sedona-spatial-join/src/build_index.rs        |   1 -
 rust/sedona-spatial-join/src/exec.rs               |  50 +-
 rust/sedona-spatial-join/src/index.rs              |   2 +
 .../src/index/build_side_collector.rs              |  29 +-
 rust/sedona-spatial-join/src/index/memory_plan.rs  | 191 +++++++
 .../src/index/partitioned_index_provider.rs        | 602 +++++++++++++++++++++
 .../sedona-spatial-join/src/index/spatial_index.rs |  35 --
 .../src/index/spatial_index_builder.rs             |  38 +-
 rust/sedona-spatial-join/src/lib.rs                |   3 +-
 rust/sedona-spatial-join/src/partitioning/kdb.rs   |  35 +-
 .../src/partitioning/stream_repartitioner.rs       |   7 +
 rust/sedona-spatial-join/src/prepare.rs            | 514 ++++++++++++++++++
 rust/sedona-spatial-join/src/stream.rs             | 135 ++++-
 rust/sedona-spatial-join/src/utils.rs              |   1 +
 rust/sedona-spatial-join/src/utils/bbox_sampler.rs |   1 -
 .../src/utils/disposable_async_cell.rs             | 204 +++++++
 17 files changed, 1734 insertions(+), 115 deletions(-)

diff --git a/rust/sedona-spatial-join/Cargo.toml 
b/rust/sedona-spatial-join/Cargo.toml
index 322ec572..d34f7a6c 100644
--- a/rust/sedona-spatial-join/Cargo.toml
+++ b/rust/sedona-spatial-join/Cargo.toml
@@ -48,6 +48,7 @@ futures = { workspace = true }
 pin-project-lite = { workspace = true }
 once_cell = { workspace = true }
 parking_lot = { workspace = true }
+tokio = { workspace = true }
 geo = { workspace = true }
 sedona-geo-generic-alg = { workspace = true }
 geo-traits = { workspace = true, features = ["geo-types"] }
diff --git a/rust/sedona-spatial-join/src/build_index.rs 
b/rust/sedona-spatial-join/src/build_index.rs
index f3cbb34b..0e292369 100644
--- a/rust/sedona-spatial-join/src/build_index.rs
+++ b/rust/sedona-spatial-join/src/build_index.rs
@@ -105,7 +105,6 @@ pub async fn build_index(
             sedona_options.spatial_join,
             join_type,
             probe_threads_count,
-            Arc::clone(memory_pool),
             SpatialJoinBuildMetrics::new(0, &metrics),
         )?;
         index_builder.add_partitions(build_partitions).await?;
diff --git a/rust/sedona-spatial-join/src/exec.rs 
b/rust/sedona-spatial-join/src/exec.rs
index 50cbd171..495518ea 100644
--- a/rust/sedona-spatial-join/src/exec.rs
+++ b/rust/sedona-spatial-join/src/exec.rs
@@ -36,12 +36,13 @@ use parking_lot::Mutex;
 use sedona_common::SpatialJoinOptions;
 
 use crate::{
-    build_index::build_index,
-    index::SpatialIndex,
+    prepare::{SpatialJoinComponents, SpatialJoinComponentsBuilder},
     spatial_predicate::{KNNPredicate, SpatialPredicate},
     stream::{SpatialJoinProbeMetrics, SpatialJoinStream},
-    utils::join_utils::{asymmetric_join_output_partitioning, 
boundedness_from_children},
-    utils::once_fut::OnceAsync,
+    utils::{
+        join_utils::{asymmetric_join_output_partitioning, 
boundedness_from_children},
+        once_fut::OnceAsync,
+    },
     SedonaOptions,
 };
 
@@ -132,9 +133,10 @@ pub struct SpatialJoinExec {
     column_indices: Vec<ColumnIndex>,
     /// Cache holding plan properties like equivalences, output partitioning 
etc.
     cache: PlanProperties,
-    /// Spatial index built asynchronously on first execute() call and shared 
across all partitions.
-    /// Uses OnceAsync for lazy initialization coordinated via async runtime.
-    once_async_spatial_index: Arc<Mutex<Option<OnceAsync<SpatialIndex>>>>,
+    /// Once future for creating the partitioned index provider shared by all 
probe partitions.
+    /// This future runs only once before probing starts, and can be disposed 
by the last finished
+    /// stream so the provider does not outlive the execution plan 
unnecessarily.
+    once_async_spatial_join_components: 
Arc<Mutex<Option<OnceAsync<SpatialJoinComponents>>>>,
     /// Indicates if this SpatialJoin was converted from a HashJoin
     /// When true, we preserve HashJoin's equivalence properties and 
partitioning
     converted_from_hash_join: bool,
@@ -203,7 +205,7 @@ impl SpatialJoinExec {
             projection,
             metrics: Default::default(),
             cache,
-            once_async_spatial_index: Arc::new(Mutex::new(None)),
+            once_async_spatial_join_components: Arc::new(Mutex::new(None)),
             converted_from_hash_join,
             seed,
         })
@@ -431,7 +433,7 @@ impl ExecutionPlan for SpatialJoinExec {
             projection: self.projection.clone(),
             metrics: Default::default(),
             cache: self.cache.clone(),
-            once_async_spatial_index: Arc::new(Mutex::new(None)),
+            once_async_spatial_join_components: Arc::new(Mutex::new(None)),
             converted_from_hash_join: self.converted_from_hash_join,
             seed: self.seed,
         }))
@@ -463,8 +465,8 @@ impl ExecutionPlan for SpatialJoinExec {
                 let (build_plan, probe_plan) = (&self.left, &self.right);
 
                 // Build the spatial index using shared OnceAsync
-                let once_fut_spatial_index = {
-                    let mut once_async = self.once_async_spatial_index.lock();
+                let once_fut_spatial_join_components = {
+                    let mut once_async = 
self.once_async_spatial_join_components.lock();
                     once_async
                         .get_or_insert(OnceAsync::default())
                         .try_once(|| {
@@ -479,16 +481,16 @@ impl ExecutionPlan for SpatialJoinExec {
 
                             let probe_thread_count =
                                 
self.right.output_partitioning().partition_count();
-                            Ok(build_index(
+                            let spatial_join_components_builder = 
SpatialJoinComponentsBuilder::new(
                                 Arc::clone(&context),
                                 build_side.schema(),
-                                build_streams,
                                 self.on.clone(),
                                 self.join_type,
                                 probe_thread_count,
                                 self.metrics.clone(),
                                 self.seed,
-                            ))
+                            );
+                            
Ok(spatial_join_components_builder.build(build_streams))
                         })?
                 };
 
@@ -508,6 +510,7 @@ impl ExecutionPlan for SpatialJoinExec {
                     self.maintains_input_order()[1] && 
self.right.output_ordering().is_some();
 
                 Ok(Box::pin(SpatialJoinStream::new(
+                    partition,
                     self.schema(),
                     &self.on,
                     self.filter.clone(),
@@ -518,8 +521,8 @@ impl ExecutionPlan for SpatialJoinExec {
                     join_metrics,
                     sedona_options.spatial_join,
                     target_output_batch_size,
-                    once_fut_spatial_index,
-                    Arc::clone(&self.once_async_spatial_index),
+                    once_fut_spatial_join_components,
+                    Arc::clone(&self.once_async_spatial_join_components),
                 )))
             }
         }
@@ -556,8 +559,8 @@ impl SpatialJoinExec {
         let actual_probe_plan_is_left = std::ptr::eq(probe_plan.as_ref(), 
self.left.as_ref());
 
         // Build the spatial index
-        let once_fut_spatial_index = {
-            let mut once_async = self.once_async_spatial_index.lock();
+        let once_fut_spatial_join_components = {
+            let mut once_async = 
self.once_async_spatial_join_components.lock();
             once_async
                 .get_or_insert(OnceAsync::default())
                 .try_once(|| {
@@ -571,16 +574,16 @@ impl SpatialJoinExec {
                     }
 
                     let probe_thread_count = 
probe_plan.output_partitioning().partition_count();
-                    Ok(build_index(
+                    let spatial_join_components_builder = 
SpatialJoinComponentsBuilder::new(
                         Arc::clone(&context),
                         build_side.schema(),
-                        build_streams,
                         self.on.clone(),
                         self.join_type,
                         probe_thread_count,
                         self.metrics.clone(),
                         self.seed,
-                    ))
+                    );
+                    Ok(spatial_join_components_builder.build(build_streams))
                 })?
         };
 
@@ -605,6 +608,7 @@ impl SpatialJoinExec {
         };
 
         Ok(Box::pin(SpatialJoinStream::new(
+            partition,
             self.schema(),
             &self.on,
             self.filter.clone(),
@@ -615,8 +619,8 @@ impl SpatialJoinExec {
             join_metrics,
             sedona_options.spatial_join,
             target_output_batch_size,
-            once_fut_spatial_index,
-            Arc::clone(&self.once_async_spatial_index),
+            once_fut_spatial_join_components,
+            Arc::clone(&self.once_async_spatial_join_components),
         )))
     }
 }
diff --git a/rust/sedona-spatial-join/src/index.rs 
b/rust/sedona-spatial-join/src/index.rs
index 55df23d5..af31b8af 100644
--- a/rust/sedona-spatial-join/src/index.rs
+++ b/rust/sedona-spatial-join/src/index.rs
@@ -17,6 +17,8 @@
 
 pub(crate) mod build_side_collector;
 mod knn_adapter;
+pub(crate) mod memory_plan;
+pub(crate) mod partitioned_index_provider;
 pub(crate) mod spatial_index;
 pub(crate) mod spatial_index_builder;
 
diff --git a/rust/sedona-spatial-join/src/index/build_side_collector.rs 
b/rust/sedona-spatial-join/src/index/build_side_collector.rs
index 646c6be2..d888680f 100644
--- a/rust/sedona-spatial-join/src/index/build_side_collector.rs
+++ b/rust/sedona-spatial-join/src/index/build_side_collector.rs
@@ -68,6 +68,9 @@ pub(crate) struct BuildPartition {
     /// The size of this reservation will be used to determine the maximum 
size of
     /// each spatial partition, as well as how many spatial partitions to 
create.
     pub reservation: MemoryReservation,
+
+    /// Metrics collected during the build side collection phase
+    pub metrics: CollectBuildSideMetrics,
 }
 
 /// A collector for evaluating the spatial expression on build side batches 
and collect
@@ -112,6 +115,10 @@ impl CollectBuildSideMetrics {
             spill_metrics: SpillMetrics::new(metrics, partition),
         }
     }
+
+    pub fn spill_metrics(&self) -> SpillMetrics {
+        self.spill_metrics.clone()
+    }
 }
 
 impl BuildSideBatchesCollector {
@@ -147,7 +154,7 @@ impl BuildSideBatchesCollector {
         mut stream: SendableEvaluatedBatchStream,
         mut reservation: MemoryReservation,
         mut bbox_sampler: BoundingBoxSampler,
-        metrics: &CollectBuildSideMetrics,
+        metrics: CollectBuildSideMetrics,
     ) -> Result<BuildPartition> {
         let mut spill_writer_opt = None;
         let mut in_mem_batches: Vec<EvaluatedBatch> = Vec::new();
@@ -200,7 +207,7 @@ impl BuildSideBatchesCollector {
                             e,
                         );
                         spill_writer_opt =
-                            self.spill_in_mem_batches(&mut in_mem_batches, 
metrics)?;
+                            self.spill_in_mem_batches(&mut in_mem_batches, 
&metrics)?;
                     }
                 }
                 Some(spill_writer) => {
@@ -236,7 +243,7 @@ impl BuildSideBatchesCollector {
                 "Force spilling enabled. Spilling {} in-memory batches to 
disk.",
                 in_mem_batches.len()
             );
-            spill_writer_opt = self.spill_in_mem_batches(&mut in_mem_batches, 
metrics)?;
+            spill_writer_opt = self.spill_in_mem_batches(&mut in_mem_batches, 
&metrics)?;
         }
 
         let build_side_batch_stream: SendableEvaluatedBatchStream = match 
spill_writer_opt {
@@ -266,6 +273,7 @@ impl BuildSideBatchesCollector {
             bbox_samples: bbox_sampler.into_samples(),
             estimated_spatial_index_memory_usage,
             reservation,
+            metrics,
         })
     }
 
@@ -329,7 +337,7 @@ impl BuildSideBatchesCollector {
                 let evaluated_stream =
                     create_evaluated_build_stream(stream, evaluator, 
metrics.time_taken.clone());
                 let result = collector
-                    .collect(evaluated_stream, reservation, bbox_sampler, 
&metrics)
+                    .collect(evaluated_stream, reservation, bbox_sampler, 
metrics)
                     .await;
                 (partition_id, result)
             });
@@ -378,7 +386,7 @@ impl BuildSideBatchesCollector {
             let evaluated_stream =
                 create_evaluated_build_stream(stream, evaluator, 
metrics.time_taken.clone());
             let result = self
-                .collect(evaluated_stream, reservation, bbox_sampler, &metrics)
+                .collect(evaluated_stream, reservation, bbox_sampler, metrics)
                 .await?;
             results.push(result);
         }
@@ -534,11 +542,12 @@ mod tests {
         let metrics = CollectBuildSideMetrics::new(0, &metrics_set);
 
         let partition = collector
-            .collect(stream, reservation, sampler, &metrics)
+            .collect(stream, reservation, sampler, metrics)
             .await?;
         let stream = partition.build_side_batch_stream;
         let is_external = stream.is_external();
         let batches: Vec<EvaluatedBatch> = stream.try_collect().await?;
+        let metrics = &partition.metrics;
         assert!(!is_external, "Expected in-memory batches");
         assert_eq!(collect_ids(&batches), vec![0, 1, 2]);
         assert_eq!(partition.num_rows, 3);
@@ -564,14 +573,15 @@ mod tests {
         let metrics = CollectBuildSideMetrics::new(0, &metrics_set);
 
         let partition = collector
-            .collect(stream, reservation, sampler, &metrics)
+            .collect(stream, reservation, sampler, metrics)
             .await?;
         let stream = partition.build_side_batch_stream;
         let is_external = stream.is_external();
         let batches: Vec<EvaluatedBatch> = stream.try_collect().await?;
+        let metrics = &partition.metrics;
         assert!(is_external, "Expected batches to spill to disk");
         assert_eq!(collect_ids(&batches), vec![10, 11, 12]);
-        let spill_metrics = metrics.spill_metrics;
+        let spill_metrics = metrics.spill_metrics();
         assert!(spill_metrics.spill_file_count.value() >= 1);
         assert!(spill_metrics.spilled_rows.value() >= 1);
         Ok(())
@@ -587,12 +597,13 @@ mod tests {
         let metrics = CollectBuildSideMetrics::new(0, &metrics_set);
 
         let partition = collector
-            .collect(stream, reservation, sampler, &metrics)
+            .collect(stream, reservation, sampler, metrics)
             .await?;
         assert_eq!(partition.num_rows, 0);
         let stream = partition.build_side_batch_stream;
         let is_external = stream.is_external();
         let batches: Vec<EvaluatedBatch> = stream.try_collect().await?;
+        let metrics = &partition.metrics;
         assert!(!is_external);
         assert!(batches.is_empty());
         assert_eq!(metrics.num_batches.value(), 0);
diff --git a/rust/sedona-spatial-join/src/index/memory_plan.rs 
b/rust/sedona-spatial-join/src/index/memory_plan.rs
new file mode 100644
index 00000000..24a25c89
--- /dev/null
+++ b/rust/sedona-spatial-join/src/index/memory_plan.rs
@@ -0,0 +1,191 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::cmp::max;
+
+use datafusion_common::{DataFusionError, Result};
+
+use super::BuildPartition;
+
+/// The memory accounting summary of a build side partition. This is collected
+/// during the build side collection phase and used to estimate the memory 
usage for
+/// running spatial join.
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub(crate) struct PartitionMemorySummary {
+    /// Number of rows in the partition.
+    pub num_rows: usize,
+    /// The total memory reserved when collecting this build side partition.
+    pub reserved_memory: usize,
+    /// The estimated memory usage for building the spatial index for all the 
data in
+    /// this build side partition.
+    pub estimated_index_memory_usage: usize,
+}
+
+impl From<&BuildPartition> for PartitionMemorySummary {
+    fn from(partition: &BuildPartition) -> Self {
+        Self {
+            num_rows: partition.num_rows,
+            reserved_memory: partition.reservation.size(),
+            estimated_index_memory_usage: 
partition.estimated_spatial_index_memory_usage,
+        }
+    }
+}
+
+/// A detailed plan for memory usage during spatial join execution. The 
spatial join
+/// could be spatial-partitioned if the reserved memory is not sufficient to 
hold the
+/// entire spatial index.
+#[derive(Debug, PartialEq, Eq)]
+pub(crate) struct MemoryPlan {
+    /// The total number of rows in the build side.
+    pub num_rows: usize,
+    /// The total memory reserved for the build side.
+    pub reserved_memory: usize,
+    /// The estimated memory usage for building the spatial index for the 
entire build side.
+    /// It could be larger than [`Self::reserved_memory`], and in that case we 
need to
+    /// partition the build side using spatial partitioning.
+    pub estimated_index_memory_usage: usize,
+    /// The memory budget for holding the spatial index. If the spatial join 
is partitioned,
+    /// this is the memory budget for holding the spatial index of a single 
partition.
+    pub memory_for_spatial_index: usize,
+    /// The memory budget for intermittent usage, such as buffering data 
during repartitioning.
+    pub memory_for_intermittent_usage: usize,
+    /// The number of spatial partitions to split the build side into.
+    pub num_partitions: usize,
+}
+
+/// Compute the memory plan for running spatial join based on the memory 
summaries of
+/// build side partitions.
+pub(crate) fn compute_memory_plan<I>(partition_summaries: I) -> 
Result<MemoryPlan>
+where
+    I: IntoIterator<Item = PartitionMemorySummary>,
+{
+    let mut num_rows = 0;
+    let mut reserved_memory = 0;
+    let mut estimated_index_memory_usage = 0;
+
+    for summary in partition_summaries {
+        num_rows += summary.num_rows;
+        reserved_memory += summary.reserved_memory;
+        estimated_index_memory_usage += summary.estimated_index_memory_usage;
+    }
+
+    if reserved_memory == 0 && num_rows > 0 {
+        return Err(DataFusionError::ResourcesExhausted(
+            "Insufficient memory for spatial join".to_string(),
+        ));
+    }
+
+    // Use 80% of reserved memory for holding the spatial index. The other 20% 
are reserved for
+    // intermittent usage like repartitioning buffers.
+    let memory_for_spatial_index =
+        calculate_memory_for_spatial_index(reserved_memory, 
estimated_index_memory_usage);
+    let memory_for_intermittent_usage = reserved_memory - 
memory_for_spatial_index;
+
+    let num_partitions = if num_rows > 0 {
+        max(
+            1,
+            estimated_index_memory_usage.div_ceil(memory_for_spatial_index),
+        )
+    } else {
+        1
+    };
+
+    Ok(MemoryPlan {
+        num_rows,
+        reserved_memory,
+        estimated_index_memory_usage,
+        memory_for_spatial_index,
+        memory_for_intermittent_usage,
+        num_partitions,
+    })
+}
+
+fn calculate_memory_for_spatial_index(
+    reserved_memory: usize,
+    estimated_index_memory_usage: usize,
+) -> usize {
+    if reserved_memory >= estimated_index_memory_usage {
+        // Reserved memory is sufficient to hold the entire spatial index. 
Make sure that
+        // the memory for spatial index is enough for holding the entire 
index. The rest
+        // can be used for intermittent usage.
+        estimated_index_memory_usage
+    } else {
+        // Reserved memory is not sufficient to hold the entire spatial index, 
We need to
+        // partition the dataset using spatial partitioning. Use 80% of 
reserved memory
+        // for holding the partitioned spatial index. The rest is used for 
intermittent usage.
+        let reserved_portion = reserved_memory.saturating_mul(80) / 100;
+        if reserved_portion == 0 {
+            reserved_memory
+        } else {
+            reserved_portion
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    fn summary(
+        num_rows: usize,
+        reserved_memory: usize,
+        estimated_usage: usize,
+    ) -> PartitionMemorySummary {
+        PartitionMemorySummary {
+            num_rows,
+            reserved_memory,
+            estimated_index_memory_usage: estimated_usage,
+        }
+    }
+
+    #[test]
+    fn memory_plan_errors_when_no_memory_but_rows_exist() {
+        let err = compute_memory_plan(vec![summary(10, 0, 512)]).unwrap_err();
+        assert!(matches!(
+            err,
+            DataFusionError::ResourcesExhausted(msg) if 
msg.contains("Insufficient memory")
+        ));
+    }
+
+    #[test]
+    fn memory_plan_partitions_large_jobs() {
+        let plan =
+            compute_memory_plan(vec![summary(100, 2_000, 1_500), summary(150, 
1_000, 3_500)])
+                .expect("plan should succeed");
+
+        assert_eq!(plan.num_rows, 250);
+        assert_eq!(plan.reserved_memory, 3_000);
+        assert_eq!(plan.memory_for_spatial_index, 2_400);
+        assert_eq!(plan.memory_for_intermittent_usage, 600);
+        assert_eq!(plan.num_partitions, 3);
+    }
+
+    #[test]
+    fn memory_plan_handles_zero_rows() {
+        let plan = compute_memory_plan(vec![summary(0, 0, 0)]).expect("plan 
should succeed");
+        assert_eq!(plan.num_partitions, 1);
+        assert_eq!(plan.memory_for_spatial_index, 0);
+        assert_eq!(plan.memory_for_intermittent_usage, 0);
+    }
+
+    #[test]
+    fn memory_plan_uses_entire_reservation_when_fraction_rounds_down() {
+        let plan = compute_memory_plan(vec![summary(10, 1, 1)]).expect("plan 
should succeed");
+        assert_eq!(plan.memory_for_spatial_index, 1);
+        assert_eq!(plan.memory_for_intermittent_usage, 0);
+    }
+}
diff --git a/rust/sedona-spatial-join/src/index/partitioned_index_provider.rs 
b/rust/sedona-spatial-join/src/index/partitioned_index_provider.rs
new file mode 100644
index 00000000..f9aeb893
--- /dev/null
+++ b/rust/sedona-spatial-join/src/index/partitioned_index_provider.rs
@@ -0,0 +1,602 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_schema::SchemaRef;
+use datafusion_common::{DataFusionError, Result, SharedResult};
+use datafusion_common_runtime::JoinSet;
+use datafusion_execution::memory_pool::MemoryReservation;
+use datafusion_expr::JoinType;
+use futures::StreamExt;
+use parking_lot::Mutex;
+use sedona_common::{sedona_internal_err, SpatialJoinOptions};
+use std::ops::DerefMut;
+use std::sync::Arc;
+use tokio::sync::mpsc;
+
+use 
crate::evaluated_batch::evaluated_batch_stream::external::ExternalEvaluatedBatchStream;
+use crate::index::BuildPartition;
+use crate::partitioning::stream_repartitioner::{SpilledPartition, 
SpilledPartitions};
+use crate::utils::disposable_async_cell::DisposableAsyncCell;
+use crate::{
+    index::{SpatialIndex, SpatialIndexBuilder, SpatialJoinBuildMetrics},
+    partitioning::SpatialPartition,
+    spatial_predicate::SpatialPredicate,
+};
+
+pub(crate) struct PartitionedIndexProvider {
+    schema: SchemaRef,
+    spatial_predicate: SpatialPredicate,
+    options: SpatialJoinOptions,
+    join_type: JoinType,
+    probe_threads_count: usize,
+    metrics: SpatialJoinBuildMetrics,
+
+    /// Data on the build side to build index for
+    data: BuildSideData,
+
+    /// Async cells for indexes, one per regular partition
+    index_cells: Vec<DisposableAsyncCell<SharedResult<Arc<SpatialIndex>>>>,
+
+    /// The memory reserved in the build side collection phase. We'll hold 
them until
+    /// we don't need to build spatial indexes.
+    _reservations: Vec<MemoryReservation>,
+}
+
+pub(crate) enum BuildSideData {
+    SinglePartition(Mutex<Option<Vec<BuildPartition>>>),
+    MultiPartition(Mutex<SpilledPartitions>),
+}
+
+impl PartitionedIndexProvider {
+    #[allow(clippy::too_many_arguments)]
+    pub fn new_multi_partition(
+        schema: SchemaRef,
+        spatial_predicate: SpatialPredicate,
+        options: SpatialJoinOptions,
+        join_type: JoinType,
+        probe_threads_count: usize,
+        partitioned_spill_files: SpilledPartitions,
+        metrics: SpatialJoinBuildMetrics,
+        reservations: Vec<MemoryReservation>,
+    ) -> Self {
+        let num_partitions = partitioned_spill_files.num_regular_partitions();
+        let index_cells = (0..num_partitions)
+            .map(|_| DisposableAsyncCell::new())
+            .collect();
+        Self {
+            schema,
+            spatial_predicate,
+            options,
+            join_type,
+            probe_threads_count,
+            metrics,
+            data: 
BuildSideData::MultiPartition(Mutex::new(partitioned_spill_files)),
+            index_cells,
+            _reservations: reservations,
+        }
+    }
+
+    #[allow(clippy::too_many_arguments)]
+    pub fn new_single_partition(
+        schema: SchemaRef,
+        spatial_predicate: SpatialPredicate,
+        options: SpatialJoinOptions,
+        join_type: JoinType,
+        probe_threads_count: usize,
+        mut build_partitions: Vec<BuildPartition>,
+        metrics: SpatialJoinBuildMetrics,
+    ) -> Self {
+        let reservations = build_partitions
+            .iter_mut()
+            .map(|p| p.reservation.take())
+            .collect();
+        let index_cells = vec![DisposableAsyncCell::new()];
+        Self {
+            schema,
+            spatial_predicate,
+            options,
+            join_type,
+            probe_threads_count,
+            metrics,
+            data: 
BuildSideData::SinglePartition(Mutex::new(Some(build_partitions))),
+            index_cells,
+            _reservations: reservations,
+        }
+    }
+
+    pub fn new_empty(
+        schema: SchemaRef,
+        spatial_predicate: SpatialPredicate,
+        options: SpatialJoinOptions,
+        join_type: JoinType,
+        probe_threads_count: usize,
+        metrics: SpatialJoinBuildMetrics,
+    ) -> Self {
+        let build_partitions = Vec::new();
+        Self::new_single_partition(
+            schema,
+            spatial_predicate,
+            options,
+            join_type,
+            probe_threads_count,
+            build_partitions,
+            metrics,
+        )
+    }
+
+    pub fn num_regular_partitions(&self) -> usize {
+        self.index_cells.len()
+    }
+
+    pub async fn build_or_wait_for_index(
+        &self,
+        partition_id: u32,
+    ) -> Option<Result<Arc<SpatialIndex>>> {
+        let cell = match self.index_cells.get(partition_id as usize) {
+            Some(cell) => cell,
+            None => {
+                return Some(sedona_internal_err!(
+                    "partition_id {} exceeds {} partitions",
+                    partition_id,
+                    self.index_cells.len()
+                ))
+            }
+        };
+        if !cell.is_empty() {
+            return get_index_from_cell(cell).await;
+        }
+
+        let res_index = {
+            let opt_res_index = self.maybe_build_index(partition_id).await;
+            match opt_res_index {
+                Some(res_index) => res_index,
+                None => {
+                    // The build side data for building the index has already 
been consumed by someone else,
+                    // we just need to wait for the task consumed the data to 
finish building the index.
+                    return get_index_from_cell(cell).await;
+                }
+            }
+        };
+
+        match res_index {
+            Ok(idx) => {
+                if let Err(e) = cell.set(Ok(Arc::clone(&idx))) {
+                    // This is probably because the cell has been disposed. No 
one
+                    // will get the index from the cell so this failure is not 
a big deal.
+                    log::debug!("Cannot set the index into the async cell: 
{:?}", e);
+                }
+                Some(Ok(idx))
+            }
+            Err(err) => {
+                let err_arc = Arc::new(err);
+                if let Err(e) = cell.set(Err(Arc::clone(&err_arc))) {
+                    log::debug!(
+                        "Cannot set the index build error into the async cell: 
{:?}",
+                        e
+                    );
+                }
+                Some(Err(DataFusionError::Shared(err_arc)))
+            }
+        }
+    }
+
+    async fn maybe_build_index(&self, partition_id: u32) -> 
Option<Result<Arc<SpatialIndex>>> {
+        match &self.data {
+            BuildSideData::SinglePartition(build_partition_opt) => {
+                if partition_id != 0 {
+                    return Some(sedona_internal_err!(
+                        "partition_id for single-partition index is not 0"
+                    ));
+                }
+
+                // consume the build side data for building the index
+                let build_partition_opt = {
+                    let mut locked = build_partition_opt.lock();
+                    std::mem::take(locked.deref_mut())
+                };
+
+                let Some(build_partition) = build_partition_opt else {
+                    // already consumed by previous attempts, the result 
should be present in the channel.
+                    return None;
+                };
+                
Some(self.build_index_for_single_partition(build_partition).await)
+            }
+            BuildSideData::MultiPartition(partitioned_spill_files) => {
+                // consume this partition of build side data for building index
+                let spilled_partition = {
+                    let mut locked = partitioned_spill_files.lock();
+                    let partition = SpatialPartition::Regular(partition_id);
+                    if !locked.can_take_spilled_partition(partition) {
+                        // already consumed by previous attempts, the result 
should be present in the channel.
+                        return None;
+                    }
+                    match locked.take_spilled_partition(partition) {
+                        Ok(spilled_partition) => spilled_partition,
+                        Err(e) => return Some(Err(e)),
+                    }
+                };
+                Some(
+                    self.build_index_for_spilled_partition(spilled_partition)
+                        .await,
+                )
+            }
+        }
+    }
+
+    #[cfg(test)]
+    pub async fn wait_for_index(&self, partition_id: u32) -> 
Option<Result<Arc<SpatialIndex>>> {
+        let cell = match self.index_cells.get(partition_id as usize) {
+            Some(cell) => cell,
+            None => {
+                return Some(sedona_internal_err!(
+                    "partition_id {} exceeds {} partitions",
+                    partition_id,
+                    self.index_cells.len()
+                ))
+            }
+        };
+
+        get_index_from_cell(cell).await
+    }
+
+    pub fn dispose_index(&self, partition_id: u32) {
+        if let Some(cell) = self.index_cells.get(partition_id as usize) {
+            cell.dispose();
+        }
+    }
+
+    pub fn num_loaded_indexes(&self) -> usize {
+        self.index_cells
+            .iter()
+            .filter(|index_cell| index_cell.is_set())
+            .count()
+    }
+
+    async fn build_index_for_single_partition(
+        &self,
+        build_partitions: Vec<BuildPartition>,
+    ) -> Result<Arc<SpatialIndex>> {
+        let mut index_builder = SpatialIndexBuilder::new(
+            Arc::clone(&self.schema),
+            self.spatial_predicate.clone(),
+            self.options.clone(),
+            self.join_type,
+            self.probe_threads_count,
+            self.metrics.clone(),
+        )?;
+
+        for build_partition in build_partitions {
+            let stream = build_partition.build_side_batch_stream;
+            let geo_statistics = build_partition.geo_statistics;
+            index_builder.add_stream(stream, geo_statistics).await?;
+        }
+
+        let index = index_builder.finish()?;
+        Ok(Arc::new(index))
+    }
+
+    async fn build_index_for_spilled_partition(
+        &self,
+        spilled_partition: SpilledPartition,
+    ) -> Result<Arc<SpatialIndex>> {
+        let mut index_builder = SpatialIndexBuilder::new(
+            Arc::clone(&self.schema),
+            self.spatial_predicate.clone(),
+            self.options.clone(),
+            self.join_type,
+            self.probe_threads_count,
+            self.metrics.clone(),
+        )?;
+
+        // Spawn tasks to load indexed batches from spilled files concurrently
+        let (spill_files, geo_statistics, _) = spilled_partition.into_inner();
+        let mut join_set: JoinSet<Result<(), DataFusionError>> = 
JoinSet::new();
+        let (tx, mut rx) = mpsc::channel(spill_files.len() * 2 + 1);
+        for spill_file in spill_files {
+            let tx = tx.clone();
+            join_set.spawn(async move {
+                let result = async {
+                    let mut stream = 
ExternalEvaluatedBatchStream::try_from_spill_file(spill_file)?;
+                    while let Some(batch) = stream.next().await {
+                        let indexed_batch = batch?;
+                        if tx.send(Ok(indexed_batch)).await.is_err() {
+                            return Ok(());
+                        }
+                    }
+                    Ok::<(), DataFusionError>(())
+                }
+                .await;
+                if let Err(e) = result {
+                    let _ = tx.send(Err(e)).await;
+                }
+                Ok(())
+            });
+        }
+        drop(tx);
+
+        // Collect the loaded indexed batches and add them to the index builder
+        while let Some(res) = rx.recv().await {
+            let batch = res?;
+            index_builder.add_batch(batch)?;
+        }
+
+        // Ensure all tasks completed successfully
+        while let Some(res) = join_set.join_next().await {
+            if let Err(e) = res {
+                if e.is_panic() {
+                    std::panic::resume_unwind(e.into_panic());
+                }
+                return Err(DataFusionError::External(Box::new(e)));
+            }
+        }
+
+        index_builder.merge_stats(geo_statistics);
+
+        let index = index_builder.finish()?;
+        Ok(Arc::new(index))
+    }
+}
+
+async fn get_index_from_cell(
+    cell: &DisposableAsyncCell<SharedResult<Arc<SpatialIndex>>>,
+) -> Option<Result<Arc<SpatialIndex>>> {
+    match cell.get().await {
+        Some(Ok(index)) => Some(Ok(index)),
+        Some(Err(shared_err)) => 
Some(Err(DataFusionError::Shared(shared_err))),
+        None => None,
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::operand_evaluator::EvaluatedGeometryArray;
+    use crate::partitioning::partition_slots::PartitionSlots;
+    use crate::utils::bbox_sampler::BoundingBoxSamples;
+    use crate::{
+        evaluated_batch::{
+            evaluated_batch_stream::{
+                in_mem::InMemoryEvaluatedBatchStream, 
SendableEvaluatedBatchStream,
+            },
+            EvaluatedBatch,
+        },
+        index::CollectBuildSideMetrics,
+    };
+    use arrow_array::{ArrayRef, BinaryArray, Int32Array, RecordBatch};
+    use arrow_schema::{DataType, Field, Schema, SchemaRef};
+    use datafusion::config::SpillCompression;
+    use datafusion_common::{DataFusionError, Result};
+    use datafusion_execution::{
+        memory_pool::{GreedyMemoryPool, MemoryConsumer, MemoryPool},
+        runtime_env::RuntimeEnv,
+    };
+    use datafusion_expr::JoinType;
+    use datafusion_physical_expr::expressions::Column;
+    use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, 
SpillMetrics};
+    use sedona_expr::statistics::GeoStatistics;
+    use sedona_functions::st_analyze_agg::AnalyzeAccumulator;
+    use sedona_geometry::analyze::analyze_geometry;
+    use sedona_schema::datatypes::WKB_GEOMETRY;
+
+    use crate::evaluated_batch::spill::EvaluatedBatchSpillWriter;
+    use crate::partitioning::stream_repartitioner::{SpilledPartition, 
SpilledPartitions};
+    use crate::spatial_predicate::{RelationPredicate, SpatialRelationType};
+
+    fn sample_schema() -> SchemaRef {
+        Arc::new(Schema::new(vec![
+            Field::new("geom", DataType::Binary, true),
+            Field::new("id", DataType::Int32, false),
+        ]))
+    }
+
+    fn point_wkb(x: f64, y: f64) -> Vec<u8> {
+        let mut buf = vec![1u8, 1, 0, 0, 0];
+        buf.extend_from_slice(&x.to_le_bytes());
+        buf.extend_from_slice(&y.to_le_bytes());
+        buf
+    }
+
+    fn sample_batch(ids: &[i32], wkbs: Vec<Option<Vec<u8>>>) -> 
Result<EvaluatedBatch> {
+        assert_eq!(ids.len(), wkbs.len());
+        let geom_values: Vec<Option<&[u8]>> = wkbs
+            .iter()
+            .map(|opt| opt.as_ref().map(|wkb| wkb.as_slice()))
+            .collect();
+        let geom_array: ArrayRef = 
Arc::new(BinaryArray::from_opt_vec(geom_values));
+        let id_array: ArrayRef = Arc::new(Int32Array::from(ids.to_vec()));
+        let batch = RecordBatch::try_new(sample_schema(), 
vec![geom_array.clone(), id_array])?;
+        let geom = EvaluatedGeometryArray::try_new(geom_array, &WKB_GEOMETRY)?;
+        Ok(EvaluatedBatch {
+            batch,
+            geom_array: geom,
+        })
+    }
+
+    fn predicate() -> SpatialPredicate {
+        SpatialPredicate::Relation(RelationPredicate::new(
+            Arc::new(Column::new("geom", 0)),
+            Arc::new(Column::new("geom", 0)),
+            SpatialRelationType::Intersects,
+        ))
+    }
+
+    fn geo_stats_from_batches(batches: &[EvaluatedBatch]) -> 
Result<GeoStatistics> {
+        let mut analyzer = AnalyzeAccumulator::new(WKB_GEOMETRY, WKB_GEOMETRY);
+        for batch in batches {
+            for wkb in batch.geom_array.wkbs().iter().flatten() {
+                let summary =
+                    analyze_geometry(wkb).map_err(|e| 
DataFusionError::External(Box::new(e)))?;
+                analyzer.ingest_geometry_summary(&summary);
+            }
+        }
+        Ok(analyzer.finish())
+    }
+
+    fn new_reservation(memory_pool: Arc<dyn MemoryPool>) -> MemoryReservation {
+        let consumer = MemoryConsumer::new("PartitionedIndexProviderTest");
+        consumer.register(&memory_pool)
+    }
+
+    fn build_partition_from_batches(
+        memory_pool: Arc<dyn MemoryPool>,
+        batches: Vec<EvaluatedBatch>,
+    ) -> Result<BuildPartition> {
+        let schema = batches
+            .first()
+            .map(|batch| batch.schema())
+            .unwrap_or_else(|| Arc::new(Schema::empty()));
+        let geo_statistics = geo_stats_from_batches(&batches)?;
+        let num_rows = batches.iter().map(|batch| batch.num_rows()).sum();
+        let mut estimated_usage = 0;
+        for batch in &batches {
+            estimated_usage += batch.in_mem_size()?;
+        }
+        let stream: SendableEvaluatedBatchStream =
+            Box::pin(InMemoryEvaluatedBatchStream::new(schema, batches));
+        Ok(BuildPartition {
+            num_rows,
+            build_side_batch_stream: stream,
+            geo_statistics,
+            bbox_samples: BoundingBoxSamples::empty(),
+            estimated_spatial_index_memory_usage: estimated_usage,
+            reservation: new_reservation(memory_pool),
+            metrics: CollectBuildSideMetrics::new(0, 
&ExecutionPlanMetricsSet::new()),
+        })
+    }
+
+    fn spill_partition_from_batches(
+        runtime_env: Arc<RuntimeEnv>,
+        batches: Vec<EvaluatedBatch>,
+    ) -> Result<SpilledPartition> {
+        if batches.is_empty() {
+            return Ok(SpilledPartition::empty());
+        }
+        let schema = batches[0].schema();
+        let sedona_type = batches[0].geom_array.sedona_type.clone();
+        let mut writer = EvaluatedBatchSpillWriter::try_new(
+            runtime_env,
+            schema,
+            &sedona_type,
+            "partitioned-index-provider-test",
+            SpillCompression::Uncompressed,
+            SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0),
+            None,
+        )?;
+        let mut num_rows = 0;
+        for batch in &batches {
+            num_rows += batch.num_rows();
+            writer.append(batch)?;
+        }
+        let geo_statistics = geo_stats_from_batches(&batches)?;
+        let spill_file = writer.finish()?;
+        Ok(SpilledPartition::new(
+            vec![Arc::new(spill_file)],
+            geo_statistics,
+            num_rows,
+        ))
+    }
+
+    fn make_spilled_partitions(
+        runtime_env: Arc<RuntimeEnv>,
+        partitions: Vec<Vec<EvaluatedBatch>>,
+    ) -> Result<SpilledPartitions> {
+        let slots = PartitionSlots::new(partitions.len());
+        let mut spilled = Vec::with_capacity(slots.total_slots());
+        for partition_batches in partitions {
+            spilled.push(spill_partition_from_batches(
+                Arc::clone(&runtime_env),
+                partition_batches,
+            )?);
+        }
+        spilled.push(SpilledPartition::empty());
+        spilled.push(SpilledPartition::empty());
+        Ok(SpilledPartitions::new(slots, spilled))
+    }
+
+    #[tokio::test]
+    async fn single_partition_builds_once_and_is_cached() -> Result<()> {
+        let memory_pool: Arc<dyn MemoryPool> = 
Arc::new(GreedyMemoryPool::new(1 << 20));
+        let batches = vec![sample_batch(
+            &[1, 2],
+            vec![Some(point_wkb(10.0, 10.0)), Some(point_wkb(20.0, 20.0))],
+        )?];
+        let build_partition = 
build_partition_from_batches(Arc::clone(&memory_pool), batches)?;
+        let metrics = ExecutionPlanMetricsSet::new();
+        let provider = PartitionedIndexProvider::new_single_partition(
+            sample_schema(),
+            predicate(),
+            SpatialJoinOptions::default(),
+            JoinType::Inner,
+            1,
+            vec![build_partition],
+            SpatialJoinBuildMetrics::new(0, &metrics),
+        );
+
+        let first_index = provider
+            .build_or_wait_for_index(0)
+            .await
+            .expect("partition exists")?;
+        assert_eq!(first_index.indexed_batches.len(), 1);
+        assert_eq!(provider.num_loaded_indexes(), 1);
+
+        let cached_index = provider
+            .wait_for_index(0)
+            .await
+            .expect("cached value must remain accessible")?;
+        assert!(Arc::ptr_eq(&first_index, &cached_index));
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn multi_partition_concurrent_requests_share_indexes() -> Result<()> 
{
+        let memory_pool: Arc<dyn MemoryPool> = 
Arc::new(GreedyMemoryPool::new(1 << 20));
+        let runtime_env = Arc::new(RuntimeEnv::default());
+        let partition_batches = vec![
+            vec![sample_batch(&[10], vec![Some(point_wkb(0.0, 0.0))])?],
+            vec![sample_batch(&[20], vec![Some(point_wkb(50.0, 50.0))])?],
+        ];
+        let spilled_partitions = make_spilled_partitions(runtime_env, 
partition_batches)?;
+        let metrics = ExecutionPlanMetricsSet::new();
+        let provider = Arc::new(PartitionedIndexProvider::new_multi_partition(
+            sample_schema(),
+            predicate(),
+            SpatialJoinOptions::default(),
+            JoinType::Inner,
+            1,
+            spilled_partitions,
+            SpatialJoinBuildMetrics::new(0, &metrics),
+            vec![new_reservation(Arc::clone(&memory_pool))],
+        ));
+
+        let (idx_one, idx_two) = tokio::join!(
+            provider.build_or_wait_for_index(0),
+            provider.build_or_wait_for_index(0)
+        );
+        let idx_one = idx_one.expect("partition exists")?;
+        let idx_two = idx_two.expect("partition exists")?;
+        assert!(Arc::ptr_eq(&idx_one, &idx_two));
+        assert_eq!(idx_one.indexed_batches.len(), 1);
+
+        let second_partition = provider
+            .build_or_wait_for_index(1)
+            .await
+            .expect("second partition exists")?;
+        assert_eq!(second_partition.indexed_batches.len(), 1);
+        assert_eq!(provider.num_loaded_indexes(), 2);
+        Ok(())
+    }
+}
diff --git a/rust/sedona-spatial-join/src/index/spatial_index.rs 
b/rust/sedona-spatial-join/src/index/spatial_index.rs
index 9364920a..bff7895d 100644
--- a/rust/sedona-spatial-join/src/index/spatial_index.rs
+++ b/rust/sedona-spatial-join/src/index/spatial_index.rs
@@ -27,7 +27,6 @@ use arrow_array::RecordBatch;
 use arrow_schema::SchemaRef;
 use datafusion_common::{DataFusionError, Result};
 use datafusion_common_runtime::JoinSet;
-use datafusion_execution::memory_pool::MemoryReservation;
 use float_next_after::NextAfter;
 use geo::BoundingRect;
 use geo_index::rtree::{
@@ -95,11 +94,6 @@ pub struct SpatialIndex {
 
     /// Shared KNN components (distance metrics and geometry cache) for 
efficient KNN queries
     pub(crate) knn_components: Option<KnnComponents>,
-
-    /// Memory reservation for tracking the memory usage of the spatial index
-    /// Cleared on `SpatialIndex` drop
-    #[expect(dead_code)]
-    pub(crate) reservation: MemoryReservation,
 }
 
 impl SpatialIndex {
@@ -108,7 +102,6 @@ impl SpatialIndex {
         schema: SchemaRef,
         options: SpatialJoinOptions,
         probe_threads_counter: AtomicUsize,
-        reservation: MemoryReservation,
     ) -> Self {
         let evaluator = create_operand_evaluator(&spatial_predicate, 
options.clone());
         let refiner = create_refiner(
@@ -133,7 +126,6 @@ impl SpatialIndex {
             visited_build_side: None,
             probe_threads_counter,
             knn_components,
-            reservation,
         }
     }
 
@@ -681,7 +673,6 @@ mod tests {
     use arrow_array::RecordBatch;
     use arrow_schema::{DataType, Field};
     use datafusion_common::JoinSide;
-    use datafusion_execution::memory_pool::GreedyMemoryPool;
     use datafusion_expr::JoinType;
     use datafusion_physical_expr::expressions::Column;
     use geo_traits::Dimensions;
@@ -692,7 +683,6 @@ mod tests {
 
     #[test]
     fn test_spatial_index_builder_empty() {
-        let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
         let options = SpatialJoinOptions {
             execution_mode: ExecutionMode::PrepareBuild,
             ..Default::default()
@@ -711,7 +701,6 @@ mod tests {
             options,
             JoinType::Inner,
             4,
-            memory_pool,
             metrics,
         )
         .unwrap();
@@ -724,7 +713,6 @@ mod tests {
 
     #[test]
     fn test_spatial_index_builder_add_batch() {
-        let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
         let options = SpatialJoinOptions {
             execution_mode: ExecutionMode::PrepareBuild,
             ..Default::default()
@@ -750,7 +738,6 @@ mod tests {
             options,
             JoinType::Inner,
             4,
-            memory_pool,
             metrics,
         )
         .unwrap();
@@ -779,7 +766,6 @@ mod tests {
     #[test]
     fn test_knn_query_execution_with_sample_data() {
         // Create a spatial index with sample geometry data
-        let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
         let options = SpatialJoinOptions {
             execution_mode: ExecutionMode::PrepareBuild,
             ..Default::default()
@@ -807,7 +793,6 @@ mod tests {
             options,
             JoinType::Inner,
             4,
-            memory_pool,
             metrics,
         )
         .unwrap();
@@ -878,7 +863,6 @@ mod tests {
     #[test]
     fn test_knn_query_execution_with_different_k_values() {
         // Create spatial index with more data points
-        let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
         let options = SpatialJoinOptions {
             execution_mode: ExecutionMode::PrepareBuild,
             ..Default::default()
@@ -905,7 +889,6 @@ mod tests {
             options,
             JoinType::Inner,
             4,
-            memory_pool,
             metrics,
         )
         .unwrap();
@@ -969,7 +952,6 @@ mod tests {
     #[test]
     fn test_knn_query_execution_with_spheroid_distance() {
         // Create spatial index
-        let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
         let options = SpatialJoinOptions {
             execution_mode: ExecutionMode::PrepareBuild,
             ..Default::default()
@@ -996,7 +978,6 @@ mod tests {
             options,
             JoinType::Inner,
             4,
-            memory_pool,
             metrics,
         )
         .unwrap();
@@ -1066,7 +1047,6 @@ mod tests {
     #[test]
     fn test_knn_query_execution_edge_cases() {
         // Create spatial index
-        let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
         let options = SpatialJoinOptions {
             execution_mode: ExecutionMode::PrepareBuild,
             ..Default::default()
@@ -1093,7 +1073,6 @@ mod tests {
             options,
             JoinType::Inner,
             4,
-            memory_pool,
             metrics,
         )
         .unwrap();
@@ -1159,7 +1138,6 @@ mod tests {
     #[test]
     fn test_knn_query_execution_empty_index() {
         // Create empty spatial index
-        let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
         let options = SpatialJoinOptions {
             execution_mode: ExecutionMode::PrepareBuild,
             ..Default::default()
@@ -1181,7 +1159,6 @@ mod tests {
             options,
             JoinType::Inner,
             4,
-            memory_pool,
             metrics,
         )
         .unwrap();
@@ -1207,7 +1184,6 @@ mod tests {
     #[test]
     fn test_knn_query_execution_with_tie_breakers() {
         // Create a spatial index with sample geometry data
-        let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
         let options = SpatialJoinOptions {
             execution_mode: ExecutionMode::PrepareBuild,
             ..Default::default()
@@ -1234,7 +1210,6 @@ mod tests {
             options,
             JoinType::Inner,
             1, // probe_threads_count
-            memory_pool.clone(),
             metrics,
         )
         .unwrap();
@@ -1322,7 +1297,6 @@ mod tests {
     #[test]
     fn test_query_knn_with_geometry_distance() {
         // Create a spatial index with sample geometry data
-        let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
         let options = SpatialJoinOptions {
             execution_mode: ExecutionMode::PrepareBuild,
             ..Default::default()
@@ -1350,7 +1324,6 @@ mod tests {
             options,
             JoinType::Inner,
             4,
-            memory_pool,
             metrics,
         )
         .unwrap();
@@ -1407,7 +1380,6 @@ mod tests {
     fn test_query_knn_with_mixed_geometries() {
         // Create a spatial index with complex geometries where geometry-based
         // distance should differ from centroid-based distance
-        let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
         let options = SpatialJoinOptions {
             execution_mode: ExecutionMode::PrepareBuild,
             ..Default::default()
@@ -1435,7 +1407,6 @@ mod tests {
             options,
             JoinType::Inner,
             4,
-            memory_pool,
             metrics,
         )
         .unwrap();
@@ -1489,7 +1460,6 @@ mod tests {
     #[test]
     fn test_query_knn_with_tie_breakers_geometry_distance() {
         // Create a spatial index with geometries that have identical 
distances for tie-breaker testing
-        let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
         let options = SpatialJoinOptions {
             execution_mode: ExecutionMode::PrepareBuild,
             ..Default::default()
@@ -1516,7 +1486,6 @@ mod tests {
             options,
             JoinType::Inner,
             4,
-            memory_pool,
             metrics,
         )
         .unwrap();
@@ -1610,7 +1579,6 @@ mod tests {
     #[test]
     fn test_knn_query_with_empty_geometry() {
         // Create a spatial index with sample geometry data like other tests
-        let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
         let options = SpatialJoinOptions {
             execution_mode: ExecutionMode::PrepareBuild,
             ..Default::default()
@@ -1638,7 +1606,6 @@ mod tests {
             options,
             JoinType::Inner,
             1, // probe_threads_count
-            memory_pool.clone(),
             metrics,
         )
         .unwrap();
@@ -1687,7 +1654,6 @@ mod tests {
         build_geoms: &[Option<&str>],
         options: SpatialJoinOptions,
     ) -> Arc<SpatialIndex> {
-        let memory_pool = Arc::new(GreedyMemoryPool::new(100 * 1024 * 1024));
         let metrics = SpatialJoinBuildMetrics::default();
         let spatial_predicate = 
SpatialPredicate::Relation(RelationPredicate::new(
             Arc::new(Column::new("left", 0)),
@@ -1706,7 +1672,6 @@ mod tests {
             options,
             JoinType::Inner,
             1,
-            memory_pool,
             metrics,
         )
         .unwrap();
diff --git a/rust/sedona-spatial-join/src/index/spatial_index_builder.rs 
b/rust/sedona-spatial-join/src/index/spatial_index_builder.rs
index 9d97b539..ca2b0088 100644
--- a/rust/sedona-spatial-join/src/index/spatial_index_builder.rs
+++ b/rust/sedona-spatial-join/src/index/spatial_index_builder.rs
@@ -22,16 +22,15 @@ use sedona_common::SpatialJoinOptions;
 use sedona_expr::statistics::GeoStatistics;
 
 use datafusion_common::{utils::proxy::VecAllocExt, Result};
-use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, 
MemoryReservation};
 use datafusion_expr::JoinType;
 use futures::StreamExt;
 use geo_index::rtree::{sort::HilbertSort, RTree, RTreeBuilder, RTreeIndex};
 use parking_lot::Mutex;
-use std::sync::{atomic::AtomicUsize, Arc};
+use std::sync::atomic::AtomicUsize;
 
 use crate::{
-    evaluated_batch::EvaluatedBatch,
-    index::{knn_adapter::KnnComponents, spatial_index::SpatialIndex, 
BuildPartition},
+    evaluated_batch::{evaluated_batch_stream::SendableEvaluatedBatchStream, 
EvaluatedBatch},
+    index::{knn_adapter::KnnComponents, spatial_index::SpatialIndex},
     operand_evaluator::create_operand_evaluator,
     refine::create_refiner,
     spatial_predicate::SpatialPredicate,
@@ -63,8 +62,6 @@ pub struct SpatialIndexBuilder {
 
     /// Batches to be indexed
     indexed_batches: Vec<EvaluatedBatch>,
-    /// Memory reservation for tracking the memory usage of the spatial index
-    reservation: MemoryReservation,
 
     /// Statistics for indexed geometries
     stats: GeoStatistics,
@@ -99,12 +96,8 @@ impl SpatialIndexBuilder {
         options: SpatialJoinOptions,
         join_type: JoinType,
         probe_threads_count: usize,
-        memory_pool: Arc<dyn MemoryPool>,
         metrics: SpatialJoinBuildMetrics,
     ) -> Result<Self> {
-        let consumer = MemoryConsumer::new("SpatialJoinIndex");
-        let reservation = consumer.register(&memory_pool);
-
         Ok(Self {
             schema,
             spatial_predicate,
@@ -113,7 +106,6 @@ impl SpatialIndexBuilder {
             probe_threads_count,
             metrics,
             indexed_batches: Vec::new(),
-            reservation,
             stats: GeoStatistics::empty(),
             memory_used: 0,
         })
@@ -258,7 +250,6 @@ impl SpatialIndexBuilder {
                 self.schema,
                 self.options,
                 AtomicUsize::new(self.probe_threads_count),
-                self.reservation,
             ));
         }
 
@@ -297,6 +288,10 @@ impl SpatialIndexBuilder {
             }
         };
 
+        log::debug!(
+            "Estimated memory used by spatial index: {}",
+            self.memory_used
+        );
         Ok(SpatialIndex {
             schema: self.schema,
             options: self.options,
@@ -309,26 +304,19 @@ impl SpatialIndexBuilder {
             visited_build_side,
             probe_threads_counter: AtomicUsize::new(self.probe_threads_count),
             knn_components: knn_components_opt,
-            reservation: self.reservation,
         })
     }
 
-    pub async fn add_partitions(&mut self, partitions: Vec<BuildPartition>) -> 
Result<()> {
-        for partition in partitions {
-            self.add_partition(partition).await?;
-        }
-        Ok(())
-    }
-
-    pub async fn add_partition(&mut self, mut partition: BuildPartition) -> 
Result<()> {
-        let mut stream = partition.build_side_batch_stream;
+    pub async fn add_stream(
+        &mut self,
+        mut stream: SendableEvaluatedBatchStream,
+        geo_statistics: GeoStatistics,
+    ) -> Result<()> {
         while let Some(batch) = stream.next().await {
             let indexed_batch = batch?;
             self.add_batch(indexed_batch)?;
         }
-        self.merge_stats(partition.geo_statistics);
-        let mem_bytes = partition.reservation.free();
-        self.reservation.try_grow(mem_bytes)?;
+        self.merge_stats(geo_statistics);
         Ok(())
     }
 
diff --git a/rust/sedona-spatial-join/src/lib.rs 
b/rust/sedona-spatial-join/src/lib.rs
index 94af3f22..2abaf3c4 100644
--- a/rust/sedona-spatial-join/src/lib.rs
+++ b/rust/sedona-spatial-join/src/lib.rs
@@ -15,13 +15,13 @@
 // specific language governing permissions and limitations
 // under the License.
 
-mod build_index;
 pub mod evaluated_batch;
 pub mod exec;
 mod index;
 pub mod operand_evaluator;
 pub mod optimizer;
 pub mod partitioning;
+mod prepare;
 pub mod refine;
 pub mod spatial_predicate;
 mod stream;
@@ -31,7 +31,6 @@ pub use exec::SpatialJoinExec;
 pub use optimizer::register_spatial_join_optimizer;
 
 // Re-export types needed for external usage (e.g., in Comet)
-pub use build_index::build_index;
 pub use index::{SpatialIndex, SpatialJoinBuildMetrics};
 pub use spatial_predicate::SpatialPredicate;
 
diff --git a/rust/sedona-spatial-join/src/partitioning/kdb.rs 
b/rust/sedona-spatial-join/src/partitioning/kdb.rs
index 32ac3a4c..c09e98ff 100644
--- a/rust/sedona-spatial-join/src/partitioning/kdb.rs
+++ b/rust/sedona-spatial-join/src/partitioning/kdb.rs
@@ -43,7 +43,9 @@
 use std::sync::Arc;
 
 use crate::partitioning::{
-    util::{bbox_to_geo_rect, rect_contains_point, rect_intersection_area, 
rects_intersect},
+    util::{
+        bbox_to_geo_rect, make_rect, rect_contains_point, 
rect_intersection_area, rects_intersect,
+    },
     SpatialPartition, SpatialPartitioner,
 };
 use datafusion_common::Result;
@@ -126,9 +128,12 @@ impl KDBTree {
         if max_items_per_node == 0 {
             return sedona_internal_err!("max_items_per_node must be greater 
than 0");
         }
-        let Some(extent_rect) = bbox_to_geo_rect(&extent)? else {
-            return sedona_internal_err!("KDBTree extent cannot be empty");
-        };
+
+        // extent_rect is a sentinel rect if the bounding box is empty. In 
that case,
+        // almost all insertions will be ignored. We are free to partition the 
data
+        // arbitrarily when the extent is empty.
+        let extent_rect = bbox_to_geo_rect(&extent)?.unwrap_or(make_rect(0.0, 
0.0, 0.0, 0.0));
+
         Ok(Self::new_with_level(
             max_items_per_node,
             max_levels,
@@ -507,6 +512,13 @@ impl KDBPartitioner {
         }
         Ok(())
     }
+
+    /// Return the tree structure in human-readable format for debugging 
purposes.
+    pub fn debug_str(&self) -> String {
+        let mut output = String::new();
+        let _ = self.debug_print(&mut output);
+        output
+    }
 }
 
 impl SpatialPartitioner for KDBPartitioner {
@@ -966,4 +978,19 @@ mod tests {
             SpatialPartition::None
         );
     }
+
+    #[test]
+    fn test_kdb_partitioner_empty_extent() {
+        let extent = BoundingBox::empty();
+        let bboxes = vec![
+            BoundingBox::xy((0.0, 10.0), (0.0, 10.0)),
+            BoundingBox::xy((1.0, 10.0), (1.0, 10.0)),
+        ];
+        let partitioner = KDBPartitioner::build(bboxes.clone().into_iter(), 
10, 4, extent).unwrap();
+
+        // Partition calls should succeed
+        for test_bbox in bboxes {
+            assert!(partitioner.partition(&test_bbox).is_ok());
+        }
+    }
 }
diff --git a/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs 
b/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs
index 44591107..038530b1 100644
--- a/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs
+++ b/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs
@@ -280,6 +280,13 @@ impl SpilledPartitions {
         }
         Ok(())
     }
+
+    /// Return debug info for this spilled partitions as a string.
+    pub fn debug_str(&self) -> String {
+        let mut output = String::new();
+        let _ = self.debug_print(&mut output);
+        output
+    }
 }
 
 /// Incremental (stateful) repartitioner for an [`EvaluatedBatch`] stream.
diff --git a/rust/sedona-spatial-join/src/prepare.rs 
b/rust/sedona-spatial-join/src/prepare.rs
new file mode 100644
index 00000000..76e825b3
--- /dev/null
+++ b/rust/sedona-spatial-join/src/prepare.rs
@@ -0,0 +1,514 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{mem, sync::Arc};
+
+use arrow_schema::SchemaRef;
+use datafusion_common::Result;
+use datafusion_common_runtime::JoinSet;
+use datafusion_execution::{
+    disk_manager::RefCountedTempFile, memory_pool::MemoryConsumer, 
SendableRecordBatchStream,
+    TaskContext,
+};
+use datafusion_expr::JoinType;
+use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet;
+use fastrand::Rng;
+use sedona_common::{sedona_internal_err, NumSpatialPartitionsConfig, 
SedonaOptions};
+use sedona_expr::statistics::GeoStatistics;
+use sedona_geometry::bounding_box::BoundingBox;
+
+use crate::{
+    index::{
+        memory_plan::{compute_memory_plan, MemoryPlan, PartitionMemorySummary},
+        partitioned_index_provider::PartitionedIndexProvider,
+        BuildPartition, BuildSideBatchesCollector, CollectBuildSideMetrics,
+        SpatialJoinBuildMetrics,
+    },
+    partitioning::{
+        kdb::KDBPartitioner,
+        stream_repartitioner::{SpilledPartition, SpilledPartitions, 
StreamRepartitioner},
+        PartitionedSide, SpatialPartition, SpatialPartitioner,
+    },
+    spatial_predicate::SpatialPredicate,
+    utils::bbox_sampler::BoundingBoxSamples,
+};
+
+pub(crate) struct SpatialJoinComponents {
+    pub partitioned_index_provider: Arc<PartitionedIndexProvider>,
+}
+
+/// Builder for constructing `SpatialJoinComponents` from build-side streams.
+///
+/// Calling `build(...)` performs the full preparation flow:
+/// - collect (and spill if needed) build-side batches,
+/// - compute memory plan and pick single- or multi-partition mode,
+/// - repartition the build side into spatial partitions in multi-partition 
mode,
+/// - create the appropriate `PartitionedIndexProvider` for creating spatial 
indexes.
+pub(crate) struct SpatialJoinComponentsBuilder {
+    context: Arc<TaskContext>,
+    build_schema: SchemaRef,
+    spatial_predicate: SpatialPredicate,
+    join_type: JoinType,
+    probe_threads_count: usize,
+    metrics: ExecutionPlanMetricsSet,
+    seed: u64,
+    sedona_options: SedonaOptions,
+}
+
+impl SpatialJoinComponentsBuilder {
+    /// Create a new builder capturing the execution context and configuration
+    /// required to produce `SpatialJoinComponents` from build-side streams.
+    pub fn new(
+        context: Arc<TaskContext>,
+        build_schema: SchemaRef,
+        spatial_predicate: SpatialPredicate,
+        join_type: JoinType,
+        probe_threads_count: usize,
+        metrics: ExecutionPlanMetricsSet,
+        seed: u64,
+    ) -> Self {
+        let session_config = context.session_config();
+        let sedona_options = session_config
+            .options()
+            .extensions
+            .get::<SedonaOptions>()
+            .cloned()
+            .unwrap_or_default();
+        Self {
+            context,
+            build_schema,
+            spatial_predicate,
+            join_type,
+            probe_threads_count,
+            metrics,
+            seed,
+            sedona_options,
+        }
+    }
+
+    /// Prepare and return `SpatialJoinComponents` for the given build-side
+    /// streams. This drives the end-to-end preparation flow and returns a
+    /// ready-to-use `SpatialJoinComponents` for the spatial join operator.
+    pub async fn build(
+        mut self,
+        build_streams: Vec<SendableRecordBatchStream>,
+    ) -> Result<SpatialJoinComponents> {
+        let num_partitions = build_streams.len();
+        if num_partitions == 0 {
+            log::debug!("Build side has no data. Creating empty spatial 
index.");
+            let partitioned_index_provider = 
PartitionedIndexProvider::new_empty(
+                self.build_schema,
+                self.spatial_predicate,
+                self.sedona_options.spatial_join,
+                self.join_type,
+                self.probe_threads_count,
+                SpatialJoinBuildMetrics::new(0, &self.metrics),
+            );
+            return Ok(SpatialJoinComponents {
+                partitioned_index_provider: 
Arc::new(partitioned_index_provider),
+            });
+        }
+
+        let mut rng = Rng::with_seed(self.seed);
+        let mut build_partitions = self
+            .collect_build_partitions(build_streams, rng.u64(0..0xFFFF))
+            .await?;
+
+        // Determine the number of spatial partitions based on the memory 
reserved and the estimated amount of
+        // memory required for loading the entire build side into a spatial 
index
+        let memory_plan =
+            
compute_memory_plan(build_partitions.iter().map(PartitionMemorySummary::from))?;
+        log::debug!("Computed memory plan for spatial join:\n{:#?}", 
memory_plan);
+        let num_partitions = match self
+            .sedona_options
+            .spatial_join
+            .debug
+            .num_spatial_partitions
+        {
+            NumSpatialPartitionsConfig::Auto => memory_plan.num_partitions,
+            NumSpatialPartitionsConfig::Fixed(n) => {
+                log::debug!("Override number of spatial partitions to {}", n);
+                n
+            }
+        };
+
+        if num_partitions == 1 {
+            log::debug!("Running single-partitioned in-memory spatial join");
+            let partitioned_index_provider = 
PartitionedIndexProvider::new_single_partition(
+                self.build_schema,
+                self.spatial_predicate,
+                self.sedona_options.spatial_join,
+                self.join_type,
+                self.probe_threads_count,
+                build_partitions,
+                SpatialJoinBuildMetrics::new(0, &self.metrics),
+            );
+            Ok(SpatialJoinComponents {
+                partitioned_index_provider: 
Arc::new(partitioned_index_provider),
+            })
+        } else {
+            // Collect all memory reservations grown during build side 
collection
+            let mut reservations = Vec::with_capacity(build_partitions.len());
+            for partition in &mut build_partitions {
+                reservations.push(partition.reservation.take());
+            }
+
+            // Partition the build side into multiple spatial partitions, each 
partition can be fully
+            // loaded into an in-memory spatial index
+            let build_partitioner = self.build_spatial_partitioner(
+                num_partitions,
+                &mut build_partitions,
+                rng.u64(0..0xFFFF),
+            )?;
+            let partitioned_spill_files_vec = self
+                .repartition_build_side(build_partitions, build_partitioner, 
&memory_plan)
+                .await?;
+
+            let merged_spilled_partitions = 
merge_spilled_partitions(partitioned_spill_files_vec)?;
+            log::debug!(
+                "Build side spatial partitions:\n{}",
+                merged_spilled_partitions.debug_str()
+            );
+
+            // Sanity check: Multi and None partitions must be empty. All the 
geometries in the build side
+            // should fall into regular partitions
+            for partition in [SpatialPartition::None, SpatialPartition::Multi] 
{
+                let spilled_partition = 
merged_spilled_partitions.spilled_partition(partition)?;
+                if !spilled_partition.spill_files().is_empty() {
+                    return sedona_internal_err!(
+                        "Build side spatial partitions {:?} should be empty",
+                        partition
+                    );
+                }
+            }
+
+            let partitioned_index_provider = 
PartitionedIndexProvider::new_multi_partition(
+                self.build_schema,
+                self.spatial_predicate,
+                self.sedona_options.spatial_join,
+                self.join_type,
+                self.probe_threads_count,
+                merged_spilled_partitions,
+                SpatialJoinBuildMetrics::new(0, &self.metrics),
+                reservations,
+            );
+
+            Ok(SpatialJoinComponents {
+                partitioned_index_provider: 
Arc::new(partitioned_index_provider),
+            })
+        }
+    }
+
+    /// Collect build-side batches from the provided streams and return a
+    /// vector of `BuildPartition` entries representing the collected data.
+    /// The collector may spill to disk according to the configured options.
+    async fn collect_build_partitions(
+        &mut self,
+        build_streams: Vec<SendableRecordBatchStream>,
+        seed: u64,
+    ) -> Result<Vec<BuildPartition>> {
+        let runtime_env = self.context.runtime_env();
+        let session_config = self.context.session_config();
+        let spill_compression = session_config.spill_compression();
+
+        let num_partitions = build_streams.len();
+        let mut collect_metrics_vec = Vec::with_capacity(num_partitions);
+        let mut reservations = Vec::with_capacity(num_partitions);
+        let memory_pool = self.context.memory_pool();
+        for k in 0..num_partitions {
+            let consumer = 
MemoryConsumer::new(format!("SpatialJoinCollectBuildSide[{k}]"))
+                .with_can_spill(true);
+            let reservation = consumer.register(memory_pool);
+            reservations.push(reservation);
+            collect_metrics_vec.push(CollectBuildSideMetrics::new(k, 
&self.metrics));
+        }
+
+        let collector = BuildSideBatchesCollector::new(
+            self.spatial_predicate.clone(),
+            self.sedona_options.spatial_join.clone(),
+            Arc::clone(&runtime_env),
+            spill_compression,
+        );
+        let build_partitions = collector
+            .collect_all(
+                build_streams,
+                reservations,
+                collect_metrics_vec.clone(),
+                self.sedona_options
+                    .spatial_join
+                    .concurrent_build_side_collection,
+                seed,
+            )
+            .await?;
+
+        Ok(build_partitions)
+    }
+
+    /// Construct a `SpatialPartitioner` (e.g. KDB) from collected samples so
+    /// the build and probe sides can be partitioned spatially across
+    /// `num_partitions`.
+    fn build_spatial_partitioner(
+        &self,
+        num_partitions: usize,
+        build_partitions: &mut Vec<BuildPartition>,
+        seed: u64,
+    ) -> Result<Arc<dyn SpatialPartitioner>> {
+        if matches!(
+            self.spatial_predicate,
+            SpatialPredicate::KNearestNeighbors(..)
+        ) {
+            return sedona_internal_err!("Partitioned KNN join is not supported 
yet");
+        }
+
+        let build_partitioner: Arc<dyn SpatialPartitioner> = {
+            // Use spatial partitioners to partition the build side and the 
probe side, this will
+            // reduce the amount of work needed for probing each partitioned 
index.
+            // The KDB partitioner is built using the collected bounding box 
samples.
+            let mut bbox_samples = BoundingBoxSamples::empty();
+            let mut geo_stats = GeoStatistics::empty();
+            let mut rng = Rng::with_seed(seed);
+            for partition in build_partitions {
+                let samples = mem::take(&mut partition.bbox_samples);
+                bbox_samples = bbox_samples.combine(samples, &mut rng);
+                geo_stats.merge(&partition.geo_statistics);
+            }
+
+            let extent = 
geo_stats.bbox().cloned().unwrap_or(BoundingBox::empty());
+            let mut samples = bbox_samples.take_samples();
+            let max_items_per_node = 1.max(samples.len() / num_partitions);
+            let max_levels = num_partitions;
+
+            log::debug!(
+                "Number of samples: {}, max_items_per_node: {}, max_levels: 
{}",
+                samples.len(),
+                max_items_per_node,
+                max_levels
+            );
+            rng.shuffle(&mut samples);
+            let kdb_partitioner =
+                KDBPartitioner::build(samples.into_iter(), max_items_per_node, 
max_levels, extent)?;
+            log::debug!(
+                "Built KDB spatial partitioner with {} partitions",
+                num_partitions
+            );
+            log::debug!(
+                "KDB partitioner debug info:\n{}",
+                kdb_partitioner.debug_str()
+            );
+
+            Arc::new(kdb_partitioner)
+        };
+
+        Ok(build_partitioner)
+    }
+
+    /// Repartition the collected build-side partitions using the provided
+    /// `SpatialPartitioner`. Returns the spilled partitions for each spatial 
partition.
+    async fn repartition_build_side(
+        &self,
+        build_partitions: Vec<BuildPartition>,
+        build_partitioner: Arc<dyn SpatialPartitioner>,
+        memory_plan: &MemoryPlan,
+    ) -> Result<Vec<SpilledPartitions>> {
+        // Spawn each task for each build partition to repartition the data 
using the spatial partitioner for
+        // the build/indexed side
+        let runtime_env = self.context.runtime_env();
+        let session_config = self.context.session_config();
+        let target_batch_size = session_config.batch_size();
+        let spill_compression = session_config.spill_compression();
+        let spilled_batch_in_memory_size_threshold = if self
+            .sedona_options
+            .spatial_join
+            .spilled_batch_in_memory_size_threshold
+            == 0
+        {
+            None
+        } else {
+            Some(
+                self.sedona_options
+                    .spatial_join
+                    .spilled_batch_in_memory_size_threshold,
+            )
+        };
+
+        let memory_for_intermittent_usage = match self
+            .sedona_options
+            .spatial_join
+            .debug
+            .memory_for_intermittent_usage
+        {
+            Some(value) => {
+                log::debug!("Override memory for intermittent usage to {}", 
value);
+                value
+            }
+            None => memory_plan.memory_for_intermittent_usage,
+        };
+
+        let mut join_set = JoinSet::new();
+        let buffer_bytes_threshold = memory_for_intermittent_usage / 
build_partitions.len();
+        for partition in build_partitions {
+            let stream = partition.build_side_batch_stream;
+            let metrics = &partition.metrics;
+            let spill_metrics = metrics.spill_metrics();
+            let runtime_env = Arc::clone(&runtime_env);
+            let partitioner = Arc::clone(&build_partitioner);
+            join_set.spawn(async move {
+                let partitioned_spill_files = StreamRepartitioner::builder(
+                    runtime_env,
+                    partitioner,
+                    PartitionedSide::BuildSide,
+                    spill_metrics,
+                )
+                .spill_compression(spill_compression)
+                .buffer_bytes_threshold(buffer_bytes_threshold)
+                .target_batch_size(target_batch_size)
+                
.spilled_batch_in_memory_size_threshold(spilled_batch_in_memory_size_threshold)
+                .build()
+                .repartition_stream(stream)
+                .await;
+                partitioned_spill_files
+            });
+        }
+
+        let results = join_set.join_all().await;
+        let partitioned_spill_files_vec = 
results.into_iter().collect::<Result<Vec<_>>>()?;
+        Ok(partitioned_spill_files_vec)
+    }
+}
+
+/// Aggregate the spill files and bounds of each spatial partition collected 
from all build partitions
+fn merge_spilled_partitions(
+    spilled_partitions_vec: Vec<SpilledPartitions>,
+) -> Result<SpilledPartitions> {
+    let Some(first) = spilled_partitions_vec.first() else {
+        return sedona_internal_err!("spilled_partitions_vec cannot be empty");
+    };
+
+    let slots = first.slots();
+    let total_slots = slots.total_slots();
+    let mut merged_spill_files: Vec<Vec<Arc<RefCountedTempFile>>> =
+        (0..total_slots).map(|_| Vec::new()).collect();
+    let mut partition_geo_stats: Vec<GeoStatistics> =
+        (0..total_slots).map(|_| GeoStatistics::empty()).collect();
+    let mut partition_num_rows: Vec<usize> = (0..total_slots).map(|_| 
0).collect();
+
+    for spilled_partitions in spilled_partitions_vec {
+        let partitions = spilled_partitions.into_spilled_partitions()?;
+        for (slot_idx, partition) in partitions.into_iter().enumerate() {
+            let (spill_files, geo_stats, num_rows) = partition.into_inner();
+            partition_geo_stats[slot_idx].merge(&geo_stats);
+            merged_spill_files[slot_idx].extend(spill_files);
+            partition_num_rows[slot_idx] += num_rows;
+        }
+    }
+
+    let merged_partitions = merged_spill_files
+        .into_iter()
+        .zip(partition_geo_stats)
+        .zip(partition_num_rows)
+        .map(|((spill_files, geo_stats), num_rows)| {
+            SpilledPartition::new(spill_files, geo_stats, num_rows)
+        })
+        .collect();
+
+    Ok(SpilledPartitions::new(slots, merged_partitions))
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::partitioning::partition_slots::PartitionSlots;
+    use datafusion_execution::runtime_env::RuntimeEnv;
+    use sedona_geometry::interval::IntervalTrait;
+
+    fn sample_geo_stats(bbox: (f64, f64, f64, f64), total_geometries: i64) -> 
GeoStatistics {
+        GeoStatistics::empty()
+            .with_bbox(Some(BoundingBox::xy((bbox.0, bbox.1), (bbox.2, 
bbox.3))))
+            .with_total_geometries(total_geometries)
+    }
+
+    fn sample_partition(
+        env: &Arc<RuntimeEnv>,
+        labels: &[&str],
+        bbox: (f64, f64, f64, f64),
+        total_geometries: i64,
+    ) -> Result<SpilledPartition> {
+        let mut files = Vec::with_capacity(labels.len());
+        for label in labels {
+            files.push(Arc::new(env.disk_manager.create_tmp_file(label)?));
+        }
+        Ok(SpilledPartition::new(
+            files,
+            sample_geo_stats(bbox, total_geometries),
+            total_geometries as usize,
+        ))
+    }
+
+    #[test]
+    fn merge_spilled_partitions_combines_files_and_stats() -> Result<()> {
+        let runtime_env = Arc::new(RuntimeEnv::default());
+        let slots = PartitionSlots::new(2);
+
+        let partitions_a = vec![
+            sample_partition(&runtime_env, &["r0_a"], (0.0, 1.0, 0.0, 1.0), 
10)?,
+            sample_partition(&runtime_env, &["r1_a"], (10.0, 11.0, -1.0, 1.0), 
5)?,
+            sample_partition(&runtime_env, &["none_a"], (-5.0, -4.0, -5.0, 
-4.0), 2)?,
+            SpilledPartition::empty(),
+        ];
+        let first = SpilledPartitions::new(slots, partitions_a);
+
+        let partitions_b = vec![
+            sample_partition(&runtime_env, &["r0_b1", "r0_b2"], (5.0, 6.0, 
5.0, 6.0), 20)?,
+            sample_partition(&runtime_env, &[], (12.0, 13.0, 2.0, 3.0), 8)?,
+            SpilledPartition::empty(),
+            sample_partition(&runtime_env, &["multi_b"], (50.0, 51.0, 50.0, 
51.0), 1)?,
+        ];
+        let second = SpilledPartitions::new(slots, partitions_b);
+
+        let merged = merge_spilled_partitions(vec![first, second])?;
+
+        assert_eq!(merged.spill_file_count(), 6);
+
+        let regular0 = merged.spilled_partition(SpatialPartition::Regular(0))?;
+        assert_eq!(regular0.spill_files().len(), 3);
+        assert_eq!(regular0.geo_statistics().total_geometries(), Some(30));
+        let bbox0 = regular0.geo_statistics().bbox().unwrap();
+        assert_eq!(bbox0.x().lo(), 0.0);
+        assert_eq!(bbox0.x().hi(), 6.0);
+        assert_eq!(bbox0.y().lo(), 0.0);
+        assert_eq!(bbox0.y().hi(), 6.0);
+
+        let regular1 = merged.spilled_partition(SpatialPartition::Regular(1))?;
+        assert_eq!(regular1.spill_files().len(), 1);
+        assert_eq!(regular1.geo_statistics().total_geometries(), Some(13));
+        let bbox1 = regular1.geo_statistics().bbox().unwrap();
+        assert_eq!(bbox1.x().lo(), 10.0);
+        assert_eq!(bbox1.x().hi(), 13.0);
+        assert_eq!(bbox1.y().lo(), -1.0);
+        assert_eq!(bbox1.y().hi(), 3.0);
+
+        let none_partition = merged.spilled_partition(SpatialPartition::None)?;
+        assert_eq!(none_partition.spill_files().len(), 1);
+        assert_eq!(none_partition.geo_statistics().total_geometries(), 
Some(2));
+
+        let multi_partition = 
merged.spilled_partition(SpatialPartition::Multi)?;
+        assert_eq!(multi_partition.spill_files().len(), 1);
+        assert_eq!(multi_partition.geo_statistics().total_geometries(), 
Some(1));
+
+        Ok(())
+    }
+}
diff --git a/rust/sedona-spatial-join/src/stream.rs 
b/rust/sedona-spatial-join/src/stream.rs
index 8451ff2d..edbb41dd 100644
--- a/rust/sedona-spatial-join/src/stream.rs
+++ b/rust/sedona-spatial-join/src/stream.rs
@@ -38,8 +38,10 @@ use std::sync::Arc;
 use 
crate::evaluated_batch::evaluated_batch_stream::evaluate::create_evaluated_probe_stream;
 use 
crate::evaluated_batch::evaluated_batch_stream::SendableEvaluatedBatchStream;
 use crate::evaluated_batch::EvaluatedBatch;
+use crate::index::partitioned_index_provider::PartitionedIndexProvider;
 use crate::index::SpatialIndex;
 use crate::operand_evaluator::create_operand_evaluator;
+use crate::prepare::SpatialJoinComponents;
 use crate::spatial_predicate::SpatialPredicate;
 use crate::utils::join_utils::{
     adjust_indices_by_join_type, apply_join_filter_to_indices, 
build_batch_from_indices,
@@ -52,6 +54,8 @@ use sedona_common::option::SpatialJoinOptions;
 
 /// Stream for producing spatial join result batches.
 pub(crate) struct SpatialJoinStream {
+    /// The partition id of the probe side stream
+    probe_partition_id: usize,
     /// Schema of joined results
     schema: Arc<Schema>,
     /// join filter
@@ -73,13 +77,18 @@ pub(crate) struct SpatialJoinStream {
     options: SpatialJoinOptions,
     /// Target output batch size
     target_output_batch_size: usize,
-    /// Once future for the spatial index
-    once_fut_spatial_index: OnceFut<SpatialIndex>,
-    /// Once async for the spatial index, will be manually disposed by the 
last finished stream
-    /// to avoid unnecessary memory usage.
-    once_async_spatial_index: Arc<Mutex<Option<OnceAsync<SpatialIndex>>>>,
+    /// Once future for the shared partitioned index provider
+    once_fut_spatial_join_components: OnceFut<SpatialJoinComponents>,
+    /// Once async for the provider, disposed by the last finished stream
+    once_async_spatial_join_components: 
Arc<Mutex<Option<OnceAsync<SpatialJoinComponents>>>>,
+    /// Cached index provider reference after it becomes available
+    index_provider: Option<Arc<PartitionedIndexProvider>>,
     /// The spatial index
     spatial_index: Option<Arc<SpatialIndex>>,
+    /// Pending future for building or waiting on a partitioned index
+    pending_index_future: Option<BoxFuture<'static, 
Option<Result<Arc<SpatialIndex>>>>>,
+    /// Total number of regular partitions produced by the provider
+    num_regular_partitions: Option<u32>,
     /// The spatial predicate being evaluated
     spatial_predicate: SpatialPredicate,
 }
@@ -87,6 +96,7 @@ pub(crate) struct SpatialJoinStream {
 impl SpatialJoinStream {
     #[allow(clippy::too_many_arguments)]
     pub(crate) fn new(
+        probe_partition_id: usize,
         schema: Arc<Schema>,
         on: &SpatialPredicate,
         filter: Option<JoinFilter>,
@@ -97,8 +107,8 @@ impl SpatialJoinStream {
         join_metrics: SpatialJoinProbeMetrics,
         options: SpatialJoinOptions,
         target_output_batch_size: usize,
-        once_fut_spatial_index: OnceFut<SpatialIndex>,
-        once_async_spatial_index: Arc<Mutex<Option<OnceAsync<SpatialIndex>>>>,
+        once_fut_spatial_join_components: OnceFut<SpatialJoinComponents>,
+        once_async_spatial_join_components: 
Arc<Mutex<Option<OnceAsync<SpatialJoinComponents>>>>,
     ) -> Self {
         let evaluator = create_operand_evaluator(on, options.clone());
         let probe_stream = create_evaluated_probe_stream(
@@ -107,6 +117,7 @@ impl SpatialJoinStream {
             join_metrics.join_time.clone(),
         );
         Self {
+            probe_partition_id,
             schema,
             filter,
             join_type,
@@ -114,12 +125,15 @@ impl SpatialJoinStream {
             column_indices,
             probe_side_ordered,
             join_metrics,
-            state: SpatialJoinStreamState::WaitBuildIndex,
+            state: SpatialJoinStreamState::WaitPrepareSpatialJoinComponents,
             options,
             target_output_batch_size,
-            once_fut_spatial_index,
-            once_async_spatial_index,
+            once_fut_spatial_join_components,
+            once_async_spatial_join_components,
+            index_provider: None,
             spatial_index: None,
+            pending_index_future: None,
+            num_regular_partitions: None,
             spatial_predicate: on.clone(),
         }
     }
@@ -169,6 +183,8 @@ impl SpatialJoinProbeMetrics {
 /// This enumeration represents various states of the nested loop join 
algorithm.
 #[allow(clippy::large_enum_variant)]
 pub(crate) enum SpatialJoinStreamState {
+    /// The initial mode: waiting for the spatial join components to become 
available
+    WaitPrepareSpatialJoinComponents,
     /// The initial mode: waiting for the spatial index to be built
     WaitBuildIndex,
     /// Indicates that build-side has been collected, and stream is ready for
@@ -193,6 +209,9 @@ impl SpatialJoinStream {
     ) -> Poll<Option<Result<RecordBatch>>> {
         loop {
             return match &mut self.state {
+                SpatialJoinStreamState::WaitPrepareSpatialJoinComponents => {
+                    
handle_state!(ready!(self.wait_create_spatial_join_components(cx)))
+                }
                 SpatialJoinStreamState::WaitBuildIndex => {
                     handle_state!(ready!(self.wait_build_index(cx)))
                 }
@@ -213,16 +232,97 @@ impl SpatialJoinStream {
         }
     }
 
-    fn wait_build_index(
+    fn wait_create_spatial_join_components(
         &mut self,
         cx: &mut std::task::Context<'_>,
     ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
-        let index = ready!(self.once_fut_spatial_index.get_shared(cx))?;
-        self.spatial_index = Some(index);
-        self.state = SpatialJoinStreamState::FetchProbeBatch;
+        if self.index_provider.is_none() {
+            let spatial_join_components =
+                ready!(self.once_fut_spatial_join_components.get_shared(cx))?;
+            let provider = 
Arc::clone(&spatial_join_components.partitioned_index_provider);
+            self.num_regular_partitions = 
Some(provider.num_regular_partitions() as u32);
+            self.index_provider = Some(provider);
+        }
+
+        let num_partitions = self
+            .num_regular_partitions
+            .expect("num_regular_partitions should be available");
+        if num_partitions == 0 {
+            // Usually does not happen. The indexed side should have at least 
1 partition.
+            self.state = SpatialJoinStreamState::Completed;
+            return Poll::Ready(Ok(StatefulStreamResult::Continue));
+        }
+
+        if num_partitions > 1 {
+            return Poll::Ready(sedona_internal_err!(
+                "Multi-partitioned spatial join is not supported yet"
+            ));
+        }
+
+        self.state = SpatialJoinStreamState::WaitBuildIndex;
         Poll::Ready(Ok(StatefulStreamResult::Continue))
     }
 
+    fn wait_build_index(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+        let num_partitions = self
+            .num_regular_partitions
+            .expect("num_regular_partitions should be available");
+        let partition_id = 0;
+        if partition_id >= num_partitions {
+            self.state = SpatialJoinStreamState::Completed;
+            return Poll::Ready(Ok(StatefulStreamResult::Continue));
+        }
+
+        if self.pending_index_future.is_none() {
+            let provider = Arc::clone(
+                self.index_provider
+                    .as_ref()
+                    .expect("Partitioned index provider should be available"),
+            );
+            let future = {
+                log::debug!(
+                    "[Partition {}] Building index for spatial partition {}",
+                    self.probe_partition_id,
+                    partition_id
+                );
+                async move { 
provider.build_or_wait_for_index(partition_id).await }.boxed()
+            };
+            self.pending_index_future = Some(future);
+        }
+
+        let future = self
+            .pending_index_future
+            .as_mut()
+            .expect("pending future must exist");
+
+        match future.poll_unpin(cx) {
+            Poll::Ready(Some(Ok(index))) => {
+                self.pending_index_future = None;
+                self.spatial_index = Some(index);
+                log::debug!(
+                    "[Partition {}] Start probing spatial partition {}",
+                    self.probe_partition_id,
+                    partition_id
+                );
+                self.state = SpatialJoinStreamState::FetchProbeBatch;
+                Poll::Ready(Ok(StatefulStreamResult::Continue))
+            }
+            Poll::Ready(Some(Err(err))) => {
+                self.pending_index_future = None;
+                Poll::Ready(Err(err))
+            }
+            Poll::Ready(None) => {
+                self.pending_index_future = None;
+                self.state = SpatialJoinStreamState::Completed;
+                Poll::Ready(Ok(StatefulStreamResult::Continue))
+            }
+            Poll::Pending => Poll::Pending,
+        }
+    }
+
     fn fetch_probe_batch(
         &mut self,
         cx: &mut std::task::Context<'_>,
@@ -318,8 +418,13 @@ impl SpatialJoinStream {
 
             // Drop the once async to avoid holding a long-living reference to 
the spatial index.
             // The spatial index will be dropped when this stream is dropped.
-            let mut once_async = self.once_async_spatial_index.lock();
+            let mut once_async = 
self.once_async_spatial_join_components.lock();
             once_async.take();
+
+            if let Some(provider) = self.index_provider.as_ref() {
+                provider.dispose_index(0);
+                assert!(provider.num_loaded_indexes() == 0);
+            }
         }
 
         // Initial setup for processing unmatched build batches
diff --git a/rust/sedona-spatial-join/src/utils.rs 
b/rust/sedona-spatial-join/src/utils.rs
index 42a257f0..4d73a002 100644
--- a/rust/sedona-spatial-join/src/utils.rs
+++ b/rust/sedona-spatial-join/src/utils.rs
@@ -17,6 +17,7 @@
 
 pub(crate) mod arrow_utils;
 pub(crate) mod bbox_sampler;
+pub(crate) mod disposable_async_cell;
 pub(crate) mod init_once_array;
 pub(crate) mod join_utils;
 pub(crate) mod once_fut;
diff --git a/rust/sedona-spatial-join/src/utils/bbox_sampler.rs 
b/rust/sedona-spatial-join/src/utils/bbox_sampler.rs
index 498f3863..99280162 100644
--- a/rust/sedona-spatial-join/src/utils/bbox_sampler.rs
+++ b/rust/sedona-spatial-join/src/utils/bbox_sampler.rs
@@ -15,7 +15,6 @@
 // specific language governing permissions and limitations
 // under the License.
 
-#![allow(unused)]
 use datafusion_common::{DataFusionError, Result};
 use fastrand::Rng;
 use sedona_geometry::bounding_box::BoundingBox;
diff --git a/rust/sedona-spatial-join/src/utils/disposable_async_cell.rs 
b/rust/sedona-spatial-join/src/utils/disposable_async_cell.rs
new file mode 100644
index 00000000..e738e034
--- /dev/null
+++ b/rust/sedona-spatial-join/src/utils/disposable_async_cell.rs
@@ -0,0 +1,204 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::fmt;
+
+use parking_lot::Mutex;
+use tokio::sync::Notify;
+
+/// Error returned when writing to a [`DisposableAsyncCell`] fails.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub(crate) enum CellSetError {
+    /// The cell has already been disposed, so new values are rejected.
+    Disposed,
+
+    /// The cell already has a value.
+    AlreadySet,
+}
+
+/// An asynchronous cell that can be set at most once before either being
+/// disposed or read by any number of waiters.
+///
+/// This is used as a lightweight one-shot coordination primitive in the 
spatial
+/// join implementation. For example, `PartitionedIndexProvider` keeps one
+/// `DisposableAsyncCell` per regular partition to publish either a 
successfully
+/// built `SpatialIndex` (or the build error) exactly once. Concurrent
+/// `SpatialJoinStream`s racing to probe the same partition can then await the
+/// same shared result instead of building duplicate indexes.
+///
+/// When an index is no longer needed (e.g. the last stream finishes a
+/// partition), the cell can be disposed to free resources.
+///
+/// Awaiters calling [`DisposableAsyncCell::get`] will park until a value is 
set
+/// or the cell is disposed. Once disposed, `get` returns `None` and `set`
+/// returns [`CellSetError::Disposed`].
+pub(crate) struct DisposableAsyncCell<T> {
+    state: Mutex<CellState<T>>,
+    notify: Notify,
+}
+
+impl<T> fmt::Debug for DisposableAsyncCell<T> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "DisposableAsyncCell")
+    }
+}
+
+impl<T> Default for DisposableAsyncCell<T> {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl<T> DisposableAsyncCell<T> {
+    /// Creates a new empty cell with no stored value.
+    pub(crate) fn new() -> Self {
+        Self {
+            state: Mutex::new(CellState::Empty),
+            notify: Notify::new(),
+        }
+    }
+
+    /// Marks the cell as disposed and wakes every waiter.
+    pub(crate) fn dispose(&self) {
+        {
+            let mut state = self.state.lock();
+            *state = CellState::Disposed;
+        }
+        self.notify.notify_waiters();
+    }
+
+    /// Check whether the cell has a value or not.
+    pub(crate) fn is_set(&self) -> bool {
+        let state = self.state.lock();
+        matches!(*state, CellState::Value(_))
+    }
+
+    /// Check whether the cell is empty (not set or disposed)
+    pub(crate) fn is_empty(&self) -> bool {
+        let state = self.state.lock();
+        matches!(*state, CellState::Empty)
+    }
+}
+
+impl<T: Clone> DisposableAsyncCell<T> {
+    /// Waits until a value is set or the cell is disposed.
+    /// Returns `None` if the cell is disposed without a value.
+    pub(crate) async fn get(&self) -> Option<T> {
+        loop {
+            let notified = self.notify.notified();
+            {
+                let state = self.state.lock();
+                match &*state {
+                    CellState::Value(val) => return Some(val.clone()),
+                    CellState::Disposed => return None,
+                    CellState::Empty => {}
+                }
+            }
+            notified.await;
+        }
+    }
+
+    /// Stores the provided value if the cell is still empty.
+    /// Fails if a value already exists or the cell has been disposed.
+    pub(crate) fn set(&self, value: T) -> std::result::Result<(), 
CellSetError> {
+        {
+            let mut state = self.state.lock();
+            match &mut *state {
+                CellState::Empty => *state = CellState::Value(value),
+                CellState::Disposed => return Err(CellSetError::Disposed),
+                CellState::Value(_) => return Err(CellSetError::AlreadySet),
+            }
+        }
+
+        self.notify.notify_waiters();
+        Ok(())
+    }
+}
+
+enum CellState<T> {
+    Empty,
+    Value(T),
+    Disposed,
+}
+
+#[cfg(test)]
+mod tests {
+    use super::{CellSetError, DisposableAsyncCell};
+    use std::sync::Arc;
+    use tokio::task;
+    use tokio::time::{sleep, Duration};
+
+    #[tokio::test]
+    async fn get_returns_value_once_set() {
+        let cell = DisposableAsyncCell::new();
+        cell.set(42).expect("set succeeds");
+        assert_eq!(Some(42), cell.get().await);
+    }
+
+    #[tokio::test]
+    async fn multiple_waiters_receive_same_value() {
+        let cell = Arc::new(DisposableAsyncCell::new());
+        let cloned = Arc::clone(&cell);
+        let waiter_one = task::spawn(async move { cloned.get().await });
+        let cloned = Arc::clone(&cell);
+        let waiter_two = task::spawn(async move { cloned.get().await });
+
+        cell.set(String::from("value")).expect("set succeeds");
+        assert_eq!(Some("value".to_string()), waiter_one.await.unwrap());
+        assert_eq!(Some("value".to_string()), waiter_two.await.unwrap());
+    }
+
+    #[tokio::test]
+    async fn dispose_unblocks_waiters() {
+        let cell = Arc::new(DisposableAsyncCell::<i32>::new());
+        let waiter = tokio::spawn({
+            let cloned = Arc::clone(&cell);
+            async move { cloned.get().await }
+        });
+
+        cell.dispose();
+        assert_eq!(None, waiter.await.unwrap());
+    }
+
+    #[tokio::test]
+    async fn set_after_dispose_fails() {
+        let cell = DisposableAsyncCell::new();
+        cell.dispose();
+        assert_eq!(Err(CellSetError::Disposed), cell.set(5));
+    }
+
+    #[tokio::test]
+    async fn set_twice_rejects_second_value() {
+        let cell = DisposableAsyncCell::new();
+        cell.set("first").expect("initial set succeeds");
+        assert_eq!(Err(CellSetError::AlreadySet), cell.set("second"));
+        assert_eq!(Some("first"), cell.get().await);
+    }
+
+    #[tokio::test]
+    async fn get_waits_until_value_is_set() {
+        let cell = Arc::new(DisposableAsyncCell::new());
+        let cloned = Arc::clone(&cell);
+        let waiter = tokio::spawn(async move { cloned.get().await });
+
+        sleep(Duration::from_millis(20)).await;
+        assert!(!waiter.is_finished());
+
+        cell.set(99).expect("set succeeds");
+        assert_eq!(Some(99), waiter.await.unwrap());
+    }
+}

Reply via email to