tustvold commented on code in PR #1190: URL: https://github.com/apache/datafusion-comet/pull/1190#discussion_r1911955126
########## native/core/src/execution/shuffle/codec.rs: ########## @@ -0,0 +1,694 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::cast::AsArray; +use arrow_array::types::Int32Type; +use arrow_array::{ + Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Decimal128Array, DictionaryArray, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, RecordBatch, + RecordBatchOptions, StringArray, TimestampMicrosecondArray, +}; +use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow_schema::{DataType, Field, Schema, TimeUnit}; +use datafusion_common::DataFusionError; +use std::io::Write; +use std::sync::Arc; + +pub fn fast_codec_supports_type(data_type: &DataType) -> bool { + match data_type { + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 + | DataType::Timestamp(TimeUnit::Microsecond, _) + | DataType::Utf8 + | DataType::Binary => true, + DataType::Decimal128(_, s) if *s >= 0 => true, + DataType::Dictionary(k, v) if **k == DataType::Int32 => fast_codec_supports_type(v), + _ => false, + } +} + +enum DataTypeId { + Boolean = 0, + Int8, + Int16, + Int32, + Int64, + Date32, + Timestamp, + TimestampNtz, + Decimal128, + Float32, + Float64, + Utf8, + Binary, + Dictionary, +} + +pub struct BatchWriter<W: Write> { + inner: W, +} + +impl<W: Write> BatchWriter<W> { + pub fn new(inner: W) -> Self { + Self { inner } + } + + /// Encode the schema (just column names because data types can vary per batch) + pub fn write_partial_schema(&mut self, schema: &Schema) -> Result<(), DataFusionError> { + let schema_len = schema.fields().len(); + self.inner.write_all(&schema_len.to_le_bytes())?; + for field in schema.fields() { + // field name + let field_name = field.name(); + self.inner.write_all(&field_name.len().to_le_bytes())?; + self.inner.write_all(field_name.as_str().as_bytes())?; + // TODO nullable - assume all nullable for now + } + Ok(()) + } + + fn write_data_type(&mut self, data_type: &DataType) -> Result<(), DataFusionError> { + match data_type { + DataType::Boolean => { + self.inner.write_all(&[DataTypeId::Boolean as u8])?; + } + DataType::Int8 => { + self.inner.write_all(&[DataTypeId::Int8 as u8])?; + } + DataType::Int16 => { + self.inner.write_all(&[DataTypeId::Int16 as u8])?; + } + DataType::Int32 => { + self.inner.write_all(&[DataTypeId::Int32 as u8])?; + } + DataType::Int64 => { + self.inner.write_all(&[DataTypeId::Int64 as u8])?; + } + DataType::Float32 => { + self.inner.write_all(&[DataTypeId::Float32 as u8])?; + } + DataType::Float64 => { + self.inner.write_all(&[DataTypeId::Float64 as u8])?; + } + DataType::Date32 => { + self.inner.write_all(&[DataTypeId::Date32 as u8])?; + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { + self.inner.write_all(&[DataTypeId::TimestampNtz as u8])?; + } + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { + self.inner.write_all(&[DataTypeId::Timestamp as u8])?; + let tz_bytes = tz.as_bytes(); + self.inner.write_all(&tz_bytes.len().to_le_bytes())?; + self.inner.write_all(tz_bytes)?; + } + DataType::Utf8 => { + self.inner.write_all(&[DataTypeId::Utf8 as u8])?; + } + DataType::Binary => { + self.inner.write_all(&[DataTypeId::Binary as u8])?; + } + DataType::Decimal128(p, s) if *s >= 0 => { + self.inner + .write_all(&[DataTypeId::Decimal128 as u8, *p, *s as u8])?; + } + DataType::Dictionary(k, v) => { + self.inner.write_all(&[DataTypeId::Dictionary as u8])?; + self.write_data_type(k)?; + self.write_data_type(v)?; + } + other => { + return Err(DataFusionError::Internal(format!( + "unsupported type in fast writer {other}" + ))) + } + } + Ok(()) + } + + pub fn write_all(&mut self, bytes: &[u8]) -> std::io::Result<()> { + self.inner.write_all(bytes) + } + + pub fn write_batch(&mut self, batch: &RecordBatch) -> Result<(), DataFusionError> { + self.write_all(&batch.num_rows().to_le_bytes())?; + for i in 0..batch.num_columns() { + self.write_array(batch.column(i))?; + } + Ok(()) + } + + fn write_array(&mut self, col: &dyn Array) -> Result<(), DataFusionError> { + // data type + self.write_data_type(col.data_type())?; + // array contents + match col.data_type() { + DataType::Boolean => { + let arr = col.as_any().downcast_ref::<BooleanArray>().unwrap(); + // boolean array is the only type we write the array length because it cannot + // be determined from the data buffer size (length is in bits rather than bytes) + self.write_all(&arr.len().to_le_bytes())?; + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Int8 => { + let arr = col.as_any().downcast_ref::<Int8Array>().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Int16 => { + let arr = col.as_any().downcast_ref::<Int16Array>().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Int32 => { + let arr = col.as_any().downcast_ref::<Int32Array>().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Int64 => { + let arr = col.as_any().downcast_ref::<Int64Array>().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Float32 => { + let arr = col.as_any().downcast_ref::<Float32Array>().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Float64 => { + let arr = col.as_any().downcast_ref::<Float64Array>().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Date32 => { + let arr = col.as_any().downcast_ref::<Date32Array>().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let arr = col + .as_any() + .downcast_ref::<TimestampMicrosecondArray>() + .unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Decimal128(_, _) => { + let arr = col.as_any().downcast_ref::<Decimal128Array>().unwrap(); + // write data buffer + self.write_buffer(arr.values().inner())?; + // write null buffer + self.write_null_buffer(arr.nulls())?; + } + DataType::Utf8 => { + let arr = col.as_any().downcast_ref::<StringArray>().unwrap(); + // write data buffer + self.write_buffer(arr.values())?; Review Comment: FYI if the array is sliced this will write the full data buffer, as the slice is not "materialized" into the buffer -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org