This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 4edd904ee refactor: reorganize shuffle crate module structure (#3772)
4edd904ee is described below

commit 4edd904eeaca8c985e5dd928486888e4ee7e9f60
Author: Andy Grove <[email protected]>
AuthorDate: Fri Mar 27 11:50:17 2026 -0600

    refactor: reorganize shuffle crate module structure (#3772)
---
 .../core/src/execution/operators/shuffle_scan.rs   |   7 +-
 native/shuffle/benches/row_columnar.rs             |   5 +-
 native/shuffle/src/comet_partitioning.rs           |   1 +
 native/shuffle/src/ipc.rs                          |  52 +++++
 native/shuffle/src/lib.rs                          |   5 +-
 native/shuffle/src/metrics.rs                      |   1 +
 native/shuffle/src/partitioners/mod.rs             |  13 +-
 native/shuffle/src/partitioners/multi_partition.rs |   1 +
 .../src/partitioners/partitioned_batch_iterator.rs |   1 +
 .../shuffle/src/partitioners/{mod.rs => traits.rs} |   8 -
 native/shuffle/src/spark_unsafe/list.rs            |   7 +-
 native/shuffle/src/spark_unsafe/map.rs             |   1 +
 native/shuffle/src/spark_unsafe/mod.rs             |   1 +
 native/shuffle/src/spark_unsafe/row.rs             | 214 +-------------------
 native/shuffle/src/spark_unsafe/unsafe_object.rs   | 224 +++++++++++++++++++++
 native/shuffle/src/writers/buf_batch_writer.rs     |   2 +-
 native/shuffle/src/writers/checksum.rs             |  81 ++++++++
 native/shuffle/src/writers/mod.rs                  |   8 +-
 .../{codec.rs => writers/shuffle_block_writer.rs}  |  97 +--------
 .../src/writers/{partition_writer.rs => spill.rs}  |   4 +-
 20 files changed, 395 insertions(+), 338 deletions(-)

diff --git a/native/core/src/execution/operators/shuffle_scan.rs 
b/native/core/src/execution/operators/shuffle_scan.rs
index 824965d48..a1ad52310 100644
--- a/native/core/src/execution/operators/shuffle_scan.rs
+++ b/native/core/src/execution/operators/shuffle_scan.rs
@@ -18,8 +18,7 @@
 use crate::{
     errors::CometError,
     execution::{
-        operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID,
-        shuffle::codec::read_ipc_compressed,
+        operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, 
shuffle::ipc::read_ipc_compressed,
     },
     jvm_bridge::{jni_call, JVMClasses},
 };
@@ -352,7 +351,7 @@ impl RecordBatchStream for ShuffleScanStream {
 
 #[cfg(test)]
 mod tests {
-    use crate::execution::shuffle::codec::{CompressionCodec, 
ShuffleBlockWriter};
+    use crate::execution::shuffle::{CompressionCodec, ShuffleBlockWriter};
     use arrow::array::{Int32Array, StringArray};
     use arrow::datatypes::{DataType, Field, Schema};
     use arrow::record_batch::RecordBatch;
@@ -360,7 +359,7 @@ mod tests {
     use std::io::Cursor;
     use std::sync::Arc;
 
-    use crate::execution::shuffle::codec::read_ipc_compressed;
+    use crate::execution::shuffle::ipc::read_ipc_compressed;
 
     #[test]
     #[cfg_attr(miri, ignore)] // Miri cannot call FFI functions (zstd)
diff --git a/native/shuffle/benches/row_columnar.rs 
b/native/shuffle/benches/row_columnar.rs
index 7d3951b4d..cc98f3fac 100644
--- a/native/shuffle/benches/row_columnar.rs
+++ b/native/shuffle/benches/row_columnar.rs
@@ -23,9 +23,8 @@
 
 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields};
 use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
-use datafusion_comet_shuffle::spark_unsafe::row::{
-    process_sorted_row_partition, SparkUnsafeObject, SparkUnsafeRow,
-};
+use 
datafusion_comet_shuffle::spark_unsafe::row::{process_sorted_row_partition, 
SparkUnsafeRow};
+use datafusion_comet_shuffle::spark_unsafe::unsafe_object::SparkUnsafeObject;
 use datafusion_comet_shuffle::CompressionCodec;
 use std::sync::Arc;
 use tempfile::Builder;
diff --git a/native/shuffle/src/comet_partitioning.rs 
b/native/shuffle/src/comet_partitioning.rs
index c269539a6..15912e648 100644
--- a/native/shuffle/src/comet_partitioning.rs
+++ b/native/shuffle/src/comet_partitioning.rs
@@ -19,6 +19,7 @@ use arrow::row::{OwnedRow, RowConverter};
 use datafusion::physical_expr::{LexOrdering, PhysicalExpr};
 use std::sync::Arc;
 
+/// Partitioning scheme for distributing rows across shuffle output partitions.
 #[derive(Debug, Clone)]
 pub enum CometPartitioning {
     SinglePartition,
diff --git a/native/shuffle/src/ipc.rs b/native/shuffle/src/ipc.rs
new file mode 100644
index 000000000..81ee41332
--- /dev/null
+++ b/native/shuffle/src/ipc.rs
@@ -0,0 +1,52 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::RecordBatch;
+use arrow::ipc::reader::StreamReader;
+use datafusion::common::DataFusionError;
+use datafusion::error::Result;
+
+pub fn read_ipc_compressed(bytes: &[u8]) -> Result<RecordBatch> {
+    match &bytes[0..4] {
+        b"SNAP" => {
+            let decoder = snap::read::FrameDecoder::new(&bytes[4..]);
+            let mut reader =
+                unsafe { StreamReader::try_new(decoder, 
None)?.with_skip_validation(true) };
+            reader.next().unwrap().map_err(|e| e.into())
+        }
+        b"LZ4_" => {
+            let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]);
+            let mut reader =
+                unsafe { StreamReader::try_new(decoder, 
None)?.with_skip_validation(true) };
+            reader.next().unwrap().map_err(|e| e.into())
+        }
+        b"ZSTD" => {
+            let decoder = zstd::Decoder::new(&bytes[4..])?;
+            let mut reader =
+                unsafe { StreamReader::try_new(decoder, 
None)?.with_skip_validation(true) };
+            reader.next().unwrap().map_err(|e| e.into())
+        }
+        b"NONE" => {
+            let mut reader =
+                unsafe { StreamReader::try_new(&bytes[4..], 
None)?.with_skip_validation(true) };
+            reader.next().unwrap().map_err(|e| e.into())
+        }
+        other => Err(DataFusionError::Execution(format!(
+            "Failed to decode batch: invalid compression codec: {other:?}"
+        ))),
+    }
+}
diff --git a/native/shuffle/src/lib.rs b/native/shuffle/src/lib.rs
index 7c2fc8403..f29588f2e 100644
--- a/native/shuffle/src/lib.rs
+++ b/native/shuffle/src/lib.rs
@@ -15,14 +15,15 @@
 // specific language governing permissions and limitations
 // under the License.
 
-pub mod codec;
 pub(crate) mod comet_partitioning;
+pub mod ipc;
 pub(crate) mod metrics;
 pub(crate) mod partitioners;
 mod shuffle_writer;
 pub mod spark_unsafe;
 pub(crate) mod writers;
 
-pub use codec::{read_ipc_compressed, CompressionCodec, ShuffleBlockWriter};
 pub use comet_partitioning::CometPartitioning;
+pub use ipc::read_ipc_compressed;
 pub use shuffle_writer::ShuffleWriterExec;
+pub use writers::{CompressionCodec, ShuffleBlockWriter};
diff --git a/native/shuffle/src/metrics.rs b/native/shuffle/src/metrics.rs
index 1aba4677d..1de751cf4 100644
--- a/native/shuffle/src/metrics.rs
+++ b/native/shuffle/src/metrics.rs
@@ -19,6 +19,7 @@ use datafusion::physical_plan::metrics::{
     BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, Time,
 };
 
+/// Execution metrics for a shuffle partition operation.
 pub(crate) struct ShufflePartitionerMetrics {
     /// metrics
     pub(crate) baseline: BaselineMetrics,
diff --git a/native/shuffle/src/partitioners/mod.rs 
b/native/shuffle/src/partitioners/mod.rs
index a6d589677..3eedef62c 100644
--- a/native/shuffle/src/partitioners/mod.rs
+++ b/native/shuffle/src/partitioners/mod.rs
@@ -18,18 +18,9 @@
 mod multi_partition;
 mod partitioned_batch_iterator;
 mod single_partition;
-
-use arrow::record_batch::RecordBatch;
-use datafusion::common::Result;
+mod traits;
 
 pub(crate) use multi_partition::MultiPartitionShuffleRepartitioner;
 pub(crate) use partitioned_batch_iterator::PartitionedBatchIterator;
 pub(crate) use single_partition::SinglePartitionShufflePartitioner;
-
-#[async_trait::async_trait]
-pub(crate) trait ShufflePartitioner: Send + Sync {
-    /// Insert a batch into the partitioner
-    async fn insert_batch(&mut self, batch: RecordBatch) -> Result<()>;
-    /// Write shuffle data and shuffle index file to disk
-    fn shuffle_write(&mut self) -> Result<()>;
-}
+pub(crate) use traits::ShufflePartitioner;
diff --git a/native/shuffle/src/partitioners/multi_partition.rs 
b/native/shuffle/src/partitioners/multi_partition.rs
index 42290c551..655bee351 100644
--- a/native/shuffle/src/partitioners/multi_partition.rs
+++ b/native/shuffle/src/partitioners/multi_partition.rs
@@ -39,6 +39,7 @@ use std::io::{BufReader, BufWriter, Seek, Write};
 use std::sync::Arc;
 use tokio::time::Instant;
 
+/// Reusable scratch buffers for computing row-to-partition assignments.
 #[derive(Default)]
 struct ScratchSpace {
     /// Hashes for each row in the current batch.
diff --git a/native/shuffle/src/partitioners/partitioned_batch_iterator.rs 
b/native/shuffle/src/partitioners/partitioned_batch_iterator.rs
index 77010938c..8309a8ed4 100644
--- a/native/shuffle/src/partitioners/partitioned_batch_iterator.rs
+++ b/native/shuffle/src/partitioners/partitioned_batch_iterator.rs
@@ -50,6 +50,7 @@ impl PartitionedBatchesProducer {
     }
 }
 
+/// Iterates over the shuffled record batches belonging to a single output 
partition.
 pub(crate) struct PartitionedBatchIterator<'a> {
     record_batches: Vec<&'a RecordBatch>,
     batch_size: usize,
diff --git a/native/shuffle/src/partitioners/mod.rs 
b/native/shuffle/src/partitioners/traits.rs
similarity index 80%
copy from native/shuffle/src/partitioners/mod.rs
copy to native/shuffle/src/partitioners/traits.rs
index a6d589677..9572b70db 100644
--- a/native/shuffle/src/partitioners/mod.rs
+++ b/native/shuffle/src/partitioners/traits.rs
@@ -15,17 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
-mod multi_partition;
-mod partitioned_batch_iterator;
-mod single_partition;
-
 use arrow::record_batch::RecordBatch;
 use datafusion::common::Result;
 
-pub(crate) use multi_partition::MultiPartitionShuffleRepartitioner;
-pub(crate) use partitioned_batch_iterator::PartitionedBatchIterator;
-pub(crate) use single_partition::SinglePartitionShufflePartitioner;
-
 #[async_trait::async_trait]
 pub(crate) trait ShufflePartitioner: Send + Sync {
     /// Insert a batch into the partitioner
diff --git a/native/shuffle/src/spark_unsafe/list.rs 
b/native/shuffle/src/spark_unsafe/list.rs
index 4eb293895..3fea3fade 100644
--- a/native/shuffle/src/spark_unsafe/list.rs
+++ b/native/shuffle/src/spark_unsafe/list.rs
@@ -17,10 +17,8 @@
 
 use crate::spark_unsafe::{
     map::append_map_elements,
-    row::{
-        append_field, downcast_builder_ref, impl_primitive_accessors, 
SparkUnsafeObject,
-        SparkUnsafeRow,
-    },
+    row::{append_field, downcast_builder_ref, SparkUnsafeRow},
+    unsafe_object::{impl_primitive_accessors, SparkUnsafeObject},
 };
 use arrow::array::{
     builder::{
@@ -86,6 +84,7 @@ macro_rules! impl_append_to_builder {
     };
 }
 
+/// A Spark `UnsafeArray` backed by JVM-allocated memory, providing element 
access by index.
 pub struct SparkUnsafeArray {
     row_addr: i64,
     num_elements: usize,
diff --git a/native/shuffle/src/spark_unsafe/map.rs 
b/native/shuffle/src/spark_unsafe/map.rs
index 57444cee7..026e6f71d 100644
--- a/native/shuffle/src/spark_unsafe/map.rs
+++ b/native/shuffle/src/spark_unsafe/map.rs
@@ -20,6 +20,7 @@ use arrow::array::builder::{ArrayBuilder, MapBuilder, 
MapFieldNames};
 use arrow::datatypes::{DataType, FieldRef};
 use datafusion_comet_jni_bridge::errors::CometError;
 
+/// A Spark `UnsafeMap` backed by JVM-allocated memory, containing parallel 
keys and values arrays.
 pub struct SparkUnsafeMap {
     pub(crate) keys: SparkUnsafeArray,
     pub(crate) values: SparkUnsafeArray,
diff --git a/native/shuffle/src/spark_unsafe/mod.rs 
b/native/shuffle/src/spark_unsafe/mod.rs
index 6390a0f23..abda69a08 100644
--- a/native/shuffle/src/spark_unsafe/mod.rs
+++ b/native/shuffle/src/spark_unsafe/mod.rs
@@ -18,3 +18,4 @@
 pub mod list;
 mod map;
 pub mod row;
+pub mod unsafe_object;
diff --git a/native/shuffle/src/spark_unsafe/row.rs 
b/native/shuffle/src/spark_unsafe/row.rs
index da980af8f..3c9867719 100644
--- a/native/shuffle/src/spark_unsafe/row.rs
+++ b/native/shuffle/src/spark_unsafe/row.rs
@@ -17,11 +17,13 @@
 
 //! Utils for supporting native sort-based columnar shuffle.
 
-use crate::codec::{Checksum, ShuffleBlockWriter};
+use crate::spark_unsafe::unsafe_object::{impl_primitive_accessors, 
SparkUnsafeObject};
 use crate::spark_unsafe::{
-    list::{append_list_element, SparkUnsafeArray},
-    map::{append_map_elements, get_map_key_value_fields, SparkUnsafeMap},
+    list::append_list_element,
+    map::{append_map_elements, get_map_key_value_fields},
 };
+use crate::writers::Checksum;
+use crate::writers::ShuffleBlockWriter;
 use arrow::array::{
     builder::{
         ArrayBuilder, BinaryBuilder, BinaryDictionaryBuilder, BooleanBuilder, 
Date32Builder,
@@ -36,219 +38,17 @@ use arrow::compute::cast;
 use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
 use arrow::error::ArrowError;
 use datafusion::physical_plan::metrics::Time;
-use datafusion_comet_common::bytes_to_i128;
 use datafusion_comet_jni_bridge::errors::CometError;
 use jni::sys::{jint, jlong};
 use std::{
     fs::OpenOptions,
     io::{Cursor, Write},
-    str::from_utf8,
     sync::Arc,
 };
 
-const MAX_LONG_DIGITS: u8 = 18;
 const NESTED_TYPE_BUILDER_CAPACITY: usize = 100;
 
-/// A common trait for Spark Unsafe classes that can be used to access the 
underlying data,
-/// e.g., `UnsafeRow` and `UnsafeArray`. This defines a set of methods that 
can be used to
-/// access the underlying data with index.
-///
-/// # Safety
-///
-/// Implementations must ensure that:
-/// - `get_row_addr()` returns a valid pointer to JVM-allocated memory
-/// - `get_element_offset()` returns a valid pointer within the row/array data 
region
-/// - The memory layout follows Spark's UnsafeRow/UnsafeArray format
-/// - The memory remains valid for the lifetime of the object (guaranteed by 
JVM ownership)
-///
-/// All accessor methods (get_boolean, get_int, etc.) use unsafe pointer 
operations but are
-/// safe to call as long as:
-/// - The index is within bounds (caller's responsibility)
-/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data
-///
-/// # Alignment
-///
-/// Primitive accessor methods are implemented separately for each type 
because they have
-/// different alignment guarantees:
-/// - `SparkUnsafeRow`: All field offsets are 8-byte aligned (bitset width is 
a multiple of 8,
-///   and each field slot is 8 bytes), so accessors use aligned `ptr::read()`.
-/// - `SparkUnsafeArray`: The array base address may be unaligned when nested 
within a row's
-///   variable-length region, so accessors use `ptr::read_unaligned()`.
-pub trait SparkUnsafeObject {
-    /// Returns the address of the row.
-    fn get_row_addr(&self) -> i64;
-
-    /// Returns the offset of the element at the given index.
-    fn get_element_offset(&self, index: usize, element_size: usize) -> *const 
u8;
-
-    fn get_boolean(&self, index: usize) -> bool;
-    fn get_byte(&self, index: usize) -> i8;
-    fn get_short(&self, index: usize) -> i16;
-    fn get_int(&self, index: usize) -> i32;
-    fn get_long(&self, index: usize) -> i64;
-    fn get_float(&self, index: usize) -> f32;
-    fn get_double(&self, index: usize) -> f64;
-    fn get_date(&self, index: usize) -> i32;
-    fn get_timestamp(&self, index: usize) -> i64;
-
-    /// Returns the offset and length of the element at the given index.
-    #[inline]
-    fn get_offset_and_len(&self, index: usize) -> (i32, i32) {
-        let offset_and_size = self.get_long(index);
-        let offset = (offset_and_size >> 32) as i32;
-        let len = offset_and_size as i32;
-        (offset, len)
-    }
-
-    /// Returns string value at the given index of the object.
-    fn get_string(&self, index: usize) -> &str {
-        let (offset, len) = self.get_offset_and_len(index);
-        let addr = self.get_row_addr() + offset as i64;
-        // SAFETY: addr points to valid UTF-8 string data within the 
variable-length region.
-        // Offset and length are read from the fixed-length portion of the 
row/array.
-        debug_assert!(addr != 0, "get_string: null address at index {index}");
-        debug_assert!(
-            len >= 0,
-            "get_string: negative length {len} at index {index}"
-        );
-        let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const 
u8, len as usize) };
-
-        from_utf8(slice).unwrap()
-    }
-
-    /// Returns binary value at the given index of the object.
-    fn get_binary(&self, index: usize) -> &[u8] {
-        let (offset, len) = self.get_offset_and_len(index);
-        let addr = self.get_row_addr() + offset as i64;
-        // SAFETY: addr points to valid binary data within the variable-length 
region.
-        // Offset and length are read from the fixed-length portion of the 
row/array.
-        debug_assert!(addr != 0, "get_binary: null address at index {index}");
-        debug_assert!(
-            len >= 0,
-            "get_binary: negative length {len} at index {index}"
-        );
-        unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }
-    }
-
-    /// Returns decimal value at the given index of the object.
-    fn get_decimal(&self, index: usize, precision: u8) -> i128 {
-        if precision <= MAX_LONG_DIGITS {
-            self.get_long(index) as i128
-        } else {
-            let slice = self.get_binary(index);
-            bytes_to_i128(slice)
-        }
-    }
-
-    /// Returns struct value at the given index of the object.
-    fn get_struct(&self, index: usize, num_fields: usize) -> SparkUnsafeRow {
-        let (offset, len) = self.get_offset_and_len(index);
-        let mut row = SparkUnsafeRow::new_with_num_fields(num_fields);
-        row.point_to(self.get_row_addr() + offset as i64, len);
-
-        row
-    }
-
-    /// Returns array value at the given index of the object.
-    fn get_array(&self, index: usize) -> SparkUnsafeArray {
-        let (offset, _) = self.get_offset_and_len(index);
-        SparkUnsafeArray::new(self.get_row_addr() + offset as i64)
-    }
-
-    fn get_map(&self, index: usize) -> SparkUnsafeMap {
-        let (offset, len) = self.get_offset_and_len(index);
-        SparkUnsafeMap::new(self.get_row_addr() + offset as i64, len)
-    }
-}
-
-/// Generates primitive accessor implementations for `SparkUnsafeObject`.
-///
-/// Uses `$read_method` to read typed values from raw pointers:
-/// - `read` for aligned access (SparkUnsafeRow — all offsets are 8-byte 
aligned)
-/// - `read_unaligned` for potentially unaligned access (SparkUnsafeArray)
-macro_rules! impl_primitive_accessors {
-    ($read_method:ident) => {
-        #[inline]
-        fn get_boolean(&self, index: usize) -> bool {
-            let addr = self.get_element_offset(index, 1);
-            debug_assert!(
-                !addr.is_null(),
-                "get_boolean: null pointer at index {index}"
-            );
-            // SAFETY: addr points to valid element data within the row/array 
region.
-            unsafe { *addr != 0 }
-        }
-
-        #[inline]
-        fn get_byte(&self, index: usize) -> i8 {
-            let addr = self.get_element_offset(index, 1);
-            debug_assert!(!addr.is_null(), "get_byte: null pointer at index 
{index}");
-            // SAFETY: addr points to valid element data (1 byte) within the 
row/array region.
-            unsafe { *(addr as *const i8) }
-        }
-
-        #[inline]
-        fn get_short(&self, index: usize) -> i16 {
-            let addr = self.get_element_offset(index, 2) as *const i16;
-            debug_assert!(!addr.is_null(), "get_short: null pointer at index 
{index}");
-            // SAFETY: addr points to valid element data (2 bytes) within the 
row/array region.
-            unsafe { addr.$read_method() }
-        }
-
-        #[inline]
-        fn get_int(&self, index: usize) -> i32 {
-            let addr = self.get_element_offset(index, 4) as *const i32;
-            debug_assert!(!addr.is_null(), "get_int: null pointer at index 
{index}");
-            // SAFETY: addr points to valid element data (4 bytes) within the 
row/array region.
-            unsafe { addr.$read_method() }
-        }
-
-        #[inline]
-        fn get_long(&self, index: usize) -> i64 {
-            let addr = self.get_element_offset(index, 8) as *const i64;
-            debug_assert!(!addr.is_null(), "get_long: null pointer at index 
{index}");
-            // SAFETY: addr points to valid element data (8 bytes) within the 
row/array region.
-            unsafe { addr.$read_method() }
-        }
-
-        #[inline]
-        fn get_float(&self, index: usize) -> f32 {
-            let addr = self.get_element_offset(index, 4) as *const f32;
-            debug_assert!(!addr.is_null(), "get_float: null pointer at index 
{index}");
-            // SAFETY: addr points to valid element data (4 bytes) within the 
row/array region.
-            unsafe { addr.$read_method() }
-        }
-
-        #[inline]
-        fn get_double(&self, index: usize) -> f64 {
-            let addr = self.get_element_offset(index, 8) as *const f64;
-            debug_assert!(!addr.is_null(), "get_double: null pointer at index 
{index}");
-            // SAFETY: addr points to valid element data (8 bytes) within the 
row/array region.
-            unsafe { addr.$read_method() }
-        }
-
-        #[inline]
-        fn get_date(&self, index: usize) -> i32 {
-            let addr = self.get_element_offset(index, 4) as *const i32;
-            debug_assert!(!addr.is_null(), "get_date: null pointer at index 
{index}");
-            // SAFETY: addr points to valid element data (4 bytes) within the 
row/array region.
-            unsafe { addr.$read_method() }
-        }
-
-        #[inline]
-        fn get_timestamp(&self, index: usize) -> i64 {
-            let addr = self.get_element_offset(index, 8) as *const i64;
-            debug_assert!(
-                !addr.is_null(),
-                "get_timestamp: null pointer at index {index}"
-            );
-            // SAFETY: addr points to valid element data (8 bytes) within the 
row/array region.
-            unsafe { addr.$read_method() }
-        }
-    };
-}
-pub(crate) use impl_primitive_accessors;
-
+/// A Spark `UnsafeRow` backed by JVM-allocated memory, providing field access 
by index.
 pub struct SparkUnsafeRow {
     row_addr: i64,
     row_size: i32,
@@ -323,7 +123,7 @@ impl SparkUnsafeRow {
     }
 
     /// Points the row to the given address with specified row size.
-    fn point_to(&mut self, row_addr: i64, row_size: i32) {
+    pub(crate) fn point_to(&mut self, row_addr: i64, row_size: i32) {
         self.row_addr = row_addr;
         self.row_size = row_size;
     }
diff --git a/native/shuffle/src/spark_unsafe/unsafe_object.rs 
b/native/shuffle/src/spark_unsafe/unsafe_object.rs
new file mode 100644
index 000000000..f32ea8c23
--- /dev/null
+++ b/native/shuffle/src/spark_unsafe/unsafe_object.rs
@@ -0,0 +1,224 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::list::SparkUnsafeArray;
+use super::map::SparkUnsafeMap;
+use super::row::SparkUnsafeRow;
+use datafusion_comet_common::bytes_to_i128;
+use std::str::from_utf8;
+
+const MAX_LONG_DIGITS: u8 = 18;
+
+/// A common trait for Spark Unsafe classes that can be used to access the 
underlying data,
+/// e.g., `UnsafeRow` and `UnsafeArray`. This defines a set of methods that 
can be used to
+/// access the underlying data with index.
+///
+/// # Safety
+///
+/// Implementations must ensure that:
+/// - `get_row_addr()` returns a valid pointer to JVM-allocated memory
+/// - `get_element_offset()` returns a valid pointer within the row/array data 
region
+/// - The memory layout follows Spark's UnsafeRow/UnsafeArray format
+/// - The memory remains valid for the lifetime of the object (guaranteed by 
JVM ownership)
+///
+/// All accessor methods (get_boolean, get_int, etc.) use unsafe pointer 
operations but are
+/// safe to call as long as:
+/// - The index is within bounds (caller's responsibility)
+/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data
+///
+/// # Alignment
+///
+/// Primitive accessor methods are implemented separately for each type 
because they have
+/// different alignment guarantees:
+/// - `SparkUnsafeRow`: All field offsets are 8-byte aligned (bitset width is 
a multiple of 8,
+///   and each field slot is 8 bytes), so accessors use aligned `ptr::read()`.
+/// - `SparkUnsafeArray`: The array base address may be unaligned when nested 
within a row's
+///   variable-length region, so accessors use `ptr::read_unaligned()`.
+pub trait SparkUnsafeObject {
+    /// Returns the address of the row.
+    fn get_row_addr(&self) -> i64;
+
+    /// Returns the offset of the element at the given index.
+    fn get_element_offset(&self, index: usize, element_size: usize) -> *const 
u8;
+
+    fn get_boolean(&self, index: usize) -> bool;
+    fn get_byte(&self, index: usize) -> i8;
+    fn get_short(&self, index: usize) -> i16;
+    fn get_int(&self, index: usize) -> i32;
+    fn get_long(&self, index: usize) -> i64;
+    fn get_float(&self, index: usize) -> f32;
+    fn get_double(&self, index: usize) -> f64;
+    fn get_date(&self, index: usize) -> i32;
+    fn get_timestamp(&self, index: usize) -> i64;
+
+    /// Returns the offset and length of the element at the given index.
+    #[inline]
+    fn get_offset_and_len(&self, index: usize) -> (i32, i32) {
+        let offset_and_size = self.get_long(index);
+        let offset = (offset_and_size >> 32) as i32;
+        let len = offset_and_size as i32;
+        (offset, len)
+    }
+
+    /// Returns string value at the given index of the object.
+    fn get_string(&self, index: usize) -> &str {
+        let (offset, len) = self.get_offset_and_len(index);
+        let addr = self.get_row_addr() + offset as i64;
+        // SAFETY: addr points to valid UTF-8 string data within the 
variable-length region.
+        // Offset and length are read from the fixed-length portion of the 
row/array.
+        debug_assert!(addr != 0, "get_string: null address at index {index}");
+        debug_assert!(
+            len >= 0,
+            "get_string: negative length {len} at index {index}"
+        );
+        let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const 
u8, len as usize) };
+
+        from_utf8(slice).unwrap()
+    }
+
+    /// Returns binary value at the given index of the object.
+    fn get_binary(&self, index: usize) -> &[u8] {
+        let (offset, len) = self.get_offset_and_len(index);
+        let addr = self.get_row_addr() + offset as i64;
+        // SAFETY: addr points to valid binary data within the variable-length 
region.
+        // Offset and length are read from the fixed-length portion of the 
row/array.
+        debug_assert!(addr != 0, "get_binary: null address at index {index}");
+        debug_assert!(
+            len >= 0,
+            "get_binary: negative length {len} at index {index}"
+        );
+        unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }
+    }
+
+    /// Returns decimal value at the given index of the object.
+    fn get_decimal(&self, index: usize, precision: u8) -> i128 {
+        if precision <= MAX_LONG_DIGITS {
+            self.get_long(index) as i128
+        } else {
+            let slice = self.get_binary(index);
+            bytes_to_i128(slice)
+        }
+    }
+
+    /// Returns struct value at the given index of the object.
+    fn get_struct(&self, index: usize, num_fields: usize) -> SparkUnsafeRow {
+        let (offset, len) = self.get_offset_and_len(index);
+        let mut row = SparkUnsafeRow::new_with_num_fields(num_fields);
+        row.point_to(self.get_row_addr() + offset as i64, len);
+
+        row
+    }
+
+    /// Returns array value at the given index of the object.
+    fn get_array(&self, index: usize) -> SparkUnsafeArray {
+        let (offset, _) = self.get_offset_and_len(index);
+        SparkUnsafeArray::new(self.get_row_addr() + offset as i64)
+    }
+
+    fn get_map(&self, index: usize) -> SparkUnsafeMap {
+        let (offset, len) = self.get_offset_and_len(index);
+        SparkUnsafeMap::new(self.get_row_addr() + offset as i64, len)
+    }
+}
+
+/// Generates primitive accessor implementations for `SparkUnsafeObject`.
+///
+/// Uses `$read_method` to read typed values from raw pointers:
+/// - `read` for aligned access (SparkUnsafeRow — all offsets are 8-byte 
aligned)
+/// - `read_unaligned` for potentially unaligned access (SparkUnsafeArray)
+macro_rules! impl_primitive_accessors {
+    ($read_method:ident) => {
+        #[inline]
+        fn get_boolean(&self, index: usize) -> bool {
+            let addr = self.get_element_offset(index, 1);
+            debug_assert!(
+                !addr.is_null(),
+                "get_boolean: null pointer at index {index}"
+            );
+            // SAFETY: addr points to valid element data within the row/array 
region.
+            unsafe { *addr != 0 }
+        }
+
+        #[inline]
+        fn get_byte(&self, index: usize) -> i8 {
+            let addr = self.get_element_offset(index, 1);
+            debug_assert!(!addr.is_null(), "get_byte: null pointer at index 
{index}");
+            // SAFETY: addr points to valid element data (1 byte) within the 
row/array region.
+            unsafe { *(addr as *const i8) }
+        }
+
+        #[inline]
+        fn get_short(&self, index: usize) -> i16 {
+            let addr = self.get_element_offset(index, 2) as *const i16;
+            debug_assert!(!addr.is_null(), "get_short: null pointer at index 
{index}");
+            // SAFETY: addr points to valid element data (2 bytes) within the 
row/array region.
+            unsafe { addr.$read_method() }
+        }
+
+        #[inline]
+        fn get_int(&self, index: usize) -> i32 {
+            let addr = self.get_element_offset(index, 4) as *const i32;
+            debug_assert!(!addr.is_null(), "get_int: null pointer at index 
{index}");
+            // SAFETY: addr points to valid element data (4 bytes) within the 
row/array region.
+            unsafe { addr.$read_method() }
+        }
+
+        #[inline]
+        fn get_long(&self, index: usize) -> i64 {
+            let addr = self.get_element_offset(index, 8) as *const i64;
+            debug_assert!(!addr.is_null(), "get_long: null pointer at index 
{index}");
+            // SAFETY: addr points to valid element data (8 bytes) within the 
row/array region.
+            unsafe { addr.$read_method() }
+        }
+
+        #[inline]
+        fn get_float(&self, index: usize) -> f32 {
+            let addr = self.get_element_offset(index, 4) as *const f32;
+            debug_assert!(!addr.is_null(), "get_float: null pointer at index 
{index}");
+            // SAFETY: addr points to valid element data (4 bytes) within the 
row/array region.
+            unsafe { addr.$read_method() }
+        }
+
+        #[inline]
+        fn get_double(&self, index: usize) -> f64 {
+            let addr = self.get_element_offset(index, 8) as *const f64;
+            debug_assert!(!addr.is_null(), "get_double: null pointer at index 
{index}");
+            // SAFETY: addr points to valid element data (8 bytes) within the 
row/array region.
+            unsafe { addr.$read_method() }
+        }
+
+        #[inline]
+        fn get_date(&self, index: usize) -> i32 {
+            let addr = self.get_element_offset(index, 4) as *const i32;
+            debug_assert!(!addr.is_null(), "get_date: null pointer at index 
{index}");
+            // SAFETY: addr points to valid element data (4 bytes) within the 
row/array region.
+            unsafe { addr.$read_method() }
+        }
+
+        #[inline]
+        fn get_timestamp(&self, index: usize) -> i64 {
+            let addr = self.get_element_offset(index, 8) as *const i64;
+            debug_assert!(
+                !addr.is_null(),
+                "get_timestamp: null pointer at index {index}"
+            );
+            // SAFETY: addr points to valid element data (8 bytes) within the 
row/array region.
+            unsafe { addr.$read_method() }
+        }
+    };
+}
+pub(crate) use impl_primitive_accessors;
diff --git a/native/shuffle/src/writers/buf_batch_writer.rs 
b/native/shuffle/src/writers/buf_batch_writer.rs
index 6344a8e5f..cfddb4653 100644
--- a/native/shuffle/src/writers/buf_batch_writer.rs
+++ b/native/shuffle/src/writers/buf_batch_writer.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::ShuffleBlockWriter;
+use super::ShuffleBlockWriter;
 use arrow::array::RecordBatch;
 use arrow::compute::kernels::coalesce::BatchCoalescer;
 use datafusion::physical_plan::metrics::Time;
diff --git a/native/shuffle/src/writers/checksum.rs 
b/native/shuffle/src/writers/checksum.rs
new file mode 100644
index 000000000..b240302e6
--- /dev/null
+++ b/native/shuffle/src/writers/checksum.rs
@@ -0,0 +1,81 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use bytes::Buf;
+use crc32fast::Hasher;
+use datafusion_comet_jni_bridge::errors::{CometError, CometResult};
+use simd_adler32::Adler32;
+use std::io::{Cursor, SeekFrom};
+
+/// Checksum algorithms for writing IPC bytes.
+#[derive(Clone)]
+pub(crate) enum Checksum {
+    /// CRC32 checksum algorithm.
+    CRC32(Hasher),
+    /// Adler32 checksum algorithm.
+    Adler32(Adler32),
+}
+
+impl Checksum {
+    pub(crate) fn try_new(algo: i32, initial_opt: Option<u32>) -> 
CometResult<Self> {
+        match algo {
+            0 => {
+                let hasher = if let Some(initial) = initial_opt {
+                    Hasher::new_with_initial(initial)
+                } else {
+                    Hasher::new()
+                };
+                Ok(Checksum::CRC32(hasher))
+            }
+            1 => {
+                let hasher = if let Some(initial) = initial_opt {
+                    // Note that Adler32 initial state is not zero.
+                    // i.e., `Adler32::from_checksum(0)` is not the same as 
`Adler32::new()`.
+                    Adler32::from_checksum(initial)
+                } else {
+                    Adler32::new()
+                };
+                Ok(Checksum::Adler32(hasher))
+            }
+            _ => Err(CometError::Internal(
+                "Unsupported checksum algorithm".to_string(),
+            )),
+        }
+    }
+
+    pub(crate) fn update(&mut self, cursor: &mut Cursor<&mut Vec<u8>>) -> 
CometResult<()> {
+        match self {
+            Checksum::CRC32(hasher) => {
+                std::io::Seek::seek(cursor, SeekFrom::Start(0))?;
+                hasher.update(cursor.chunk());
+                Ok(())
+            }
+            Checksum::Adler32(hasher) => {
+                std::io::Seek::seek(cursor, SeekFrom::Start(0))?;
+                hasher.write(cursor.chunk());
+                Ok(())
+            }
+        }
+    }
+
+    pub(crate) fn finalize(self) -> u32 {
+        match self {
+            Checksum::CRC32(hasher) => hasher.finalize(),
+            Checksum::Adler32(hasher) => hasher.finish(),
+        }
+    }
+}
diff --git a/native/shuffle/src/writers/mod.rs 
b/native/shuffle/src/writers/mod.rs
index b58989e46..75caf9f3a 100644
--- a/native/shuffle/src/writers/mod.rs
+++ b/native/shuffle/src/writers/mod.rs
@@ -16,7 +16,11 @@
 // under the License.
 
 mod buf_batch_writer;
-mod partition_writer;
+mod checksum;
+mod shuffle_block_writer;
+mod spill;
 
 pub(crate) use buf_batch_writer::BufBatchWriter;
-pub(crate) use partition_writer::PartitionWriter;
+pub(crate) use checksum::Checksum;
+pub use shuffle_block_writer::{CompressionCodec, ShuffleBlockWriter};
+pub(crate) use spill::PartitionWriter;
diff --git a/native/shuffle/src/codec.rs 
b/native/shuffle/src/writers/shuffle_block_writer.rs
similarity index 60%
rename from native/shuffle/src/codec.rs
rename to native/shuffle/src/writers/shuffle_block_writer.rs
index c8edc2468..5ed5330e3 100644
--- a/native/shuffle/src/codec.rs
+++ b/native/shuffle/src/writers/shuffle_block_writer.rs
@@ -17,17 +17,13 @@
 
 use arrow::array::RecordBatch;
 use arrow::datatypes::Schema;
-use arrow::ipc::reader::StreamReader;
 use arrow::ipc::writer::StreamWriter;
-use bytes::Buf;
-use crc32fast::Hasher;
 use datafusion::common::DataFusionError;
 use datafusion::error::Result;
 use datafusion::physical_plan::metrics::Time;
-use datafusion_comet_jni_bridge::errors::{CometError, CometResult};
-use simd_adler32::Adler32;
 use std::io::{Cursor, Seek, SeekFrom, Write};
 
+/// Compression algorithm applied to shuffle IPC blocks.
 #[derive(Debug, Clone)]
 pub enum CompressionCodec {
     None,
@@ -36,6 +32,7 @@ pub enum CompressionCodec {
     Snappy,
 }
 
+/// Writes a record batch as a length-prefixed, compressed Arrow IPC block.
 #[derive(Clone)]
 pub struct ShuffleBlockWriter {
     codec: CompressionCodec,
@@ -147,93 +144,3 @@ impl ShuffleBlockWriter {
         Ok((end_pos - start_pos) as usize)
     }
 }
-
-pub fn read_ipc_compressed(bytes: &[u8]) -> Result<RecordBatch> {
-    match &bytes[0..4] {
-        b"SNAP" => {
-            let decoder = snap::read::FrameDecoder::new(&bytes[4..]);
-            let mut reader =
-                unsafe { StreamReader::try_new(decoder, 
None)?.with_skip_validation(true) };
-            reader.next().unwrap().map_err(|e| e.into())
-        }
-        b"LZ4_" => {
-            let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]);
-            let mut reader =
-                unsafe { StreamReader::try_new(decoder, 
None)?.with_skip_validation(true) };
-            reader.next().unwrap().map_err(|e| e.into())
-        }
-        b"ZSTD" => {
-            let decoder = zstd::Decoder::new(&bytes[4..])?;
-            let mut reader =
-                unsafe { StreamReader::try_new(decoder, 
None)?.with_skip_validation(true) };
-            reader.next().unwrap().map_err(|e| e.into())
-        }
-        b"NONE" => {
-            let mut reader =
-                unsafe { StreamReader::try_new(&bytes[4..], 
None)?.with_skip_validation(true) };
-            reader.next().unwrap().map_err(|e| e.into())
-        }
-        other => Err(DataFusionError::Execution(format!(
-            "Failed to decode batch: invalid compression codec: {other:?}"
-        ))),
-    }
-}
-
-/// Checksum algorithms for writing IPC bytes.
-#[derive(Clone)]
-pub(crate) enum Checksum {
-    /// CRC32 checksum algorithm.
-    CRC32(Hasher),
-    /// Adler32 checksum algorithm.
-    Adler32(Adler32),
-}
-
-impl Checksum {
-    pub(crate) fn try_new(algo: i32, initial_opt: Option<u32>) -> 
CometResult<Self> {
-        match algo {
-            0 => {
-                let hasher = if let Some(initial) = initial_opt {
-                    Hasher::new_with_initial(initial)
-                } else {
-                    Hasher::new()
-                };
-                Ok(Checksum::CRC32(hasher))
-            }
-            1 => {
-                let hasher = if let Some(initial) = initial_opt {
-                    // Note that Adler32 initial state is not zero.
-                    // i.e., `Adler32::from_checksum(0)` is not the same as 
`Adler32::new()`.
-                    Adler32::from_checksum(initial)
-                } else {
-                    Adler32::new()
-                };
-                Ok(Checksum::Adler32(hasher))
-            }
-            _ => Err(CometError::Internal(
-                "Unsupported checksum algorithm".to_string(),
-            )),
-        }
-    }
-
-    pub(crate) fn update(&mut self, cursor: &mut Cursor<&mut Vec<u8>>) -> 
CometResult<()> {
-        match self {
-            Checksum::CRC32(hasher) => {
-                std::io::Seek::seek(cursor, SeekFrom::Start(0))?;
-                hasher.update(cursor.chunk());
-                Ok(())
-            }
-            Checksum::Adler32(hasher) => {
-                std::io::Seek::seek(cursor, SeekFrom::Start(0))?;
-                hasher.write(cursor.chunk());
-                Ok(())
-            }
-        }
-    }
-
-    pub(crate) fn finalize(self) -> u32 {
-        match self {
-            Checksum::CRC32(hasher) => hasher.finalize(),
-            Checksum::Adler32(hasher) => hasher.finish(),
-        }
-    }
-}
diff --git a/native/shuffle/src/writers/partition_writer.rs 
b/native/shuffle/src/writers/spill.rs
similarity index 95%
rename from native/shuffle/src/writers/partition_writer.rs
rename to native/shuffle/src/writers/spill.rs
index 48017871d..c16caddbf 100644
--- a/native/shuffle/src/writers/partition_writer.rs
+++ b/native/shuffle/src/writers/spill.rs
@@ -15,20 +15,22 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use super::ShuffleBlockWriter;
 use crate::metrics::ShufflePartitionerMetrics;
 use crate::partitioners::PartitionedBatchIterator;
 use crate::writers::buf_batch_writer::BufBatchWriter;
-use crate::ShuffleBlockWriter;
 use datafusion::common::DataFusionError;
 use datafusion::execution::disk_manager::RefCountedTempFile;
 use datafusion::execution::runtime_env::RuntimeEnv;
 use std::fs::{File, OpenOptions};
 
+/// A temporary disk file for spilling a partition's intermediate shuffle data.
 struct SpillFile {
     temp_file: RefCountedTempFile,
     file: File,
 }
 
+/// Manages encoding and optional disk spilling for a single shuffle partition.
 pub(crate) struct PartitionWriter {
     /// Spill file for intermediate shuffle output for this partition. Each 
spill event
     /// will append to this file and the contents will be copied to the 
shuffle file at


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to