alamb commented on code in PR #15354:
URL: https://github.com/apache/datafusion/pull/15354#discussion_r2010786252
##########
datafusion/functions-nested/src/array_has.rs:
##########
@@ -121,6 +124,52 @@ impl ScalarUDFImpl for ArrayHas {
Ok(DataType::Boolean)
}
+ fn simplify(
+ &self,
+ mut args: Vec<Expr>,
+ _info: &dyn datafusion_expr::simplify::SimplifyInfo,
+ ) -> Result<ExprSimplifyResult> {
+ let [haystack, needle] = take_function_args(self.name(), &mut args)?;
+
+ // if the haystack is a constant list, we can use an inlist expression
which is more
+ // efficient because the haystack is not varying per-row
+ if let Expr::Literal(ScalarValue::List(array)) = haystack {
+ // TODO: support LargeList
+ // (not supported by `convert_array_to_scalar_vec`)
+ // (FixedSizeList not supported either, but seems to have worked
fine when attempting to
+ // build a reproducer)
+
+ assert_eq!(array.len(), 1); // guarantee of ScalarValue
+ if let Ok(scalar_values) =
+ ScalarValue::convert_array_to_scalar_vec(array.as_ref())
+ {
+ assert_eq!(scalar_values.len(), 1);
+ let list = scalar_values
+ .into_iter()
+ .flatten()
+ .map(Expr::Literal)
+ .collect();
+
+ return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList {
+ expr: Box::new(std::mem::take(needle)),
+ list,
+ negated: false,
+ })));
+ }
+ } else if let Expr::ScalarFunction(ScalarFunction { func, args }) =
haystack {
+ // make_array has a static set of arguments, so we can pull the
arguments out from it
Review Comment:
I tested removing this case and the slt tests failed like this
```diff
Completed 113 test files in 3 seconds
External error: query result mismatch:
[SQL] explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle
FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE array_has([needle], needle);
[Diff] (-expected|+actual)
logical_plan
01)Projection: count(Int64(1)) AS count(*)
02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
03)----SubqueryAlias: test
04)------SubqueryAlias: t
- 05)--------Projection:
- 06)----------Filter: __common_expr_3 = __common_expr_3
+ 05)--------Projection:
+ 06)----------Filter: array_has(make_array(__common_expr_3),
__common_expr_3)
07)------------Projection: substr(CAST(md5(CAST(tmp_table.value AS
Utf8)) AS Utf8), Int64(1), Int64(32)) AS __common_expr_3
08)--------------TableScan: tmp_table projection=[value]
physical_plan
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
03)----CoalescePartitionsExec
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
05)--------ProjectionExec: expr=[]
06)----------CoalesceBatchesExec: target_batch_size=8192
- 07)------------FilterExec: __common_expr_3@0 = __common_expr_3@0
+ 07)------------FilterExec: array_has(make_array(__common_expr_3@0),
__common_expr_3@0)
08)--------------ProjectionExec: expr=[substr(md5(CAST(value@0 AS
Utf8)), 1, 32) as __common_expr_3]
09)----------------RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1
10)------------------LazyMemoryExec: partitions=1,
batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
at test_files/array.slt:6120
```
I found that unexpected but don't have time to look into it more now
##########
datafusion/functions-nested/src/array_has.rs:
##########
@@ -121,6 +124,52 @@ impl ScalarUDFImpl for ArrayHas {
Ok(DataType::Boolean)
}
+ fn simplify(
+ &self,
+ mut args: Vec<Expr>,
+ _info: &dyn datafusion_expr::simplify::SimplifyInfo,
+ ) -> Result<ExprSimplifyResult> {
+ let [haystack, needle] = take_function_args(self.name(), &mut args)?;
+
+ // if the haystack is a constant list, we can use an inlist expression
which is more
+ // efficient because the haystack is not varying per-row
+ if let Expr::Literal(ScalarValue::List(array)) = haystack {
+ // TODO: support LargeList
+ // (not supported by `convert_array_to_scalar_vec`)
+ // (FixedSizeList not supported either, but seems to have worked
fine when attempting to
+ // build a reproducer)
+
+ assert_eq!(array.len(), 1); // guarantee of ScalarValue
+ if let Ok(scalar_values) =
+ ScalarValue::convert_array_to_scalar_vec(array.as_ref())
+ {
+ assert_eq!(scalar_values.len(), 1);
+ let list = scalar_values
+ .into_iter()
+ .flatten()
+ .map(Expr::Literal)
+ .collect();
+
+ return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList {
+ expr: Box::new(std::mem::take(needle)),
+ list,
+ negated: false,
+ })));
+ }
+ } else if let Expr::ScalarFunction(ScalarFunction { func, args }) =
haystack {
+ // make_array has a static set of arguments, so we can pull the
arguments out from it
Review Comment:
I would expect that during constant evaluation make_array would be turned
into a literal so this case would be unecessary
However, you wouldn't observe that simplification happening in unit tests
(only in the slt tests when everything was put together)
##########
datafusion/sqllogictest/test_files/array.slt:
##########
@@ -5960,6 +5960,188 @@ true false true false false false true true false false
true false true
#----
#true false true false false false true true false false true false true
+# rewrite various array_has operations to InList where the haystack is a
literal list
+# NB that `col in (a, b, c)` is simplified to OR if there are <= 3 elements,
so we make 4-element haystack lists
+
+query I
+with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278',
'a', 'b', 'c');
+----
+1
+
+query TT
+explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278',
'a', 'b', 'c');
+----
+logical_plan
+01)Projection: count(Int64(1)) AS count(*)
+02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
+03)----SubqueryAlias: test
+04)------SubqueryAlias: t
+05)--------Projection:
+06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8),
Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"),
Utf8View("a"), Utf8View("b"), Utf8View("c")])
+07)------------TableScan: tmp_table projection=[value]
+physical_plan
+01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
+02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
+03)----CoalescePartitionsExec
+04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
+05)--------ProjectionExec: expr=[]
+06)----------CoalesceBatchesExec: target_batch_size=8192
+07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN
([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal {
value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value:
Utf8View("c") }])
+08)--------------RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1
+09)----------------LazyMemoryExec: partitions=1,
batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
+
+query I
+with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE needle =
ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']);
+----
+1
+
+query TT
+explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE needle =
ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']);
+----
+logical_plan
+01)Projection: count(Int64(1)) AS count(*)
+02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
+03)----SubqueryAlias: test
+04)------SubqueryAlias: t
+05)--------Projection:
+06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8),
Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"),
Utf8View("a"), Utf8View("b"), Utf8View("c")])
Review Comment:
that is cool to see
##########
datafusion/sqllogictest/test_files/array.slt:
##########
@@ -5960,6 +5960,188 @@ true false true false false false true true false false
true false true
#----
#true false true false false false true true false false true false true
+# rewrite various array_has operations to InList where the haystack is a
literal list
+# NB that `col in (a, b, c)` is simplified to OR if there are <= 3 elements,
so we make 4-element haystack lists
Review Comment:
👍
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]