This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.1 by this push:
     new e4d0c672a30 branch-2.1: [fix](Nereids) cse extract wrong expression 
from lambda expressions (#49166) (#49942)
e4d0c672a30 is described below

commit e4d0c672a3083006ab4ba090258b862240c8265f
Author: 924060929 <lanhuaj...@selectdb.com>
AuthorDate: Thu Apr 10 23:22:55 2025 +0800

    branch-2.1: [fix](Nereids) cse extract wrong expression from lambda 
expressions (#49166) (#49942)
    
    cherry pick from #49166
    
    Co-authored-by: morrySnow <zhangwen...@selectdb.com>
---
 .../post/CommonSubExpressionCollector.java         | 24 ++++++++---
 .../processor/post/CommonSubExpressionOpt.java     |  2 +-
 .../postprocess/CommonSubExpressionTest.java       | 46 ++++++++++++++++------
 .../suites/nereids_rules_p0/cse/cse.groovy         | 39 ++++++++++++++++++
 4 files changed, 92 insertions(+), 19 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java
index 877e411a539..520902c0439 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java
@@ -17,7 +17,10 @@
 
 package org.apache.doris.nereids.processor.post;
 
+import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
+import 
org.apache.doris.nereids.trees.expressions.ArrayItemReference.ArrayItemSlot;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 
 import java.util.HashMap;
@@ -28,29 +31,38 @@ import java.util.Set;
 /**
  * collect common expr
  */
-public class CommonSubExpressionCollector extends ExpressionVisitor<Integer, 
Void> {
+public class CommonSubExpressionCollector extends ExpressionVisitor<Integer, 
Boolean> {
     public final Map<Integer, Set<Expression>> commonExprByDepth = new 
HashMap<>();
     private final Map<Integer, Set<Expression>> expressionsByDepth = new 
HashMap<>();
 
+    public int collect(Expression expr) {
+        return expr.accept(this, expr instanceof Lambda);
+    }
+
     @Override
-    public Integer visit(Expression expr, Void context) {
+    public Integer visit(Expression expr, Boolean inLambda) {
         if (expr.children().isEmpty()) {
             return 0;
         }
         return collectCommonExpressionByDepth(
                 expr.children()
                         .stream()
-                        .map(child -> child.accept(this, context))
+                        .map(child -> child.accept(this, inLambda == null || 
inLambda || child instanceof Lambda))
                         .reduce(Math::max)
                         .map(m -> m + 1)
                         .orElse(1),
-                expr
+                expr,
+                inLambda == null || inLambda
         );
     }
 
-    private int collectCommonExpressionByDepth(int depth, Expression expr) {
+    private int collectCommonExpressionByDepth(int depth, Expression expr, 
boolean inLambda) {
         Set<Expression> expressions = getExpressionsFromDepthMap(depth, 
expressionsByDepth);
-        if (expressions.contains(expr)) {
+        // ArrayItemSlot and ArrayItemReference could not be common expressions
+        // TODO: could not extract common expression when expression contains 
same lambda expression
+        //   because ArrayItemSlot in Lambda are not same.
+        if (expressions.contains(expr)
+                && !(inLambda && expr.containsType(ArrayItemSlot.class, 
ArrayItemReference.class))) {
             Set<Expression> commonExpression = 
getExpressionsFromDepthMap(depth, commonExprByDepth);
             commonExpression.add(expr);
         }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java
index fca84167994..82239f2c2c9 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java
@@ -65,7 +65,7 @@ public class CommonSubExpressionOpt extends PlanPostProcessor 
{
         List<List<NamedExpression>> multiLayers = Lists.newArrayList();
         CommonSubExpressionCollector collector = new 
CommonSubExpressionCollector();
         for (Expression expr : projects) {
-            expr.accept(collector, null);
+            collector.collect(expr);
         }
         Map<Expression, Alias> aliasMap = new HashMap<>();
         if (!collector.commonExprByDepth.isEmpty()) {
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java
index 56b67e087d5..c9a20f65e69 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java
@@ -20,14 +20,24 @@ package org.apache.doris.nereids.postprocess;
 import org.apache.doris.nereids.processor.post.CommonSubExpressionCollector;
 import org.apache.doris.nereids.processor.post.CommonSubExpressionOpt;
 import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
+import org.apache.doris.nereids.trees.expressions.Add;
 import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.And;
+import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
+import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.types.ArrayType;
 import org.apache.doris.nereids.types.IntegerType;
 
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
@@ -37,27 +47,40 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.stream.Collectors;
 
 public class CommonSubExpressionTest extends ExpressionRewriteTestHelper {
     @Test
     public void testExtractCommonExpr() {
         List<NamedExpression> exprs = parseProjections("a+b, a+b+1, 
abs(a+b+1), a");
-        CommonSubExpressionCollector collector =
-                new CommonSubExpressionCollector();
+        CommonSubExpressionCollector collector = new 
CommonSubExpressionCollector();
         exprs.forEach(expr -> collector.visit(expr, null));
-        System.out.println(collector.commonExprByDepth);
         Assertions.assertEquals(2, collector.commonExprByDepth.size());
-        List<Expression> l1 = 
collector.commonExprByDepth.get(Integer.valueOf(1))
-                .stream().collect(Collectors.toList());
-        List<Expression> l2 = 
collector.commonExprByDepth.get(Integer.valueOf(2))
-                .stream().collect(Collectors.toList());
+        List<Expression> l1 = new 
ArrayList<>(collector.commonExprByDepth.get(1));
+        List<Expression> l2 = new 
ArrayList<>(collector.commonExprByDepth.get(2));
         Assertions.assertEquals(1, l1.size());
         assertExpression(l1.get(0), "a+b");
         Assertions.assertEquals(1, l2.size());
         assertExpression(l2.get(0), "a+b+1");
     }
 
+    @Test
+    void testLambdaExpression() {
+        ArrayItemReference ref = new ArrayItemReference("x", new 
SlotReference(new ExprId(1), "y",
+                ArrayType.of(IntegerType.INSTANCE), true, ImmutableList.of()));
+        Expression add = new Add(ref.toSlot(), Literal.of(1));
+        Expression and = new And(add, add);
+        ArrayMap arrayMap = new ArrayMap(new Lambda(ImmutableList.of("x"), 
and, ImmutableList.of(ref)));
+        List<NamedExpression> exprs = Lists.newArrayList(
+                new Alias(new ExprId(10000), arrayMap, "c1"),
+                new Alias(new ExprId(10001), arrayMap, "c2")
+        );
+        CommonSubExpressionCollector collector = new 
CommonSubExpressionCollector();
+        exprs.forEach(expr -> collector.visit(expr, false));
+        Assertions.assertEquals(1, collector.commonExprByDepth.size());
+        Assertions.assertEquals(1, collector.commonExprByDepth.get(4).size());
+        Assertions.assertEquals(arrayMap, 
collector.commonExprByDepth.get(4).iterator().next());
+    }
+
     @Test
     public void testMultiLayers() throws Exception {
         List<NamedExpression> exprs = parseProjections("a, a+b, a+b+1, 
abs(a+b+1), a");
@@ -68,15 +91,14 @@ public class CommonSubExpressionTest extends 
ExpressionRewriteTestHelper {
         computeMultLayerProjectionsMethod.setAccessible(true);
         List<List<NamedExpression>> multiLayers = 
(List<List<NamedExpression>>) computeMultLayerProjectionsMethod
                 .invoke(opt, inputSlots, exprs);
-        System.out.println(multiLayers);
         Assertions.assertEquals(3, multiLayers.size());
         List<NamedExpression> l0 = multiLayers.get(0);
         Assertions.assertEquals(2, l0.size());
         
Assertions.assertTrue(l0.contains(ExprParser.INSTANCE.parseExpression("a")));
-        Assertions.assertTrue(l0.get(1) instanceof Alias);
+        Assertions.assertInstanceOf(Alias.class, l0.get(1));
         assertExpression(l0.get(1).child(0), "a+b");
-        Assertions.assertEquals(multiLayers.get(1).size(), 3);
-        Assertions.assertEquals(multiLayers.get(2).size(), 5);
+        Assertions.assertEquals(3, multiLayers.get(1).size());
+        Assertions.assertEquals(5, multiLayers.get(2).size());
         List<NamedExpression> l2 = multiLayers.get(2);
         for (int i = 0; i < 5; i++) {
             Assertions.assertEquals(exprs.get(i).getExprId().asInt(), 
l2.get(i).getExprId().asInt());
diff --git a/regression-test/suites/nereids_rules_p0/cse/cse.groovy 
b/regression-test/suites/nereids_rules_p0/cse/cse.groovy
new file mode 100644
index 00000000000..07c83dcbacc
--- /dev/null
+++ b/regression-test/suites/nereids_rules_p0/cse/cse.groovy
@@ -0,0 +1,39 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+suite("cse") {
+    // cse should not extract expression use for lambda, such as ArrayItemSlot 
and ArrayItemReference
+    sql """
+        drop table if exists array_cse;
+    """
+    sql """
+        create table array_cse(c1 int, c2 array<varchar(255)>) PROPERTIES 
("replication_allocation" = "tag.location.default: 1");
+    """
+    sql """
+        insert into array_cse values(1, [1,2,3]);
+    """
+    sql """
+        sync
+    """
+    sql """
+        SELECT array_map(x-> if(left(x, 5) = '12345', x, left(x, 5)), c2) FROM 
array_cse;
+    """
+    sql """
+        SELECT c0, c0 FROM (SELECT ARRAY_MAP(x-> if(left(x, 5), x, left(x, 
5)), `c2`) as `c0` FROM array_cse) t
+    """
+    
+}
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to