findepi commented on code in PR #15110: URL: https://github.com/apache/datafusion/pull/15110#discussion_r2005577033
########## datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs: ########## @@ -215,6 +230,69 @@ fn is_supported_dictionary_type(data_type: &DataType) -> bool { DataType::Dictionary(_, inner) if is_supported_type(inner)) } +/// Try to move a cast from a column to the other side of a `=` / `!=` operator +/// +/// Specifically, rewrites +/// ```sql +/// cast(col) <op> <literal> +/// ``` +/// +/// To +/// +/// ```sql +/// col <op> cast(<literal>) +/// col <op> <casted_literal> +/// ``` +fn cast_literal_to_type_with_op( + lit_value: &ScalarValue, + target_type: &DataType, + op: Operator, +) -> Option<ScalarValue> { + match (op, lit_value) { + ( + Operator::Eq | Operator::NotEq, + ScalarValue::Utf8(Some(_)) + | ScalarValue::Utf8View(Some(_)) + | ScalarValue::LargeUtf8(Some(_)), + ) => { + // Only try for integer types (TODO can we do this for other types + // like timestamps)? + use DataType::*; + if matches!( + target_type, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 + ) { + let opts = arrow::compute::CastOptions { + safe: false, + format_options: Default::default(), + }; + + let array = ScalarValue::to_array(lit_value).ok()?; + let casted = + arrow::compute::cast_with_options(&array, target_type, &opts).ok()?; Review Comment: ```suggestion let cast = lit_value.cast_to(target_type).ok()?; ``` ########## datafusion/optimizer/src/analyzer/type_coercion.rs: ########## @@ -290,19 +290,72 @@ impl<'a> TypeCoercionRewriter<'a> { right: Expr, right_schema: &DFSchema, ) -> Result<(Expr, Expr)> { + if let Expr::Literal(ref lit_value) = left { + if let Some(casted) = + try_cast_literal_to_type(lit_value, op, &right.get_type(right_schema)?) + { + return Ok((casted, right)); + }; + } + + if let Expr::Literal(ref lit_value) = right { + if let Some(casted) = + try_cast_literal_to_type(lit_value, op, &left.get_type(left_schema)?) + { + return Ok((left, casted)); + }; + } + let (left_type, right_type) = BinaryTypeCoercer::new( &left.get_type(left_schema)?, &op, &right.get_type(right_schema)?, ) .get_input_types()?; + Ok(( left.cast_to(&left_type, left_schema)?, right.cast_to(&right_type, right_schema)?, )) } } +fn try_cast_literal_to_type( + lit_value: &ScalarValue, + op: Operator, + target_type: &DataType, +) -> Option<Expr> { + match (op, lit_value) { + ( + Operator::Eq | Operator::NotEq, + ScalarValue::Utf8(Some(_)) + | ScalarValue::Utf8View(Some(_)) + | ScalarValue::LargeUtf8(Some(_)), + ) => { + // Only try for integer types (TODO can we do this for other types like timestamps)? + use DataType::*; + if matches!( + target_type, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 + ) { + let opts = arrow::compute::CastOptions { + safe: false, + format_options: Default::default(), + }; + let array = ScalarValue::to_array(lit_value).ok()?; + let casted = + arrow::compute::cast_with_options(&array, target_type, &opts).ok()?; Review Comment: ```suggestion let cast = lit_value.cast_to(target_type).ok()?; ``` ########## datafusion/optimizer/src/analyzer/type_coercion.rs: ########## @@ -290,19 +290,72 @@ impl<'a> TypeCoercionRewriter<'a> { right: Expr, right_schema: &DFSchema, ) -> Result<(Expr, Expr)> { + if let Expr::Literal(ref lit_value) = left { + if let Some(casted) = + try_cast_literal_to_type(lit_value, op, &right.get_type(right_schema)?) + { + return Ok((casted, right)); + }; + } + + if let Expr::Literal(ref lit_value) = right { + if let Some(casted) = + try_cast_literal_to_type(lit_value, op, &left.get_type(left_schema)?) Review Comment: We swapped left and right, but did nothing to `op`. This is correct only when op is symmetric (eg `=`, `<>` but not `>` or `<`) ########## datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs: ########## @@ -215,6 +230,69 @@ fn is_supported_dictionary_type(data_type: &DataType) -> bool { DataType::Dictionary(_, inner) if is_supported_type(inner)) } +/// Try to move a cast from a column to the other side of a `=` / `!=` operator +/// +/// Specifically, rewrites +/// ```sql +/// cast(col) <op> <literal> +/// ``` +/// +/// To +/// +/// ```sql +/// col <op> cast(<literal>) +/// col <op> <casted_literal> +/// ``` +fn cast_literal_to_type_with_op( + lit_value: &ScalarValue, + target_type: &DataType, + op: Operator, +) -> Option<ScalarValue> { + match (op, lit_value) { + ( + Operator::Eq | Operator::NotEq, + ScalarValue::Utf8(Some(_)) + | ScalarValue::Utf8View(Some(_)) + | ScalarValue::LargeUtf8(Some(_)), + ) => { + // Only try for integer types (TODO can we do this for other types + // like timestamps)? + use DataType::*; + if matches!( + target_type, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 + ) { + let opts = arrow::compute::CastOptions { + safe: false, + format_options: Default::default(), + }; + + let array = ScalarValue::to_array(lit_value).ok()?; + let casted = + arrow::compute::cast_with_options(&array, target_type, &opts).ok()?; + + // Perform a round-trip cast: literal -> target_type -> original_type + // Ensures cast expressions involving values like '0123' are not unwrapped for correctness (e.g., `cast(c1, UTF8) = '0123'`) + let round_tripped = arrow::compute::cast_with_options( + &casted, + &lit_value.data_type(), + &opts, + ) + .ok()?; + + if array != round_tripped { + return None; + } Review Comment: ```suggestion let round_tripped = cast.cast_to(&lit_value.data_type()).ok()?; if lit_value != &round_tripped { return None; } ``` ########## datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs: ########## @@ -215,6 +230,69 @@ fn is_supported_dictionary_type(data_type: &DataType) -> bool { DataType::Dictionary(_, inner) if is_supported_type(inner)) } +/// Try to move a cast from a column to the other side of a `=` / `!=` operator Review Comment: If the function was limited to `=` and `!=` it would take a boolean, or two-valued enum instead of `op: Operator`. ```suggestion /// Tries to move a cast from an expression (such as column) to the literal other side of a comparison operator. ``` ########## datafusion/optimizer/src/analyzer/type_coercion.rs: ########## @@ -290,19 +290,72 @@ impl<'a> TypeCoercionRewriter<'a> { right: Expr, right_schema: &DFSchema, ) -> Result<(Expr, Expr)> { + if let Expr::Literal(ref lit_value) = left { + if let Some(casted) = + try_cast_literal_to_type(lit_value, op, &right.get_type(right_schema)?) + { + return Ok((casted, right)); + }; + } + + if let Expr::Literal(ref lit_value) = right { + if let Some(casted) = + try_cast_literal_to_type(lit_value, op, &left.get_type(left_schema)?) + { + return Ok((left, casted)); + }; + } + let (left_type, right_type) = BinaryTypeCoercer::new( &left.get_type(left_schema)?, &op, &right.get_type(right_schema)?, ) .get_input_types()?; + Ok(( left.cast_to(&left_type, left_schema)?, right.cast_to(&right_type, right_schema)?, )) } } +fn try_cast_literal_to_type( + lit_value: &ScalarValue, + op: Operator, + target_type: &DataType, +) -> Option<Expr> { + match (op, lit_value) { + ( + Operator::Eq | Operator::NotEq, + ScalarValue::Utf8(Some(_)) + | ScalarValue::Utf8View(Some(_)) + | ScalarValue::LargeUtf8(Some(_)), + ) => { + // Only try for integer types (TODO can we do this for other types like timestamps)? + use DataType::*; + if matches!( + target_type, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 + ) { + let opts = arrow::compute::CastOptions { + safe: false, + format_options: Default::default(), + }; + let array = ScalarValue::to_array(lit_value).ok()?; + let casted = + arrow::compute::cast_with_options(&array, target_type, &opts).ok()?; + ScalarValue::try_from_array(&casted, 0) Review Comment: Any reason this code doesn't check for round-tripping as the other code in `unwrap_cast.rs` does? Also, do we need to have the same/similar logic in two files? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org