alamb commented on code in PR #15354:
URL: https://github.com/apache/datafusion/pull/15354#discussion_r2010786252


##########
datafusion/functions-nested/src/array_has.rs:
##########
@@ -121,6 +124,52 @@ impl ScalarUDFImpl for ArrayHas {
         Ok(DataType::Boolean)
     }
 
+    fn simplify(
+        &self,
+        mut args: Vec<Expr>,
+        _info: &dyn datafusion_expr::simplify::SimplifyInfo,
+    ) -> Result<ExprSimplifyResult> {
+        let [haystack, needle] = take_function_args(self.name(), &mut args)?;
+
+        // if the haystack is a constant list, we can use an inlist expression 
which is more
+        // efficient because the haystack is not varying per-row
+        if let Expr::Literal(ScalarValue::List(array)) = haystack {
+            // TODO: support LargeList
+            // (not supported by `convert_array_to_scalar_vec`)
+            // (FixedSizeList not supported either, but seems to have worked 
fine when attempting to
+            // build a reproducer)
+
+            assert_eq!(array.len(), 1); // guarantee of ScalarValue
+            if let Ok(scalar_values) =
+                ScalarValue::convert_array_to_scalar_vec(array.as_ref())
+            {
+                assert_eq!(scalar_values.len(), 1);
+                let list = scalar_values
+                    .into_iter()
+                    .flatten()
+                    .map(Expr::Literal)
+                    .collect();
+
+                return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList {
+                    expr: Box::new(std::mem::take(needle)),
+                    list,
+                    negated: false,
+                })));
+            }
+        } else if let Expr::ScalarFunction(ScalarFunction { func, args }) = 
haystack {
+            // make_array has a static set of arguments, so we can pull the 
arguments out from it

Review Comment:
   I tested removing this case and the slt tests failed like this
   
   ```diff
   Completed 113 test files in 3 seconds                                        
                                                                                
         External error: query result mismatch:
   [SQL] explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle 
FROM generate_series(1, 100000) t(i))
   select count(*) from test WHERE array_has([needle], needle);
   [Diff] (-expected|+actual)
       logical_plan
       01)Projection: count(Int64(1)) AS count(*)
       02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
       03)----SubqueryAlias: test
       04)------SubqueryAlias: t
   -   05)--------Projection:
   -   06)----------Filter: __common_expr_3 = __common_expr_3
   +   05)--------Projection:
   +   06)----------Filter: array_has(make_array(__common_expr_3), 
__common_expr_3)
       07)------------Projection: substr(CAST(md5(CAST(tmp_table.value AS 
Utf8)) AS Utf8), Int64(1), Int64(32)) AS __common_expr_3
       08)--------------TableScan: tmp_table projection=[value]
       physical_plan
       01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
       02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
       03)----CoalescePartitionsExec
       04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
       05)--------ProjectionExec: expr=[]
       06)----------CoalesceBatchesExec: target_batch_size=8192
   -   07)------------FilterExec: __common_expr_3@0 = __common_expr_3@0
   +   07)------------FilterExec: array_has(make_array(__common_expr_3@0), 
__common_expr_3@0)
       08)--------------ProjectionExec: expr=[substr(md5(CAST(value@0 AS 
Utf8)), 1, 32) as __common_expr_3]
       09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), 
input_partitions=1
       10)------------------LazyMemoryExec: partitions=1, 
batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
   at test_files/array.slt:6120
   ```
   
   I found that unexpected but don't have time to look into it more now



##########
datafusion/functions-nested/src/array_has.rs:
##########
@@ -121,6 +124,52 @@ impl ScalarUDFImpl for ArrayHas {
         Ok(DataType::Boolean)
     }
 
+    fn simplify(
+        &self,
+        mut args: Vec<Expr>,
+        _info: &dyn datafusion_expr::simplify::SimplifyInfo,
+    ) -> Result<ExprSimplifyResult> {
+        let [haystack, needle] = take_function_args(self.name(), &mut args)?;
+
+        // if the haystack is a constant list, we can use an inlist expression 
which is more
+        // efficient because the haystack is not varying per-row
+        if let Expr::Literal(ScalarValue::List(array)) = haystack {
+            // TODO: support LargeList
+            // (not supported by `convert_array_to_scalar_vec`)
+            // (FixedSizeList not supported either, but seems to have worked 
fine when attempting to
+            // build a reproducer)
+
+            assert_eq!(array.len(), 1); // guarantee of ScalarValue
+            if let Ok(scalar_values) =
+                ScalarValue::convert_array_to_scalar_vec(array.as_ref())
+            {
+                assert_eq!(scalar_values.len(), 1);
+                let list = scalar_values
+                    .into_iter()
+                    .flatten()
+                    .map(Expr::Literal)
+                    .collect();
+
+                return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList {
+                    expr: Box::new(std::mem::take(needle)),
+                    list,
+                    negated: false,
+                })));
+            }
+        } else if let Expr::ScalarFunction(ScalarFunction { func, args }) = 
haystack {
+            // make_array has a static set of arguments, so we can pull the 
arguments out from it

Review Comment:
   I would expect that during constant evaluation make_array would be turned 
into a literal so this case would be unecessary
   
   However, you wouldn't observe that simplification happening in unit tests 
(only in the slt tests when everything was put together)



##########
datafusion/sqllogictest/test_files/array.slt:
##########
@@ -5960,6 +5960,188 @@ true false true false false false true true false false 
true false true
 #----
 #true false true false false false true true false false true false true
 
+# rewrite various array_has operations to InList where the haystack is a 
literal list
+# NB that `col in (a, b, c)` is simplified to OR if there are <= 3 elements, 
so we make 4-element haystack lists
+
+query I
+with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM 
generate_series(1, 100000) t(i))
+select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278', 
'a', 'b', 'c');
+----
+1
+
+query TT
+explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM 
generate_series(1, 100000) t(i))
+select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278', 
'a', 'b', 'c');
+----
+logical_plan
+01)Projection: count(Int64(1)) AS count(*)
+02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
+03)----SubqueryAlias: test
+04)------SubqueryAlias: t
+05)--------Projection:
+06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), 
Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), 
Utf8View("a"), Utf8View("b"), Utf8View("c")])
+07)------------TableScan: tmp_table projection=[value]
+physical_plan
+01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
+02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
+03)----CoalescePartitionsExec
+04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
+05)--------ProjectionExec: expr=[]
+06)----------CoalesceBatchesExec: target_batch_size=8192
+07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN 
([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { 
value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: 
Utf8View("c") }])
+08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), 
input_partitions=1
+09)----------------LazyMemoryExec: partitions=1, 
batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
+
+query I
+with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM 
generate_series(1, 100000) t(i))
+select count(*) from test WHERE needle = 
ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']);
+----
+1
+
+query TT
+explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM 
generate_series(1, 100000) t(i))
+select count(*) from test WHERE needle = 
ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']);
+----
+logical_plan
+01)Projection: count(Int64(1)) AS count(*)
+02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
+03)----SubqueryAlias: test
+04)------SubqueryAlias: t
+05)--------Projection:
+06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), 
Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), 
Utf8View("a"), Utf8View("b"), Utf8View("c")])

Review Comment:
   that is cool to see



##########
datafusion/sqllogictest/test_files/array.slt:
##########
@@ -5960,6 +5960,188 @@ true false true false false false true true false false 
true false true
 #----
 #true false true false false false true true false false true false true
 
+# rewrite various array_has operations to InList where the haystack is a 
literal list
+# NB that `col in (a, b, c)` is simplified to OR if there are <= 3 elements, 
so we make 4-element haystack lists

Review Comment:
   👍 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to