alamb commented on code in PR #15354: URL: https://github.com/apache/datafusion/pull/15354#discussion_r2010786252
########## datafusion/functions-nested/src/array_has.rs: ########## @@ -121,6 +124,52 @@ impl ScalarUDFImpl for ArrayHas { Ok(DataType::Boolean) } + fn simplify( + &self, + mut args: Vec<Expr>, + _info: &dyn datafusion_expr::simplify::SimplifyInfo, + ) -> Result<ExprSimplifyResult> { + let [haystack, needle] = take_function_args(self.name(), &mut args)?; + + // if the haystack is a constant list, we can use an inlist expression which is more + // efficient because the haystack is not varying per-row + if let Expr::Literal(ScalarValue::List(array)) = haystack { + // TODO: support LargeList + // (not supported by `convert_array_to_scalar_vec`) + // (FixedSizeList not supported either, but seems to have worked fine when attempting to + // build a reproducer) + + assert_eq!(array.len(), 1); // guarantee of ScalarValue + if let Ok(scalar_values) = + ScalarValue::convert_array_to_scalar_vec(array.as_ref()) + { + assert_eq!(scalar_values.len(), 1); + let list = scalar_values + .into_iter() + .flatten() + .map(Expr::Literal) + .collect(); + + return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList { + expr: Box::new(std::mem::take(needle)), + list, + negated: false, + }))); + } + } else if let Expr::ScalarFunction(ScalarFunction { func, args }) = haystack { + // make_array has a static set of arguments, so we can pull the arguments out from it Review Comment: I tested removing this case and the slt tests failed like this ```diff Completed 113 test files in 3 seconds External error: query result mismatch: [SQL] explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE array_has([needle], needle); [Diff] (-expected|+actual) logical_plan 01)Projection: count(Int64(1)) AS count(*) 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----SubqueryAlias: test 04)------SubqueryAlias: t - 05)--------Projection: - 06)----------Filter: __common_expr_3 = __common_expr_3 + 05)--------Projection: + 06)----------Filter: array_has(make_array(__common_expr_3), __common_expr_3) 07)------------Projection: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) AS __common_expr_3 08)--------------TableScan: tmp_table projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 - 07)------------FilterExec: __common_expr_3@0 = __common_expr_3@0 + 07)------------FilterExec: array_has(make_array(__common_expr_3@0), __common_expr_3@0) 08)--------------ProjectionExec: expr=[substr(md5(CAST(value@0 AS Utf8)), 1, 32) as __common_expr_3] 09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 10)------------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] at test_files/array.slt:6120 ``` I found that unexpected but don't have time to look into it more now ########## datafusion/functions-nested/src/array_has.rs: ########## @@ -121,6 +124,52 @@ impl ScalarUDFImpl for ArrayHas { Ok(DataType::Boolean) } + fn simplify( + &self, + mut args: Vec<Expr>, + _info: &dyn datafusion_expr::simplify::SimplifyInfo, + ) -> Result<ExprSimplifyResult> { + let [haystack, needle] = take_function_args(self.name(), &mut args)?; + + // if the haystack is a constant list, we can use an inlist expression which is more + // efficient because the haystack is not varying per-row + if let Expr::Literal(ScalarValue::List(array)) = haystack { + // TODO: support LargeList + // (not supported by `convert_array_to_scalar_vec`) + // (FixedSizeList not supported either, but seems to have worked fine when attempting to + // build a reproducer) + + assert_eq!(array.len(), 1); // guarantee of ScalarValue + if let Ok(scalar_values) = + ScalarValue::convert_array_to_scalar_vec(array.as_ref()) + { + assert_eq!(scalar_values.len(), 1); + let list = scalar_values + .into_iter() + .flatten() + .map(Expr::Literal) + .collect(); + + return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList { + expr: Box::new(std::mem::take(needle)), + list, + negated: false, + }))); + } + } else if let Expr::ScalarFunction(ScalarFunction { func, args }) = haystack { + // make_array has a static set of arguments, so we can pull the arguments out from it Review Comment: I would expect that during constant evaluation make_array would be turned into a literal so this case would be unecessary However, you wouldn't observe that simplification happening in unit tests (only in the slt tests when everything was put together) ########## datafusion/sqllogictest/test_files/array.slt: ########## @@ -5960,6 +5960,188 @@ true false true false false false true true false false true false true #---- #true false true false false false true true false false true false true +# rewrite various array_has operations to InList where the haystack is a literal list +# NB that `col in (a, b, c)` is simplified to OR if there are <= 3 elements, so we make 4-element haystack lists + +query I +with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'); +---- +1 + +query TT +explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'); +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----SubqueryAlias: test +04)------SubqueryAlias: t +05)--------Projection: +06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: tmp_table projection=[value] +physical_plan +01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] +02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] +05)--------ProjectionExec: expr=[] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] + +query I +with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE needle = ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']); +---- +1 + +query TT +explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE needle = ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']); +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----SubqueryAlias: test +04)------SubqueryAlias: t +05)--------Projection: +06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) Review Comment: that is cool to see ########## datafusion/sqllogictest/test_files/array.slt: ########## @@ -5960,6 +5960,188 @@ true false true false false false true true false false true false true #---- #true false true false false false true true false false true false true +# rewrite various array_has operations to InList where the haystack is a literal list +# NB that `col in (a, b, c)` is simplified to OR if there are <= 3 elements, so we make 4-element haystack lists Review Comment: 👍 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org