This is an automated email from the ASF dual-hosted git repository. tustvold pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push: new 8fff5e4a07 Add ListBuilder::with_field to support non nullable list fields (#5330) (#5331) 8fff5e4a07 is described below commit 8fff5e4a075c20619690c967e217163cd74e0656 Author: Raphael Taylor-Davies <1781103+tustv...@users.noreply.github.com> AuthorDate: Thu Jan 25 12:13:43 2024 +0000 Add ListBuilder::with_field to support non nullable list fields (#5330) (#5331) * Add ListBuilder::with_field (#5330) * Tweak docs * Review feedback --- arrow-array/src/builder/generic_list_builder.rs | 111 +++++++++++++++--------- 1 file changed, 72 insertions(+), 39 deletions(-) diff --git a/arrow-array/src/builder/generic_list_builder.rs b/arrow-array/src/builder/generic_list_builder.rs index 116e2553cf..b857224c5d 100644 --- a/arrow-array/src/builder/generic_list_builder.rs +++ b/arrow-array/src/builder/generic_list_builder.rs @@ -17,10 +17,9 @@ use crate::builder::{ArrayBuilder, BufferBuilder}; use crate::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; -use arrow_buffer::Buffer; use arrow_buffer::NullBufferBuilder; -use arrow_data::ArrayData; -use arrow_schema::Field; +use arrow_buffer::{Buffer, OffsetBuffer}; +use arrow_schema::{Field, FieldRef}; use std::any::Any; use std::sync::Arc; @@ -92,6 +91,7 @@ pub struct GenericListBuilder<OffsetSize: OffsetSizeTrait, T: ArrayBuilder> { offsets_builder: BufferBuilder<OffsetSize>, null_buffer_builder: NullBufferBuilder, values_builder: T, + field: Option<FieldRef>, } impl<O: OffsetSizeTrait, T: ArrayBuilder + Default> Default for GenericListBuilder<O, T> { @@ -116,6 +116,20 @@ impl<OffsetSize: OffsetSizeTrait, T: ArrayBuilder> GenericListBuilder<OffsetSize offsets_builder, null_buffer_builder: NullBufferBuilder::new(capacity), values_builder, + field: None, + } + } + + /// Override the field passed to [`GenericListArray::new`] + /// + /// By default a nullable field is created with the name `item` + /// + /// Note: [`Self::finish`] and [`Self::finish_cloned`] will panic if the + /// field's data type does not match that of `T` + pub fn with_field(self, field: impl Into<FieldRef>) -> Self { + Self { + field: Some(field.into()), + ..self } } } @@ -275,53 +289,37 @@ where /// Builds the [`GenericListArray`] and reset this builder. pub fn finish(&mut self) -> GenericListArray<OffsetSize> { - let len = self.len(); - let values_arr = self.values_builder.finish(); - let values_data = values_arr.to_data(); + let values = self.values_builder.finish(); + let nulls = self.null_buffer_builder.finish(); - let offset_buffer = self.offsets_builder.finish(); - let null_bit_buffer = self.null_buffer_builder.finish(); + let offsets = self.offsets_builder.finish(); + // Safety: Safe by construction + let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; self.offsets_builder.append(OffsetSize::zero()); - let field = Arc::new(Field::new( - "item", - values_data.data_type().clone(), - true, // TODO: find a consistent way of getting this - )); - let data_type = GenericListArray::<OffsetSize>::DATA_TYPE_CONSTRUCTOR(field); - let array_data_builder = ArrayData::builder(data_type) - .len(len) - .add_buffer(offset_buffer) - .add_child_data(values_data) - .nulls(null_bit_buffer); - let array_data = unsafe { array_data_builder.build_unchecked() }; + let field = match &self.field { + Some(f) => f.clone(), + None => Arc::new(Field::new("item", values.data_type().clone(), true)), + }; - GenericListArray::<OffsetSize>::from(array_data) + GenericListArray::new(field, offsets, values, nulls) } /// Builds the [`GenericListArray`] without resetting the builder. pub fn finish_cloned(&self) -> GenericListArray<OffsetSize> { - let len = self.len(); - let values_arr = self.values_builder.finish_cloned(); - let values_data = values_arr.to_data(); - - let offset_buffer = Buffer::from_slice_ref(self.offsets_builder.as_slice()); + let values = self.values_builder.finish_cloned(); let nulls = self.null_buffer_builder.finish_cloned(); - let field = Arc::new(Field::new( - "item", - values_data.data_type().clone(), - true, // TODO: find a consistent way of getting this - )); - let data_type = GenericListArray::<OffsetSize>::DATA_TYPE_CONSTRUCTOR(field); - let array_data_builder = ArrayData::builder(data_type) - .len(len) - .add_buffer(offset_buffer) - .add_child_data(values_data) - .nulls(nulls); - let array_data = unsafe { array_data_builder.build_unchecked() }; + let offsets = Buffer::from_slice_ref(self.offsets_builder.as_slice()); + // Safety: safe by construction + let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; + + let field = match &self.field { + Some(f) => f.clone(), + None => Arc::new(Field::new("item", values.data_type().clone(), true)), + }; - GenericListArray::<OffsetSize>::from(array_data) + GenericListArray::new(field, offsets, values, nulls) } /// Returns the current offsets buffer as a slice @@ -765,4 +763,39 @@ mod tests { assert_eq!(0, i1.null_count()); assert_eq!(i1.values(), &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); } + + #[test] + fn test_with_field() { + let field = Arc::new(Field::new("bar", DataType::Int32, false)); + let mut builder = ListBuilder::new(Int32Builder::new()).with_field(field.clone()); + builder.append_value([Some(1), Some(2), Some(3)]); + builder.append_null(); // This is fine as nullability refers to nullability of values + builder.append_value([Some(4)]); + let array = builder.finish(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::List(field.clone())); + + builder.append_value([Some(4), Some(5)]); + let array = builder.finish(); + assert_eq!(array.data_type(), &DataType::List(field)); + assert_eq!(array.len(), 1); + } + + #[test] + #[should_panic(expected = "Non-nullable field of ListArray \\\"item\\\" cannot contain nulls")] + fn test_checks_nullability() { + let field = Arc::new(Field::new("item", DataType::Int32, false)); + let mut builder = ListBuilder::new(Int32Builder::new()).with_field(field.clone()); + builder.append_value([Some(1), None]); + builder.finish(); + } + + #[test] + #[should_panic(expected = "ListArray expected data type Int64 got Int32")] + fn test_checks_data_type() { + let field = Arc::new(Field::new("item", DataType::Int64, false)); + let mut builder = ListBuilder::new(Int32Builder::new()).with_field(field.clone()); + builder.append_value([Some(1)]); + builder.finish(); + } }