comphead commented on code in PR #3289:
URL: https://github.com/apache/datafusion-comet/pull/3289#discussion_r2855082842
##########
native/core/src/execution/shuffle/spark_unsafe/row.rs:
##########
@@ -441,7 +492,662 @@ pub(super) fn append_field(
Ok(())
}
+/// Appends nested struct fields to the struct builder using field-major order.
+/// This is a helper function for processing nested struct fields recursively.
+///
+/// Unlike `append_struct_fields_field_major`, this function takes slices of
row addresses,
+/// sizes, and null flags directly, without needing to navigate from a parent
row.
+#[allow(clippy::redundant_closure_call)]
+fn append_nested_struct_fields_field_major(
+ row_addresses: &[jlong],
+ row_sizes: &[jint],
+ struct_is_null: &[bool],
+ struct_builder: &mut StructBuilder,
+ fields: &arrow::datatypes::Fields,
+) -> Result<(), CometError> {
+ let num_rows = row_addresses.len();
+ let mut row = SparkUnsafeRow::new_with_num_fields(fields.len());
+
+ // Helper macro for processing primitive fields
+ macro_rules! process_field {
+ ($builder_type:ty, $field_idx:expr, $get_value:expr) => {{
+ let field_builder = get_field_builder!(struct_builder,
$builder_type, $field_idx);
+
+ for row_idx in 0..num_rows {
+ if struct_is_null[row_idx] {
+ // Struct is null, field is also null
+ field_builder.append_null();
+ } else {
+ let row_addr = row_addresses[row_idx];
+ let row_size = row_sizes[row_idx];
+ row.point_to(row_addr, row_size);
+
+ if row.is_null_at($field_idx) {
+ field_builder.append_null();
+ } else {
+ field_builder.append_value($get_value(&row,
$field_idx));
+ }
+ }
+ }
+ }};
+ }
+
+ // Process each field across all rows
+ for (field_idx, field) in fields.iter().enumerate() {
+ match field.data_type() {
+ DataType::Boolean => {
+ process_field!(BooleanBuilder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_boolean(idx));
+ }
+ DataType::Int8 => {
+ process_field!(Int8Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_byte(idx));
+ }
+ DataType::Int16 => {
+ process_field!(Int16Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_short(idx));
+ }
+ DataType::Int32 => {
+ process_field!(Int32Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_int(idx));
+ }
+ DataType::Int64 => {
+ process_field!(Int64Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_long(idx));
+ }
+ DataType::Float32 => {
+ process_field!(Float32Builder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_float(idx));
+ }
+ DataType::Float64 => {
+ process_field!(Float64Builder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_double(idx));
+ }
+ DataType::Date32 => {
+ process_field!(Date32Builder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_date(idx));
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ process_field!(
+ TimestampMicrosecondBuilder,
+ field_idx,
+ |row: &SparkUnsafeRow, idx| row.get_timestamp(idx)
+ );
+ }
+ DataType::Binary => {
+ let field_builder = get_field_builder!(struct_builder,
BinaryBuilder, field_idx);
+
+ for row_idx in 0..num_rows {
+ if struct_is_null[row_idx] {
+ field_builder.append_null();
+ } else {
+ let row_addr = row_addresses[row_idx];
+ let row_size = row_sizes[row_idx];
+ row.point_to(row_addr, row_size);
+
+ if row.is_null_at(field_idx) {
+ field_builder.append_null();
+ } else {
+
field_builder.append_value(row.get_binary(field_idx));
+ }
+ }
+ }
+ }
+ DataType::Utf8 => {
+ let field_builder = get_field_builder!(struct_builder,
StringBuilder, field_idx);
+
+ for row_idx in 0..num_rows {
+ if struct_is_null[row_idx] {
+ field_builder.append_null();
+ } else {
+ let row_addr = row_addresses[row_idx];
+ let row_size = row_sizes[row_idx];
+ row.point_to(row_addr, row_size);
+
+ if row.is_null_at(field_idx) {
+ field_builder.append_null();
+ } else {
+
field_builder.append_value(row.get_string(field_idx));
+ }
+ }
+ }
+ }
+ DataType::Decimal128(p, _) => {
+ let p = *p;
+ let field_builder =
+ get_field_builder!(struct_builder, Decimal128Builder,
field_idx);
+
+ for row_idx in 0..num_rows {
+ if struct_is_null[row_idx] {
+ field_builder.append_null();
+ } else {
+ let row_addr = row_addresses[row_idx];
+ let row_size = row_sizes[row_idx];
+ row.point_to(row_addr, row_size);
+
+ if row.is_null_at(field_idx) {
+ field_builder.append_null();
+ } else {
+
field_builder.append_value(row.get_decimal(field_idx, p));
+ }
+ }
+ }
+ }
+ DataType::Struct(nested_fields) => {
+ let nested_builder = get_field_builder!(struct_builder,
StructBuilder, field_idx);
+
+ // Collect nested struct addresses and sizes in one pass,
building validity
+ let mut nested_addresses: Vec<jlong> =
Vec::with_capacity(num_rows);
+ let mut nested_sizes: Vec<jint> = Vec::with_capacity(num_rows);
+ let mut nested_is_null: Vec<bool> =
Vec::with_capacity(num_rows);
+
+ for row_idx in 0..num_rows {
+ if struct_is_null[row_idx] {
+ // Parent struct is null, nested struct is also null
+ nested_builder.append_null();
+ nested_is_null.push(true);
+ nested_addresses.push(0);
+ nested_sizes.push(0);
+ } else {
+ let row_addr = row_addresses[row_idx];
+ let row_size = row_sizes[row_idx];
+ row.point_to(row_addr, row_size);
+
+ if row.is_null_at(field_idx) {
+ nested_builder.append_null();
+ nested_is_null.push(true);
+ nested_addresses.push(0);
+ nested_sizes.push(0);
+ } else {
+ nested_builder.append(true);
+ nested_is_null.push(false);
+ // Get nested struct address and size
+ let nested_row = row.get_struct(field_idx,
nested_fields.len());
+ nested_addresses.push(nested_row.get_row_addr());
+ nested_sizes.push(nested_row.get_row_size());
+ }
+ }
+ }
+
+ // Recursively process nested struct fields in field-major
order
+ append_nested_struct_fields_field_major(
+ &nested_addresses,
+ &nested_sizes,
+ &nested_is_null,
+ nested_builder,
+ nested_fields,
+ )?;
+ }
+ // For list and map, fall back to append_field since they have
variable-length elements
+ dt @ (DataType::List(_) | DataType::Map(_, _)) => {
+ for row_idx in 0..num_rows {
+ if struct_is_null[row_idx] {
+ let null_row = SparkUnsafeRow::default();
+ append_field(dt, struct_builder, &null_row,
field_idx)?;
+ } else {
+ let row_addr = row_addresses[row_idx];
+ let row_size = row_sizes[row_idx];
+ row.point_to(row_addr, row_size);
+ append_field(dt, struct_builder, &row, field_idx)?;
+ }
+ }
+ }
+ _ => {
+ unreachable!(
+ "Unsupported data type of struct field: {:?}",
+ field.data_type()
+ )
+ }
+ }
+ }
+
+ Ok(())
+}
+
+/// Reads row address and size from JVM-provided pointer arrays and points the
row to that data.
+///
+/// # Safety
+/// Caller must ensure row_addresses_ptr and row_sizes_ptr are valid for index
i.
+/// This is guaranteed when called from append_columns with indices in
[row_start, row_end).
+macro_rules! read_row_at {
+ ($row:expr, $row_addresses_ptr:expr, $row_sizes_ptr:expr, $i:expr) => {{
+ // SAFETY: Caller guarantees pointers are valid for this index (see
macro doc)
+ let row_addr = unsafe { *$row_addresses_ptr.add($i) };
+ let row_size = unsafe { *$row_sizes_ptr.add($i) };
+ $row.point_to(row_addr, row_size);
+ }};
+}
+
+/// Appends a batch of list values to the list builder with a single type
dispatch.
+/// This moves type dispatch from O(rows) to O(1), significantly improving
performance
+/// for large batches.
+#[allow(clippy::too_many_arguments)]
+fn append_list_column_batch(
+ row_addresses_ptr: *mut jlong,
+ row_sizes_ptr: *mut jint,
+ row_start: usize,
+ row_end: usize,
+ schema: &[DataType],
+ column_idx: usize,
+ element_type: &DataType,
+ list_builder: &mut ListBuilder<Box<dyn ArrayBuilder>>,
+) -> Result<(), CometError> {
+ let mut row = SparkUnsafeRow::new(schema);
+
+ // Helper macro for primitive element types - gets builder fresh each
iteration
+ // to avoid borrow conflicts with list_builder.append()
+ macro_rules! process_primitive_lists {
+ ($builder_type:ty, $append_fn:ident) => {{
+ for i in row_start..row_end {
+ read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i);
+
+ if row.is_null_at(column_idx) {
+ list_builder.append_null();
+ } else {
+ let array = row.get_array(column_idx);
+ // Get values builder fresh each iteration to avoid borrow
conflict
+ let values_builder = list_builder
+ .values()
+ .as_any_mut()
+ .downcast_mut::<$builder_type>()
+ .expect(stringify!($builder_type));
+ array.$append_fn::<true>(values_builder);
+ list_builder.append(true);
+ }
+ }
+ }};
+ }
+
+ match element_type {
+ DataType::Boolean => {
+ process_primitive_lists!(BooleanBuilder,
append_booleans_to_builder);
+ }
+ DataType::Int8 => {
+ process_primitive_lists!(Int8Builder, append_bytes_to_builder);
+ }
+ DataType::Int16 => {
+ process_primitive_lists!(Int16Builder, append_shorts_to_builder);
+ }
+ DataType::Int32 => {
+ process_primitive_lists!(Int32Builder, append_ints_to_builder);
+ }
+ DataType::Int64 => {
+ process_primitive_lists!(Int64Builder, append_longs_to_builder);
+ }
+ DataType::Float32 => {
+ process_primitive_lists!(Float32Builder, append_floats_to_builder);
+ }
+ DataType::Float64 => {
+ process_primitive_lists!(Float64Builder,
append_doubles_to_builder);
+ }
+ DataType::Date32 => {
+ process_primitive_lists!(Date32Builder, append_dates_to_builder);
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ process_primitive_lists!(TimestampMicrosecondBuilder,
append_timestamps_to_builder);
+ }
+ // For complex element types, fall back to per-row dispatch
+ _ => {
+ for i in row_start..row_end {
+ read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i);
+
+ if row.is_null_at(column_idx) {
+ list_builder.append_null();
+ } else {
+ append_list_element(element_type, list_builder,
&row.get_array(column_idx))?;
+ }
+ }
+ }
+ }
+
+ Ok(())
+}
+
+/// Appends a batch of map values to the map builder with a single type
dispatch.
+/// This moves type dispatch from O(rows × 2) to O(2), improving performance
for maps.
+#[allow(clippy::too_many_arguments)]
+fn append_map_column_batch(
+ row_addresses_ptr: *mut jlong,
+ row_sizes_ptr: *mut jint,
+ row_start: usize,
+ row_end: usize,
+ schema: &[DataType],
+ column_idx: usize,
+ field: &arrow::datatypes::FieldRef,
+ map_builder: &mut MapBuilder<Box<dyn ArrayBuilder>, Box<dyn ArrayBuilder>>,
+) -> Result<(), CometError> {
+ let mut row = SparkUnsafeRow::new(schema);
+ let (key_field, value_field, _) = get_map_key_value_fields(field)?;
+ let key_type = key_field.data_type();
+ let value_type = value_field.data_type();
+
+ // Helper macro for processing maps with primitive key/value types
+ // Uses scoped borrows to avoid borrow checker conflicts
+ macro_rules! process_primitive_maps {
+ ($key_builder:ty, $key_append:ident, $val_builder:ty,
$val_append:ident) => {{
+ for i in row_start..row_end {
+ read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i);
+
+ if row.is_null_at(column_idx) {
+ map_builder.append(false)?;
+ } else {
+ let map = row.get_map(column_idx);
+ // Process keys in a scope so borrow ends
+ {
+ let keys_builder = map_builder
+ .keys()
+ .as_any_mut()
+ .downcast_mut::<$key_builder>()
+ .expect(stringify!($key_builder));
+ map.keys.$key_append::<false>(keys_builder);
+ }
+ // Process values in a scope so borrow ends
+ {
+ let values_builder = map_builder
+ .values()
+ .as_any_mut()
+ .downcast_mut::<$val_builder>()
+ .expect(stringify!($val_builder));
+ map.values.$val_append::<true>(values_builder);
+ }
+ map_builder.append(true)?;
+ }
+ }
+ }};
+ }
+
+ // Optimize common map type combinations
+ match (key_type, value_type) {
+ // Map<Int64, Int64>
+ (DataType::Int64, DataType::Int64) => {
+ process_primitive_maps!(
+ Int64Builder,
+ append_longs_to_builder,
+ Int64Builder,
+ append_longs_to_builder
+ );
+ }
+ // Map<Int64, Float64>
+ (DataType::Int64, DataType::Float64) => {
+ process_primitive_maps!(
+ Int64Builder,
+ append_longs_to_builder,
+ Float64Builder,
+ append_doubles_to_builder
+ );
+ }
+ // Map<Int32, Int32>
+ (DataType::Int32, DataType::Int32) => {
+ process_primitive_maps!(
+ Int32Builder,
+ append_ints_to_builder,
+ Int32Builder,
+ append_ints_to_builder
+ );
+ }
+ // Map<Int32, Int64>
+ (DataType::Int32, DataType::Int64) => {
+ process_primitive_maps!(
+ Int32Builder,
+ append_ints_to_builder,
+ Int64Builder,
+ append_longs_to_builder
+ );
+ }
+ // For other types, fall back to per-row dispatch
+ _ => {
+ for i in row_start..row_end {
+ read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i);
+
+ if row.is_null_at(column_idx) {
+ map_builder.append(false)?;
+ } else {
+ append_map_elements(field, map_builder,
&row.get_map(column_idx))?;
+ }
+ }
+ }
+ }
+
+ Ok(())
+}
+
+/// Appends struct fields to the struct builder using field-major order.
+/// This processes one field at a time across all rows, which moves type
dispatch
+/// outside the row loop (O(fields) dispatches instead of O(rows × fields)).
+#[allow(clippy::redundant_closure_call, clippy::too_many_arguments)]
+fn append_struct_fields_field_major(
+ row_addresses_ptr: *mut jlong,
+ row_sizes_ptr: *mut jint,
+ row_start: usize,
+ row_end: usize,
+ parent_row: &mut SparkUnsafeRow,
+ column_idx: usize,
+ struct_builder: &mut StructBuilder,
+ fields: &arrow::datatypes::Fields,
+) -> Result<(), CometError> {
+ let num_rows = row_end - row_start;
+ let num_fields = fields.len();
+
+ // First pass: Build struct validity and collect which structs are null
+ // We use a Vec<bool> for simplicity; could use a bitset for better memory
+ let mut struct_is_null = Vec::with_capacity(num_rows);
+
+ for i in row_start..row_end {
+ read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i);
+
+ let is_null = parent_row.is_null_at(column_idx);
+ struct_is_null.push(is_null);
+
+ if is_null {
+ struct_builder.append_null();
+ } else {
+ struct_builder.append(true);
+ }
+ }
+
+ // Helper macro for processing primitive fields
+ macro_rules! process_field {
+ ($builder_type:ty, $field_idx:expr, $get_value:expr) => {{
+ let field_builder = get_field_builder!(struct_builder,
$builder_type, $field_idx);
+
+ for (row_idx, i) in (row_start..row_end).enumerate() {
+ if struct_is_null[row_idx] {
+ // Struct is null, field is also null
+ field_builder.append_null();
+ } else {
+ read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr,
i);
+ let nested_row = parent_row.get_struct(column_idx,
num_fields);
+
+ if nested_row.is_null_at($field_idx) {
+ field_builder.append_null();
+ } else {
+ field_builder.append_value($get_value(&nested_row,
$field_idx));
+ }
+ }
+ }
+ }};
+ }
+
+ // Second pass: Process each field across all rows
+ for (field_idx, field) in fields.iter().enumerate() {
+ match field.data_type() {
+ DataType::Boolean => {
+ process_field!(BooleanBuilder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_boolean(idx));
+ }
+ DataType::Int8 => {
+ process_field!(Int8Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_byte(idx));
+ }
+ DataType::Int16 => {
+ process_field!(Int16Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_short(idx));
+ }
+ DataType::Int32 => {
+ process_field!(Int32Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_int(idx));
+ }
+ DataType::Int64 => {
+ process_field!(Int64Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_long(idx));
+ }
+ DataType::Float32 => {
+ process_field!(Float32Builder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_float(idx));
+ }
+ DataType::Float64 => {
+ process_field!(Float64Builder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_double(idx));
+ }
+ DataType::Date32 => {
+ process_field!(Date32Builder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_date(idx));
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ process_field!(
+ TimestampMicrosecondBuilder,
+ field_idx,
+ |row: &SparkUnsafeRow, idx| row.get_timestamp(idx)
+ );
+ }
+ DataType::Binary => {
+ let field_builder = get_field_builder!(struct_builder,
BinaryBuilder, field_idx);
+
+ for (row_idx, i) in (row_start..row_end).enumerate() {
+ if struct_is_null[row_idx] {
+ field_builder.append_null();
+ } else {
+ read_row_at!(parent_row, row_addresses_ptr,
row_sizes_ptr, i);
+ let nested_row = parent_row.get_struct(column_idx,
num_fields);
+
+ if nested_row.is_null_at(field_idx) {
+ field_builder.append_null();
+ } else {
+
field_builder.append_value(nested_row.get_binary(field_idx));
+ }
+ }
+ }
+ }
+ DataType::Utf8 => {
+ let field_builder = get_field_builder!(struct_builder,
StringBuilder, field_idx);
+
+ for (row_idx, i) in (row_start..row_end).enumerate() {
+ if struct_is_null[row_idx] {
+ field_builder.append_null();
+ } else {
+ read_row_at!(parent_row, row_addresses_ptr,
row_sizes_ptr, i);
+ let nested_row = parent_row.get_struct(column_idx,
num_fields);
+
+ if nested_row.is_null_at(field_idx) {
+ field_builder.append_null();
+ } else {
+
field_builder.append_value(nested_row.get_string(field_idx));
+ }
+ }
+ }
+ }
+ DataType::Decimal128(p, _) => {
+ let p = *p;
+ let field_builder =
+ get_field_builder!(struct_builder, Decimal128Builder,
field_idx);
+
+ for (row_idx, i) in (row_start..row_end).enumerate() {
+ if struct_is_null[row_idx] {
+ field_builder.append_null();
+ } else {
+ read_row_at!(parent_row, row_addresses_ptr,
row_sizes_ptr, i);
+ let nested_row = parent_row.get_struct(column_idx,
num_fields);
+
+ if nested_row.is_null_at(field_idx) {
+ field_builder.append_null();
+ } else {
+
field_builder.append_value(nested_row.get_decimal(field_idx, p));
+ }
+ }
+ }
+ }
+ // For nested structs, apply field-major processing recursively
+ DataType::Struct(nested_fields) => {
+ let nested_builder = get_field_builder!(struct_builder,
StructBuilder, field_idx);
+
+ // Collect nested struct addresses and sizes in one pass,
building validity
+ let mut nested_addresses: Vec<jlong> =
Vec::with_capacity(num_rows);
+ let mut nested_sizes: Vec<jint> = Vec::with_capacity(num_rows);
+ let mut nested_is_null: Vec<bool> =
Vec::with_capacity(num_rows);
+
+ for (row_idx, i) in (row_start..row_end).enumerate() {
+ if struct_is_null[row_idx] {
+ // Parent struct is null, nested struct is also null
+ nested_builder.append_null();
+ nested_is_null.push(true);
+ nested_addresses.push(0);
+ nested_sizes.push(0);
+ } else {
+ read_row_at!(parent_row, row_addresses_ptr,
row_sizes_ptr, i);
+ let parent_struct = parent_row.get_struct(column_idx,
num_fields);
+
+ if parent_struct.is_null_at(field_idx) {
+ nested_builder.append_null();
+ nested_is_null.push(true);
+ nested_addresses.push(0);
+ nested_sizes.push(0);
+ } else {
+ nested_builder.append(true);
+ nested_is_null.push(false);
+ // Get nested struct address and size
+ let nested_row =
+ parent_struct.get_struct(field_idx,
nested_fields.len());
+ nested_addresses.push(nested_row.get_row_addr());
+ nested_sizes.push(nested_row.get_row_size());
+ }
+ }
+ }
+
+ // Recursively process nested struct fields in field-major
order
+ append_nested_struct_fields_field_major(
+ &nested_addresses,
+ &nested_sizes,
+ &nested_is_null,
+ nested_builder,
+ nested_fields,
+ )?;
+ }
+ // For list and map, fall back to append_field since they have
variable-length elements
+ dt @ (DataType::List(_) | DataType::Map(_, _)) => {
+ for (row_idx, i) in (row_start..row_end).enumerate() {
+ if struct_is_null[row_idx] {
+ let null_row = SparkUnsafeRow::default();
+ append_field(dt, struct_builder, &null_row,
field_idx)?;
+ } else {
+ read_row_at!(parent_row, row_addresses_ptr,
row_sizes_ptr, i);
+ let nested_row = parent_row.get_struct(column_idx,
num_fields);
+ append_field(dt, struct_builder, &nested_row,
field_idx)?;
+ }
+ }
+ }
+ _ => {
+ unreachable!(
+ "Unsupported data type of struct field: {:?}",
+ field.data_type()
+ )
+ }
+ }
+ }
+
+ Ok(())
+}
+
/// Appends column of top rows to the given array builder.
+///
+/// # Safety
+///
+/// The caller must ensure:
+/// - `row_addresses_ptr` points to an array of at least `row_end` jlong values
+/// - `row_sizes_ptr` points to an array of at least `row_end` jint values
+/// - Each address in `row_addresses_ptr[row_start..row_end]` points to valid
Spark UnsafeRow data
Review Comment:
Looks like https://github.com/apache/datafusion-comet/pull/3367 is directly
related to safety
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]