jayzhan211 commented on code in PR #14440: URL: https://github.com/apache/datafusion/pull/14440#discussion_r1950091410
########## datafusion/expr/src/type_coercion/functions.rs: ########## @@ -596,75 +594,93 @@ fn get_valid_types( vec![vec![target_type; *num]] } } - TypeSignature::Coercible(target_types) => { - function_length_check( - function_name, - current_types.len(), - target_types.len(), - )?; - - // Aim to keep this logic as SIMPLE as possible! - // Make sure the corresponding test is covered - // If this function becomes COMPLEX, create another new signature! - fn can_coerce_to( - function_name: &str, - current_type: &DataType, - target_type_class: &TypeSignatureClass, - ) -> Result<DataType> { - let logical_type: NativeType = current_type.into(); + TypeSignature::Coercible(param_types) => { + function_length_check(function_name, current_types.len(), param_types.len())?; - match target_type_class { - TypeSignatureClass::Native(native_type) => { - let target_type = native_type.native(); - if &logical_type == target_type { - return target_type.default_cast_for(current_type); - } + let mut new_types = Vec::with_capacity(current_types.len()); + for (current_type, param) in current_types.iter().zip(param_types.iter()) { + let current_logical_type: NativeType = current_type.into(); + + fn is_matched_type( + target_type: &TypeSignatureClass, + logical_type: &NativeType, + ) -> bool { + if logical_type == &NativeType::Null { + return true; + } - if logical_type == NativeType::Null { - return target_type.default_cast_for(current_type); + match target_type { + TypeSignatureClass::Native(t) if t.native() == logical_type => { + true } - - if target_type.is_integer() && logical_type.is_integer() { - return target_type.default_cast_for(current_type); + TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { + true } - - internal_err!( - "Function '{function_name}' expects {target_type_class} but received {current_type}" - ) - } - // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp - TypeSignatureClass::Timestamp - if logical_type == NativeType::String => - { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Date if logical_type.is_date() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Time if logical_type.is_time() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Interval if logical_type.is_interval() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Duration if logical_type.is_duration() => { - Ok(current_type.to_owned()) + TypeSignatureClass::Time if logical_type.is_time() => true, + TypeSignatureClass::Interval if logical_type.is_interval() => { + true + } + TypeSignatureClass::Duration if logical_type.is_duration() => { + true + } + TypeSignatureClass::Integer if logical_type.is_integer() => true, + _ => false, } - _ => { - not_impl_err!("Function '{function_name}' got logical_type: {logical_type} with target_type_class: {target_type_class}") + } + + fn default_casted_type( + signature_class: &TypeSignatureClass, + logical_type: &NativeType, + origin_type: &DataType, + ) -> Result<DataType> { + match signature_class { + TypeSignatureClass::Native(logical_type) => { + logical_type.native().default_cast_for(origin_type) + } + // If the given type is already a timestamp, we don't change the unit and timezone + TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Time if logical_type.is_time() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Interval if logical_type.is_interval() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Duration if logical_type.is_duration() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Integer if logical_type.is_integer() => { + Ok(origin_type.to_owned()) + } + _ => internal_err!("May miss the matching logic in `is_matched_type`"), } } - } - let mut new_types = Vec::with_capacity(current_types.len()); - for (current_type, target_type_class) in - current_types.iter().zip(target_types.iter()) - { - let target_type = can_coerce_to(function_name, current_type, target_type_class)?; - new_types.push(target_type); + if is_matched_type(¶m.desired_type, ¤t_logical_type) { + let casted_type = default_casted_type( + ¶m.desired_type, + ¤t_logical_type, + current_type, + )?; + + new_types.push(casted_type); + } else if param + .allowed_source_types() + .iter() + .any(|t| is_matched_type(t, ¤t_logical_type)) { + // If the condition is met which means `implicit coercion`` is provided so we can safely unwrap + let default_casted_type = param.default_casted_type().unwrap(); + let casted_type = default_casted_type.default_cast_for(current_type)?; Review Comment: Didn't find more simplified logic -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org