geoffreyclaude commented on code in PR #18832:
URL: https://github.com/apache/datafusion/pull/18832#discussion_r2610778034
##########
datafusion/physical-expr/src/expressions/in_list.rs:
##########
@@ -198,99 +213,302 @@ impl ArrayStaticFilter {
}
}
-struct Int32StaticFilter {
- null_count: usize,
- values: HashSet<i32>,
+/// Wrapper for f32 that implements Hash and Eq using IEEE 754 total ordering.
+/// This treats NaN values as equal to each other (using total_cmp).
+#[derive(Clone, Copy)]
+struct OrderedFloat32(f32);
+
+impl Hash for OrderedFloat32 {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ self.0.to_ne_bytes().hash(state);
+ }
}
-impl Int32StaticFilter {
- fn try_new(in_array: &ArrayRef) -> Result<Self> {
- let in_array = in_array
- .as_primitive_opt::<Int32Type>()
- .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?;
+impl PartialEq for OrderedFloat32 {
+ fn eq(&self, other: &Self) -> bool {
+ self.0.total_cmp(&other.0).is_eq()
+ }
+}
- let mut values = HashSet::with_capacity(in_array.len());
- let null_count = in_array.null_count();
+impl Eq for OrderedFloat32 {}
- for v in in_array.iter().flatten() {
- values.insert(v);
- }
+impl From<f32> for OrderedFloat32 {
+ fn from(v: f32) -> Self {
+ Self(v)
+ }
+}
- Ok(Self { null_count, values })
+/// Wrapper for f64 that implements Hash and Eq using IEEE 754 total ordering.
+/// This treats NaN values as equal to each other (using total_cmp).
+#[derive(Clone, Copy)]
+struct OrderedFloat64(f64);
+
+impl Hash for OrderedFloat64 {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ self.0.to_ne_bytes().hash(state);
}
}
-impl StaticFilter for Int32StaticFilter {
- fn null_count(&self) -> usize {
- self.null_count
+impl PartialEq for OrderedFloat64 {
+ fn eq(&self, other: &Self) -> bool {
+ self.0.total_cmp(&other.0).is_eq()
}
+}
- fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
- // Handle dictionary arrays by recursing on the values
- downcast_dictionary_array! {
- v => {
- let values_contains = self.contains(v.values().as_ref(),
negated)?;
- let result = take(&values_contains, v.keys(), None)?;
- return Ok(downcast_array(result.as_ref()))
+impl Eq for OrderedFloat64 {}
+
+impl From<f64> for OrderedFloat64 {
+ fn from(v: f64) -> Self {
+ Self(v)
+ }
+}
+
+// Macro to generate specialized StaticFilter implementations for primitive
types
+macro_rules! primitive_static_filter {
+ ($Name:ident, $ArrowType:ty) => {
+ struct $Name {
+ null_count: usize,
+ values: HashSet<<$ArrowType as ArrowPrimitiveType>::Native>,
+ }
+
+ impl $Name {
+ fn try_new(in_array: &ArrayRef) -> Result<Self> {
+ let in_array = in_array
+ .as_primitive_opt::<$ArrowType>()
+ .ok_or_else(|| exec_datafusion_err!("Failed to downcast an
array to a '{}' array", stringify!($ArrowType)))?;
+
+ let mut values = HashSet::with_capacity(in_array.len());
+ let null_count = in_array.null_count();
+
+ for v in in_array.iter().flatten() {
+ values.insert(v);
+ }
+
+ Ok(Self { null_count, values })
}
- _ => {}
}
- let v = v
- .as_primitive_opt::<Int32Type>()
- .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?;
-
- let haystack_has_nulls = self.null_count > 0;
- let has_nulls = v.null_count() > 0 || haystack_has_nulls;
-
- let result = match (has_nulls, negated) {
- (true, false) => {
- // needle has nulls, not negated
- BooleanArray::from_iter(v.iter().map(|value| match value {
- None => None,
- Some(v) => {
- if self.values.contains(&v) {
- Some(true)
- } else if haystack_has_nulls {
- None
+ impl StaticFilter for $Name {
+ fn null_count(&self) -> usize {
+ self.null_count
+ }
+
+ fn contains(&self, v: &dyn Array, negated: bool) ->
Result<BooleanArray> {
+ // Handle dictionary arrays by recursing on the values
+ downcast_dictionary_array! {
+ v => {
+ let values_contains =
self.contains(v.values().as_ref(), negated)?;
+ let result = take(&values_contains, v.keys(), None)?;
+ return Ok(downcast_array(result.as_ref()))
+ }
+ _ => {}
+ }
+
+ let v = v
+ .as_primitive_opt::<$ArrowType>()
+ .ok_or_else(|| exec_datafusion_err!("Failed to downcast an
array to a '{}' array", stringify!($ArrowType)))?;
+
+ let haystack_has_nulls = self.null_count > 0;
+
+ let needle_values = v.values();
+ let needle_nulls = v.nulls();
+ let needle_has_nulls = v.null_count() > 0;
+
+ // Compute the "contains" result using collect_bool (fast
batched approach)
+ // This ignores nulls - we handle them separately
+ let contains_buffer = if negated {
+ BooleanBuffer::collect_bool(needle_values.len(), |i| {
+ !self.values.contains(&needle_values[i])
+ })
+ } else {
+ BooleanBuffer::collect_bool(needle_values.len(), |i| {
+ self.values.contains(&needle_values[i])
+ })
+ };
+
+ // Compute the null mask
+ // Output is null when:
+ // 1. needle value is null, OR
+ // 2. needle value is not in set AND haystack has nulls
+ let result_nulls = match (needle_has_nulls,
haystack_has_nulls) {
+ (false, false) => {
+ // No nulls anywhere
+ None
+ }
+ (true, false) => {
+ // Only needle has nulls - just use needle's null mask
+ needle_nulls.cloned()
+ }
+ (false, true) => {
+ // Only haystack has nulls - null where not-in-set
Review Comment:
This comment block is a bit verbose and not super clear. Maybe having it in
table form would help?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]