This is an automated email from the ASF dual-hosted git repository.
parthc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 34bbe78f8 feat: Cast numeric (non int) to timestamp (#3559)
34bbe78f8 is described below
commit 34bbe78f87b36e596523de45fac3c7a54fc72aae
Author: Bhargava Vadlamani <[email protected]>
AuthorDate: Fri Mar 6 13:52:56 2026 -0800
feat: Cast numeric (non int) to timestamp (#3559)
* float_to_timestamp
* non_numeric_to_timestamp
---
native/spark-expr/Cargo.toml | 4 +
.../benches/cast_non_int_numeric_timestamp.rs | 143 ++++++++++++++
native/spark-expr/src/conversion_funcs/boolean.rs | 45 ++++-
native/spark-expr/src/conversion_funcs/cast.rs | 16 +-
native/spark-expr/src/conversion_funcs/numeric.rs | 212 ++++++++++++++++++++-
.../org/apache/comet/expressions/CometCast.scala | 31 +--
.../scala/org/apache/comet/CometCastSuite.scala | 79 +++++---
.../scala/org/apache/spark/sql/CometTestBase.scala | 49 +++++
8 files changed, 531 insertions(+), 48 deletions(-)
diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index e7c238f7e..b014c49a2 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -103,3 +103,7 @@ path = "tests/spark_expr_reg.rs"
[[bench]]
name = "cast_from_boolean"
harness = false
+
+[[bench]]
+name = "cast_non_int_numeric_timestamp"
+harness = false
diff --git a/native/spark-expr/benches/cast_non_int_numeric_timestamp.rs
b/native/spark-expr/benches/cast_non_int_numeric_timestamp.rs
new file mode 100644
index 000000000..ea1a85e40
--- /dev/null
+++ b/native/spark-expr/benches/cast_non_int_numeric_timestamp.rs
@@ -0,0 +1,143 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::builder::{BooleanBuilder, Decimal128Builder, Float32Builder,
Float64Builder};
+use arrow::array::RecordBatch;
+use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
+use criterion::{criterion_group, criterion_main, Criterion};
+use datafusion::physical_expr::{expressions::Column, PhysicalExpr};
+use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
+use std::sync::Arc;
+
+const BATCH_SIZE: usize = 8192;
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC",
false);
+ let timestamp_type = DataType::Timestamp(TimeUnit::Microsecond,
Some("UTC".into()));
+
+ let mut group = c.benchmark_group("cast_non_int_numeric_to_timestamp");
+
+ // Float32 -> Timestamp
+ let batch_f32 = create_float32_batch();
+ let expr_f32 = Arc::new(Column::new("a", 0));
+ let cast_f32_to_ts = Cast::new(expr_f32, timestamp_type.clone(),
spark_cast_options.clone());
+ group.bench_function("cast_f32_to_timestamp", |b| {
+ b.iter(|| cast_f32_to_ts.evaluate(&batch_f32).unwrap());
+ });
+
+ // Float64 -> Timestamp
+ let batch_f64 = create_float64_batch();
+ let expr_f64 = Arc::new(Column::new("a", 0));
+ let cast_f64_to_ts = Cast::new(expr_f64, timestamp_type.clone(),
spark_cast_options.clone());
+ group.bench_function("cast_f64_to_timestamp", |b| {
+ b.iter(|| cast_f64_to_ts.evaluate(&batch_f64).unwrap());
+ });
+
+ // Boolean -> Timestamp
+ let batch_bool = create_boolean_batch();
+ let expr_bool = Arc::new(Column::new("a", 0));
+ let cast_bool_to_ts = Cast::new(
+ expr_bool,
+ timestamp_type.clone(),
+ spark_cast_options.clone(),
+ );
+ group.bench_function("cast_bool_to_timestamp", |b| {
+ b.iter(|| cast_bool_to_ts.evaluate(&batch_bool).unwrap());
+ });
+
+ // Decimal128 -> Timestamp
+ let batch_decimal = create_decimal128_batch();
+ let expr_decimal = Arc::new(Column::new("a", 0));
+ let cast_decimal_to_ts = Cast::new(
+ expr_decimal,
+ timestamp_type.clone(),
+ spark_cast_options.clone(),
+ );
+ group.bench_function("cast_decimal_to_timestamp", |b| {
+ b.iter(|| cast_decimal_to_ts.evaluate(&batch_decimal).unwrap());
+ });
+
+ group.finish();
+}
+
+fn create_float32_batch() -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32,
true)]));
+ let mut b = Float32Builder::with_capacity(BATCH_SIZE);
+ for i in 0..BATCH_SIZE {
+ if i % 10 == 0 {
+ b.append_null();
+ } else {
+ b.append_value(rand::random::<f32>());
+ }
+ }
+ RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
+}
+
+fn create_float64_batch() -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float64,
true)]));
+ let mut b = Float64Builder::with_capacity(BATCH_SIZE);
+ for i in 0..BATCH_SIZE {
+ if i % 10 == 0 {
+ b.append_null();
+ } else {
+ b.append_value(rand::random::<f64>());
+ }
+ }
+ RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
+}
+
+fn create_boolean_batch() -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean,
true)]));
+ let mut b = BooleanBuilder::with_capacity(BATCH_SIZE);
+ for i in 0..BATCH_SIZE {
+ if i % 10 == 0 {
+ b.append_null();
+ } else {
+ b.append_value(rand::random::<bool>());
+ }
+ }
+ RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
+}
+
+fn create_decimal128_batch() -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "a",
+ DataType::Decimal128(18, 6),
+ true,
+ )]));
+ let mut b = Decimal128Builder::with_capacity(BATCH_SIZE);
+ for i in 0..BATCH_SIZE {
+ if i % 10 == 0 {
+ b.append_null();
+ } else {
+ b.append_value(i as i128 * 1_000_000);
+ }
+ }
+ let array = b.finish().with_precision_and_scale(18, 6).unwrap();
+ RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
+}
+
+fn config() -> Criterion {
+ Criterion::default()
+}
+
+criterion_group! {
+ name = benches;
+ config = config();
+ targets = criterion_benchmark
+}
+criterion_main!(benches);
diff --git a/native/spark-expr/src/conversion_funcs/boolean.rs
b/native/spark-expr/src/conversion_funcs/boolean.rs
index db288fa32..49855790b 100644
--- a/native/spark-expr/src/conversion_funcs/boolean.rs
+++ b/native/spark-expr/src/conversion_funcs/boolean.rs
@@ -16,7 +16,7 @@
// under the License.
use crate::SparkResult;
-use arrow::array::{ArrayRef, AsArray, Decimal128Array};
+use arrow::array::{Array, ArrayRef, AsArray, Decimal128Array,
TimestampMicrosecondBuilder};
use arrow::datatypes::DataType;
use std::sync::Arc;
@@ -28,7 +28,6 @@ pub fn is_df_cast_from_bool_spark_compatible(to_type:
&DataType) -> bool {
)
}
-// only DF incompatible boolean cast
pub fn cast_boolean_to_decimal(
array: &ArrayRef,
precision: u8,
@@ -43,6 +42,25 @@ pub fn cast_boolean_to_decimal(
Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
}
+pub(crate) fn cast_boolean_to_timestamp(
+ array_ref: &ArrayRef,
+ target_tz: &Option<Arc<str>>,
+) -> SparkResult<ArrayRef> {
+ let bool_array = array_ref.as_boolean();
+ let mut builder =
TimestampMicrosecondBuilder::with_capacity(bool_array.len());
+
+ for i in 0..bool_array.len() {
+ if bool_array.is_null(i) {
+ builder.append_null();
+ } else {
+ let micros = if bool_array.value(i) { 1 } else { 0 };
+ builder.append_value(micros);
+ }
+ }
+
+ Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as
ArrayRef)
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -53,6 +71,7 @@ mod tests {
Int64Array, Int8Array, StringArray,
};
use arrow::datatypes::DataType::Decimal128;
+ use arrow::datatypes::TimestampMicrosecondType;
use std::sync::Arc;
fn test_input_bool_array() -> ArrayRef {
@@ -193,4 +212,26 @@ mod tests {
assert_eq!(arr.value(1), expected_arr.value(1));
assert!(arr.is_null(2));
}
+
+ #[test]
+ fn test_cast_boolean_to_timestamp() {
+ let timezones: [Option<Arc<str>>; 3] = [
+ Some(Arc::from("UTC")),
+ Some(Arc::from("America/Los_Angeles")),
+ None,
+ ];
+
+ for tz in &timezones {
+ let bool_array: ArrayRef =
+ Arc::new(BooleanArray::from(vec![Some(true), Some(false),
None]));
+
+ let result = cast_boolean_to_timestamp(&bool_array, tz).unwrap();
+ let ts_array = result.as_primitive::<TimestampMicrosecondType>();
+
+ assert_eq!(ts_array.value(0), 1); // true -> 1 microsecond
+ assert_eq!(ts_array.value(1), 0); // false -> 0 (epoch)
+ assert!(ts_array.is_null(2));
+ assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
+ }
+ }
}
diff --git a/native/spark-expr/src/conversion_funcs/cast.rs
b/native/spark-expr/src/conversion_funcs/cast.rs
index ff09dbe06..a9e688814 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -16,14 +16,15 @@
// under the License.
use crate::conversion_funcs::boolean::{
- cast_boolean_to_decimal, is_df_cast_from_bool_spark_compatible,
+ cast_boolean_to_decimal, cast_boolean_to_timestamp,
is_df_cast_from_bool_spark_compatible,
};
use crate::conversion_funcs::numeric::{
- cast_float32_to_decimal128, cast_float64_to_decimal128,
cast_int_to_decimal128,
- cast_int_to_timestamp, is_df_cast_from_decimal_spark_compatible,
- is_df_cast_from_float_spark_compatible,
is_df_cast_from_int_spark_compatible,
- spark_cast_decimal_to_boolean, spark_cast_float32_to_utf8,
spark_cast_float64_to_utf8,
- spark_cast_int_to_int, spark_cast_nonintegral_numeric_to_integral,
+ cast_decimal_to_timestamp, cast_float32_to_decimal128,
cast_float64_to_decimal128,
+ cast_float_to_timestamp, cast_int_to_decimal128, cast_int_to_timestamp,
+ is_df_cast_from_decimal_spark_compatible,
is_df_cast_from_float_spark_compatible,
+ is_df_cast_from_int_spark_compatible, spark_cast_decimal_to_boolean,
+ spark_cast_float32_to_utf8, spark_cast_float64_to_utf8,
spark_cast_int_to_int,
+ spark_cast_nonintegral_numeric_to_integral,
};
use crate::conversion_funcs::string::{
cast_string_to_date, cast_string_to_decimal, cast_string_to_float,
cast_string_to_int,
@@ -384,6 +385,9 @@ pub(crate) fn cast_array(
cast_boolean_to_decimal(&array, *precision, *scale)
}
(Int8 | Int16 | Int32 | Int64, Timestamp(_, tz)) =>
cast_int_to_timestamp(&array, tz),
+ (Float32 | Float64, Timestamp(_, tz)) =>
cast_float_to_timestamp(&array, tz, eval_mode),
+ (Boolean, Timestamp(_, tz)) => cast_boolean_to_timestamp(&array, tz),
+ (Decimal128(_, scale), Timestamp(_, tz)) =>
cast_decimal_to_timestamp(&array, tz, *scale),
_ if cast_options.is_adapting_schema
|| is_datafusion_spark_compatible(&from_type, to_type) =>
{
diff --git a/native/spark-expr/src/conversion_funcs/numeric.rs
b/native/spark-expr/src/conversion_funcs/numeric.rs
index d204e2871..59a65fb49 100644
--- a/native/spark-expr/src/conversion_funcs/numeric.rs
+++ b/native/spark-expr/src/conversion_funcs/numeric.rs
@@ -24,7 +24,7 @@ use arrow::array::{
OffsetSizeTrait, PrimitiveArray, TimestampMicrosecondBuilder,
};
use arrow::datatypes::{
- is_validate_decimal_precision, ArrowPrimitiveType, DataType,
Decimal128Type, Float32Type,
+ i256, is_validate_decimal_precision, ArrowPrimitiveType, DataType,
Decimal128Type, Float32Type,
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
};
use num::{cast::AsPrimitive, ToPrimitive, Zero};
@@ -75,6 +75,56 @@ pub(crate) fn
is_df_cast_from_decimal_spark_compatible(to_type: &DataType) -> bo
)
}
+macro_rules! cast_float_to_timestamp_impl {
+ ($array:expr, $builder:expr, $primitive_type:ty, $eval_mode:expr) => {{
+ let arr = $array.as_primitive::<$primitive_type>();
+ for i in 0..arr.len() {
+ if arr.is_null(i) {
+ $builder.append_null();
+ } else {
+ let val = arr.value(i) as f64;
+ // Path 1: NaN/Infinity check - error says TIMESTAMP
+ if val.is_nan() || val.is_infinite() {
+ if $eval_mode == EvalMode::Ansi {
+ return Err(SparkError::CastInvalidValue {
+ value: val.to_string(),
+ from_type: "DOUBLE".to_string(),
+ to_type: "TIMESTAMP".to_string(),
+ });
+ }
+ $builder.append_null();
+ } else {
+ // Path 2: Multiply then check overflow - error says BIGINT
+ let micros = val * MICROS_PER_SECOND as f64;
+ if micros.floor() <= i64::MAX as f64 && micros.ceil() >=
i64::MIN as f64 {
+ $builder.append_value(micros as i64);
+ } else {
+ if $eval_mode == EvalMode::Ansi {
+ let value_str = if micros.is_infinite() {
+ if micros.is_sign_positive() {
+ "Infinity".to_string()
+ } else {
+ "-Infinity".to_string()
+ }
+ } else if micros.is_nan() {
+ "NaN".to_string()
+ } else {
+ format!("{:e}", micros).to_uppercase() + "D"
+ };
+ return Err(SparkError::CastOverFlow {
+ value: value_str,
+ from_type: "DOUBLE".to_string(),
+ to_type: "BIGINT".to_string(),
+ });
+ }
+ $builder.append_null();
+ }
+ }
+ }
+ }
+ }};
+}
+
macro_rules! cast_float_to_string {
($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty)
=> {{
@@ -913,6 +963,57 @@ pub(crate) fn cast_int_to_timestamp(
Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as
ArrayRef)
}
+pub(crate) fn cast_decimal_to_timestamp(
+ array_ref: &ArrayRef,
+ target_tz: &Option<Arc<str>>,
+ scale: i8,
+) -> SparkResult<ArrayRef> {
+ let arr = array_ref.as_primitive::<Decimal128Type>();
+ let scale_factor = 10_i128.pow(scale as u32);
+ let mut builder = TimestampMicrosecondBuilder::with_capacity(arr.len());
+
+ for i in 0..arr.len() {
+ if arr.is_null(i) {
+ builder.append_null();
+ } else {
+ let value = arr.value(i);
+ // Note: spark's big decimal truncates to
+ // long value and does not throw error (in all leval modes)
+ let value_256 = i256::from_i128(value);
+ let micros_256 = value_256 * i256::from_i128(MICROS_PER_SECOND as
i128);
+ let ts = micros_256 / i256::from_i128(scale_factor);
+ builder.append_value(ts.as_i128() as i64);
+ }
+ }
+
+ Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as
ArrayRef)
+}
+
+pub(crate) fn cast_float_to_timestamp(
+ array_ref: &ArrayRef,
+ target_tz: &Option<Arc<str>>,
+ eval_mode: EvalMode,
+) -> SparkResult<ArrayRef> {
+ let mut builder =
TimestampMicrosecondBuilder::with_capacity(array_ref.len());
+
+ match array_ref.data_type() {
+ DataType::Float32 => {
+ cast_float_to_timestamp_impl!(array_ref, builder, Float32Type,
eval_mode)
+ }
+ DataType::Float64 => {
+ cast_float_to_timestamp_impl!(array_ref, builder, Float64Type,
eval_mode)
+ }
+ dt => {
+ return Err(SparkError::Internal(format!(
+ "Unsupported type for cast_float_to_timestamp: {:?}",
+ dt
+ )))
+ }
+ }
+
+ Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as
ArrayRef)
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -1100,4 +1201,113 @@ mod tests {
assert!(casted.is_null(8));
assert!(casted.is_null(9));
}
+
+ #[test]
+ fn test_cast_decimal_to_timestamp() {
+ let timezones: [Option<Arc<str>>; 3] = [
+ Some(Arc::from("UTC")),
+ Some(Arc::from("America/Los_Angeles")),
+ None,
+ ];
+
+ for tz in &timezones {
+ // Decimal128 with scale 6
+ let decimal_array: ArrayRef = Arc::new(
+ Decimal128Array::from(vec![
+ Some(0_i128),
+ Some(1_000_000_i128),
+ Some(-1_000_000_i128),
+ Some(1_500_000_i128),
+ Some(123_456_789_i128),
+ None,
+ ])
+ .with_precision_and_scale(18, 6)
+ .unwrap(),
+ );
+
+ let result = cast_decimal_to_timestamp(&decimal_array, tz,
6).unwrap();
+ let ts_array = result.as_primitive::<TimestampMicrosecondType>();
+
+ assert_eq!(ts_array.value(0), 0);
+ assert_eq!(ts_array.value(1), 1_000_000);
+ assert_eq!(ts_array.value(2), -1_000_000);
+ assert_eq!(ts_array.value(3), 1_500_000);
+ assert_eq!(ts_array.value(4), 123_456_789);
+ assert!(ts_array.is_null(5));
+ assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
+
+ // Test with scale 2
+ let decimal_array: ArrayRef = Arc::new(
+ Decimal128Array::from(vec![Some(100_i128), Some(150_i128),
Some(-250_i128)])
+ .with_precision_and_scale(10, 2)
+ .unwrap(),
+ );
+
+ let result = cast_decimal_to_timestamp(&decimal_array, tz,
2).unwrap();
+ let ts_array = result.as_primitive::<TimestampMicrosecondType>();
+
+ assert_eq!(ts_array.value(0), 1_000_000);
+ assert_eq!(ts_array.value(1), 1_500_000);
+ assert_eq!(ts_array.value(2), -2_500_000);
+ }
+ }
+
+ #[test]
+ fn test_cast_float_to_timestamp() {
+ let timezones: [Option<Arc<str>>; 3] = [
+ Some(Arc::from("UTC")),
+ Some(Arc::from("America/Los_Angeles")),
+ None,
+ ];
+ let eval_modes = [EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try];
+
+ for tz in &timezones {
+ for eval_mode in &eval_modes {
+ // Float64 tests
+ let f64_array: ArrayRef = Arc::new(Float64Array::from(vec![
+ Some(0.0),
+ Some(1.0),
+ Some(-1.0),
+ Some(1.5),
+ Some(0.000001),
+ None,
+ ]));
+
+ let result = cast_float_to_timestamp(&f64_array, tz,
*eval_mode).unwrap();
+ let ts_array =
result.as_primitive::<TimestampMicrosecondType>();
+
+ assert_eq!(ts_array.value(0), 0);
+ assert_eq!(ts_array.value(1), 1_000_000);
+ assert_eq!(ts_array.value(2), -1_000_000);
+ assert_eq!(ts_array.value(3), 1_500_000);
+ assert_eq!(ts_array.value(4), 1);
+ assert!(ts_array.is_null(5));
+ assert_eq!(ts_array.timezone(), tz.as_ref().map(|s|
s.as_ref()));
+
+ // Float32 tests
+ let f32_array: ArrayRef = Arc::new(Float32Array::from(vec![
+ Some(0.0_f32),
+ Some(1.0_f32),
+ Some(-1.0_f32),
+ None,
+ ]));
+
+ let result = cast_float_to_timestamp(&f32_array, tz,
*eval_mode).unwrap();
+ let ts_array =
result.as_primitive::<TimestampMicrosecondType>();
+
+ assert_eq!(ts_array.value(0), 0);
+ assert_eq!(ts_array.value(1), 1_000_000);
+ assert_eq!(ts_array.value(2), -1_000_000);
+ assert!(ts_array.is_null(3));
+ }
+ }
+
+ // ANSI mode errors on NaN/Infinity
+ let tz = &Some(Arc::from("UTC"));
+ let f64_nan: ArrayRef =
Arc::new(Float64Array::from(vec![Some(f64::NAN)]));
+ assert!(cast_float_to_timestamp(&f64_nan, tz,
EvalMode::Ansi).is_err());
+
+ let f64_inf: ArrayRef =
Arc::new(Float64Array::from(vec![Some(f64::INFINITY)]));
+ assert!(cast_float_to_timestamp(&f64_inf, tz,
EvalMode::Ansi).is_err());
+ }
}
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 15dfcb2d7..95d536690 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -21,7 +21,7 @@ package org.apache.comet.expressions
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression,
Literal}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes,
DecimalType, NullType, StructType}
+import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes,
DecimalType, NullType, StructType, TimestampType}
import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
@@ -63,16 +63,17 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
cast: Cast,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
+ val cometEvalMode = evalMode(cast)
cast.child match {
case _: Literal =>
exprToProtoInternal(Literal.create(cast.eval(), cast.dataType),
inputs, binding)
case _ =>
- if (isAlwaysCastToNull(cast.child.dataType, cast.dataType,
evalMode(cast))) {
+ if (isAlwaysCastToNull(cast.child.dataType, cast.dataType,
cometEvalMode)) {
exprToProtoInternal(Literal.create(null, cast.dataType), inputs,
binding)
} else {
val childExpr = exprToProtoInternal(cast.child, inputs, binding)
if (childExpr.isDefined) {
- castToProto(cast, cast.timeZoneId, cast.dataType, childExpr.get,
evalMode(cast))
+ castToProto(cast, cast.timeZoneId, cast.dataType, childExpr.get,
cometEvalMode)
} else {
withInfo(cast, cast.child)
None
@@ -165,7 +166,7 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
case (_: DecimalType, _) =>
canCastFromDecimal(toType)
case (DataTypes.BooleanType, _) =>
- canCastFromBoolean(toType)
+ canCastFromBoolean(toType, evalMode)
case (DataTypes.ByteType, _) =>
canCastFromByte(toType, evalMode)
case (DataTypes.ShortType, _) =>
@@ -282,12 +283,15 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
}
}
- private def canCastFromBoolean(toType: DataType): SupportLevel = toType
match {
- case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType |
DataTypes.LongType |
- DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
- Compatible()
- case _ => unsupported(DataTypes.BooleanType, toType)
- }
+ private def canCastFromBoolean(toType: DataType, evalMode:
CometEvalMode.Value): SupportLevel =
+ toType match {
+ case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType |
DataTypes.LongType |
+ DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
+ Compatible()
+ case _: TimestampType if evalMode == CometEvalMode.LEGACY =>
+ Compatible()
+ case _ => unsupported(DataTypes.BooleanType, toType)
+ }
private def canCastFromByte(toType: DataType, evalMode:
CometEvalMode.Value): SupportLevel =
toType match {
@@ -357,7 +361,7 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
private def canCastFromFloat(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType |
DataTypes.ShortType |
- DataTypes.IntegerType | DataTypes.LongType =>
+ DataTypes.IntegerType | DataTypes.LongType | DataTypes.TimestampType =>
Compatible()
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/1371
@@ -368,7 +372,7 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
private def canCastFromDouble(toType: DataType): SupportLevel = toType match
{
case DataTypes.BooleanType | DataTypes.FloatType | DataTypes.ByteType |
DataTypes.ShortType |
- DataTypes.IntegerType | DataTypes.LongType =>
+ DataTypes.IntegerType | DataTypes.LongType | DataTypes.TimestampType =>
Compatible()
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/1371
@@ -378,7 +382,8 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
private def canCastFromDecimal(toType: DataType): SupportLevel = toType
match {
case DataTypes.FloatType | DataTypes.DoubleType | DataTypes.ByteType |
DataTypes.ShortType |
- DataTypes.IntegerType | DataTypes.LongType | DataTypes.BooleanType =>
+ DataTypes.IntegerType | DataTypes.LongType | DataTypes.BooleanType |
+ DataTypes.TimestampType =>
Compatible()
case _ => Unsupported(Some(s"Cast from DecimalType to $toType is not
supported"))
}
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 72c2390d7..48242a978 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -167,9 +167,9 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateBools(), DataTypes.StringType)
}
- ignore("cast BooleanType to TimestampType") {
- // Arrow error: Cast error: Casting from Boolean to Timestamp(Microsecond,
Some("UTC")) not supported
- castTest(generateBools(), DataTypes.TimestampType)
+ test("cast BooleanType to TimestampType") {
+ // Spark does not support ANSI or Try mode for Boolean to Timestamp casts
+ castTest(generateBools(), DataTypes.TimestampType, testAnsi = false,
testTry = false)
}
// CAST from ByteType
@@ -504,9 +504,13 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(withNulls(values).toDF("a"), DataTypes.StringType)
}
- ignore("cast FloatType to TimestampType") {
- // java.lang.ArithmeticException: long overflow
- castTest(generateFloats(), DataTypes.TimestampType)
+ test("cast FloatType to TimestampType") {
+ compatibleTimezones.foreach { tz =>
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
+ // Use useDFDiff to avoid collect() which fails on extreme timestamp
values
+ castTest(generateFloats(), DataTypes.TimestampType, useDataFrameDiff =
true)
+ }
+ }
}
// CAST from DoubleType
@@ -560,9 +564,13 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(withNulls(values).toDF("a"), DataTypes.StringType)
}
- ignore("cast DoubleType to TimestampType") {
- // java.lang.ArithmeticException: long overflow
- castTest(generateDoubles(), DataTypes.TimestampType)
+ test("cast DoubleType to TimestampType") {
+ compatibleTimezones.foreach { tz =>
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
+ // Use useDFDiff to avoid collect() which fails on extreme timestamp
values
+ castTest(generateDoubles(), DataTypes.TimestampType, useDataFrameDiff
= true)
+ }
+ }
}
// CAST from DecimalType(10,2)
@@ -627,11 +635,14 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateDecimalsPrecision10Scale2(), DataTypes.StringType)
}
- ignore("cast DecimalType(10,2) to TimestampType") {
- // input: -123456.789000000000000000, expected: 1969-12-30 05:42:23.211,
actual: 1969-12-31 15:59:59.876544
+ test("cast DecimalType(10,2) to TimestampType") {
castTest(generateDecimalsPrecision10Scale2(), DataTypes.TimestampType)
}
+ test("cast DecimalType(38,10) to TimestampType") {
+ castTest(generateDecimalsPrecision38Scale18(), DataTypes.TimestampType)
+ }
+
// CAST from StringType
test("cast StringType to BooleanType") {
@@ -1466,7 +1477,8 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
toType: DataType,
hasIncompatibleType: Boolean = false,
testAnsi: Boolean = true,
- testTry: Boolean = true): Unit = {
+ testTry: Boolean = true,
+ useDataFrameDiff: Boolean = false): Unit = {
withTempPath { dir =>
val data = roundtripParquet(input, dir).coalesce(1)
@@ -1474,22 +1486,29 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
// cast() should return null for invalid inputs when ansi mode is
disabled
val df = data.select(col("a"), col("a").cast(toType)).orderBy(col("a"))
- if (hasIncompatibleType) {
- checkSparkAnswer(df)
+ if (useDataFrameDiff) {
+ assertDataFrameEqualsWithExceptions(df, assertCometNative =
!hasIncompatibleType)
} else {
- checkSparkAnswerAndOperator(df)
+ if (hasIncompatibleType) {
+ checkSparkAnswer(df)
+ } else {
+ checkSparkAnswerAndOperator(df)
+ }
}
if (testTry) {
data.createOrReplaceTempView("t")
-// try_cast() should always return null for invalid inputs
-// not using spark DSL since it `try_cast` is only available from
Spark 4x
- val df2 =
- spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by
a")
+ // try_cast() should always return null for invalid inputs
+ // not using spark DSL since it `try_cast` is only available from
Spark 4x
+ val df2 = spark.sql(s"select a, try_cast(a as ${toType.sql}) from t
order by a")
if (hasIncompatibleType) {
checkSparkAnswer(df2)
} else {
- checkSparkAnswerAndOperator(df2)
+ if (useDataFrameDiff) {
+ assertDataFrameEqualsWithExceptions(df2, assertCometNative =
!hasIncompatibleType)
+ } else {
+ checkSparkAnswerAndOperator(df2)
+ }
}
}
}
@@ -1502,7 +1521,12 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
// cast() should throw exception on invalid inputs when ansi mode is
enabled
val df = data.withColumn("converted", col("a").cast(toType))
- checkSparkAnswerMaybeThrows(df) match {
+ val res = if (useDataFrameDiff) {
+ assertDataFrameEqualsWithExceptions(df, assertCometNative =
!hasIncompatibleType)
+ } else {
+ checkSparkAnswerMaybeThrows(df)
+ }
+ res match {
case (None, None) =>
// neither system threw an exception
case (None, Some(e)) =>
@@ -1546,12 +1570,15 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
}
+ }
- // try_cast() should always return null for invalid inputs
- if (testTry) {
- data.createOrReplaceTempView("t")
- val df2 =
- spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order
by a")
+ // try_cast() should always return null for invalid inputs
+ if (testTry) {
+ data.createOrReplaceTempView("t")
+ val df2 = spark.sql(s"select a, try_cast(a as ${toType.sql}) from t
order by a")
+ if (useDataFrameDiff) {
+ assertDataFrameEqualsWithExceptions(df2, assertCometNative =
!hasIncompatibleType)
+ } else {
if (hasIncompatibleType) {
checkSparkAnswer(df2)
} else {
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 41080ed9e..f831d53bf 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -1276,4 +1276,53 @@ abstract class CometTestBase
!usingLegacyNativeCometScan(conf) &&
CometConf.COMET_PARQUET_UNSIGNED_SMALL_INT_CHECK.get(conf)
}
+
+ /**
+ * Compares Spark and Comet results using foreach() and exceptAll() to avoid
collect()
+ */
+ protected def assertDataFrameEqualsWithExceptions(
+ df: => DataFrame,
+ assertCometNative: Boolean = true): (Option[Throwable],
Option[Throwable]) = {
+
+ var expected: Try[Unit] = null
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ expected = Try(datasetOfRows(spark, df.logicalPlan).foreach(_ => ()))
+ }
+ val actual = Try(datasetOfRows(spark, df.logicalPlan).foreach(_ => ()))
+
+ (expected, actual) match {
+ case (Success(_), Success(_)) =>
+ // compare results and confirm that they match
+ var dfSpark: DataFrame = null
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ dfSpark = datasetOfRows(spark, df.logicalPlan)
+ }
+ val dfComet = datasetOfRows(spark, df.logicalPlan)
+
+ // Compare schemas
+ assert(
+ dfSpark.schema == dfComet.schema,
+ s"Schema mismatch:\nSpark: ${dfSpark.schema}\nComet:
${dfComet.schema}")
+
+ val sparkMinusComet = dfSpark.exceptAll(dfComet)
+ val cometMinusSpark = dfComet.exceptAll(dfSpark)
+ val diffCount1 = sparkMinusComet.count()
+ val diffCount2 = cometMinusSpark.count()
+
+ if (diffCount1 > 0 || diffCount2 > 0) {
+ fail(
+ "Results do not match. " +
+ s"Rows in Spark but not Comet: $diffCount1. " +
+ s"Rows in Comet but not Spark: $diffCount2.")
+ }
+
+ if (assertCometNative) {
+
checkCometOperators(stripAQEPlan(dfComet.queryExecution.executedPlan))
+ }
+
+ (None, None)
+ case _ =>
+ (expected.failed.toOption, actual.failed.toOption)
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]