This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 576069a31 Add ArrayWriter indirection (#1764) (#2091)
576069a31 is described below

commit 576069a3111094b63d39c3c3973b7cebc90b5a94
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Thu Jul 21 08:06:08 2022 -0400

    Add ArrayWriter indirection (#1764) (#2091)
---
 parquet/src/arrow/arrow_writer/mod.rs | 65 ++++++++++++++++++++++++-----------
 parquet/src/file/writer.rs            | 51 +++++++++++++++++++--------
 2 files changed, 81 insertions(+), 35 deletions(-)

diff --git a/parquet/src/arrow/arrow_writer/mod.rs 
b/parquet/src/arrow/arrow_writer/mod.rs
index 75bd6f6aa..53b094a9e 100644
--- a/parquet/src/arrow/arrow_writer/mod.rs
+++ b/parquet/src/arrow/arrow_writer/mod.rs
@@ -33,7 +33,7 @@ use super::schema::{
     decimal_length_from_precision,
 };
 
-use crate::column::writer::ColumnWriter;
+use crate::column::writer::{get_column_writer, ColumnWriter};
 use crate::errors::{ParquetError, Result};
 use crate::file::metadata::RowGroupMetaDataPtr;
 use crate::file::properties::WriterProperties;
@@ -43,6 +43,44 @@ use levels::{calculate_array_levels, LevelInfo};
 
 mod levels;
 
+/// An object-safe API for writing an [`ArrayRef`]
+trait ArrayWriter {
+    fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()>;
+
+    fn close(&mut self) -> Result<()>;
+}
+
+/// Fallback implementation for writing an [`ArrayRef`] that uses 
[`SerializedColumnWriter`]
+struct ColumnArrayWriter<'a>(Option<SerializedColumnWriter<'a>>);
+
+impl<'a> ArrayWriter for ColumnArrayWriter<'a> {
+    fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()> {
+        write_leaf(self.0.as_mut().unwrap().untyped(), array, levels)?;
+        Ok(())
+    }
+
+    fn close(&mut self) -> Result<()> {
+        self.0.take().unwrap().close()
+    }
+}
+
+fn get_writer<'a, W: Write>(
+    row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>,
+) -> Result<Box<dyn ArrayWriter + 'a>> {
+    let array_writer = row_group_writer
+        .next_column_with_factory(|descr, props, page_writer, on_close| {
+            // TODO: Special case array readers (#1764)
+
+            let column_writer = get_column_writer(descr, props.clone(), 
page_writer);
+            let serialized_writer =
+                SerializedColumnWriter::new(column_writer, Some(on_close));
+
+            Ok(Box::new(ColumnArrayWriter(Some(serialized_writer))))
+        })?
+        .expect("Unable to get column writer");
+    Ok(array_writer)
+}
+
 /// Arrow writer
 ///
 /// Writes Arrow `RecordBatch`es to a Parquet writer, buffering up 
`RecordBatch` in order
@@ -229,17 +267,6 @@ impl<W: Write> ArrowWriter<W> {
     }
 }
 
-/// Convenience method to get the next ColumnWriter from the RowGroupWriter
-#[inline]
-fn get_col_writer<'a, W: Write>(
-    row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>,
-) -> Result<SerializedColumnWriter<'a>> {
-    let col_writer = row_group_writer
-        .next_column()?
-        .expect("Unable to get column writer");
-    Ok(col_writer)
-}
-
 fn write_leaves<W: Write>(
     row_group_writer: &mut SerializedRowGroupWriter<'_, W>,
     arrays: &[ArrayRef],
@@ -277,15 +304,14 @@ fn write_leaves<W: Write>(
         | ArrowDataType::LargeUtf8
         | ArrowDataType::Decimal(_, _)
         | ArrowDataType::FixedSizeBinary(_) => {
-            let mut col_writer = get_col_writer(row_group_writer)?;
+            let mut writer = get_writer(row_group_writer)?;
             for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
-                write_leaf(
-                    col_writer.untyped(),
+                writer.write(
                     array,
                     levels.pop().expect("Levels exhausted"),
                 )?;
             }
-            col_writer.close()?;
+            writer.close()?;
             Ok(())
         }
         ArrowDataType::List(_) | ArrowDataType::LargeList(_) => {
@@ -338,17 +364,16 @@ fn write_leaves<W: Write>(
             Ok(())
         }
         ArrowDataType::Dictionary(_, value_type) => {
-            let mut col_writer = get_col_writer(row_group_writer)?;
+            let mut writer = get_writer(row_group_writer)?;
             for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
                 // cast dictionary to a primitive
                 let array = arrow::compute::cast(array, value_type)?;
-                write_leaf(
-                    col_writer.untyped(),
+                writer.write(
                     &array,
                     levels.pop().expect("Levels exhausted"),
                 )?;
             }
-            col_writer.close()?;
+            writer.close()?;
             Ok(())
         }
         ArrowDataType::Float16 => Err(ParquetError::ArrowError(
diff --git a/parquet/src/file/writer.rs b/parquet/src/file/writer.rs
index 10983c741..467273aaa 100644
--- a/parquet/src/file/writer.rs
+++ b/parquet/src/file/writer.rs
@@ -37,7 +37,9 @@ use crate::file::{
     metadata::*, properties::WriterPropertiesPtr,
     statistics::to_thrift as statistics_to_thrift, FOOTER_SIZE, PARQUET_MAGIC,
 };
-use crate::schema::types::{self, SchemaDescPtr, SchemaDescriptor, TypePtr};
+use crate::schema::types::{
+    self, ColumnDescPtr, SchemaDescPtr, SchemaDescriptor, TypePtr,
+};
 use crate::util::io::TryClone;
 
 /// A wrapper around a [`Write`] that keeps track of the number
@@ -367,22 +369,26 @@ impl<'a, W: Write> SerializedRowGroupWriter<'a, W> {
         }
     }
 
-    /// Returns the next column writer, if available; otherwise returns `None`.
-    /// In case of any IO error or Thrift error, or if row group writer has 
already been
-    /// closed returns `Err`.
-    pub fn next_column(&mut self) -> 
Result<Option<SerializedColumnWriter<'_>>> {
+    /// Returns the next column writer, if available, using the factory 
function;
+    /// otherwise returns `None`.
+    pub(crate) fn next_column_with_factory<'b, F, C>(
+        &'b mut self,
+        factory: F,
+    ) -> Result<Option<C>>
+    where
+        F: FnOnce(
+            ColumnDescPtr,
+            &'b WriterPropertiesPtr,
+            Box<dyn PageWriter + 'b>,
+            OnCloseColumnChunk<'b>,
+        ) -> Result<C>,
+    {
         self.assert_previous_writer_closed()?;
 
         if self.column_index >= self.descr.num_columns() {
             return Ok(None);
         }
         let page_writer = Box::new(SerializedPageWriter::new(self.buf));
-        let column_writer = get_column_writer(
-            self.descr.column(self.column_index),
-            self.props.clone(),
-            page_writer,
-        );
-        self.column_index += 1;
 
         let total_bytes_written = &mut self.total_bytes_written;
         let total_rows_written = &mut self.total_rows_written;
@@ -413,10 +419,25 @@ impl<'a, W: Write> SerializedRowGroupWriter<'a, W> {
                 Ok(())
             };
 
-        Ok(Some(SerializedColumnWriter::new(
-            column_writer,
-            Some(Box::new(on_close)),
-        )))
+        let column = self.descr.column(self.column_index);
+        self.column_index += 1;
+
+        Ok(Some(factory(
+            column,
+            &self.props,
+            page_writer,
+            Box::new(on_close),
+        )?))
+    }
+
+    /// Returns the next column writer, if available; otherwise returns `None`.
+    /// In case of any IO error or Thrift error, or if row group writer has 
already been
+    /// closed returns `Err`.
+    pub fn next_column(&mut self) -> 
Result<Option<SerializedColumnWriter<'_>>> {
+        self.next_column_with_factory(|descr, props, page_writer, on_close| {
+            let column_writer = get_column_writer(descr, props.clone(), 
page_writer);
+            Ok(SerializedColumnWriter::new(column_writer, Some(on_close)))
+        })
     }
 
     /// Closes this row group writer and returns row group metadata.

Reply via email to