This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.1 by this push: new e4d0c672a30 branch-2.1: [fix](Nereids) cse extract wrong expression from lambda expressions (#49166) (#49942) e4d0c672a30 is described below commit e4d0c672a3083006ab4ba090258b862240c8265f Author: 924060929 <lanhuaj...@selectdb.com> AuthorDate: Thu Apr 10 23:22:55 2025 +0800 branch-2.1: [fix](Nereids) cse extract wrong expression from lambda expressions (#49166) (#49942) cherry pick from #49166 Co-authored-by: morrySnow <zhangwen...@selectdb.com> --- .../post/CommonSubExpressionCollector.java | 24 ++++++++--- .../processor/post/CommonSubExpressionOpt.java | 2 +- .../postprocess/CommonSubExpressionTest.java | 46 ++++++++++++++++------ .../suites/nereids_rules_p0/cse/cse.groovy | 39 ++++++++++++++++++ 4 files changed, 92 insertions(+), 19 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java index 877e411a539..520902c0439 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java @@ -17,7 +17,10 @@ package org.apache.doris.nereids.processor.post; +import org.apache.doris.nereids.trees.expressions.ArrayItemReference; +import org.apache.doris.nereids.trees.expressions.ArrayItemReference.ArrayItemSlot; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import java.util.HashMap; @@ -28,29 +31,38 @@ import java.util.Set; /** * collect common expr */ -public class CommonSubExpressionCollector extends ExpressionVisitor<Integer, Void> { +public class CommonSubExpressionCollector extends ExpressionVisitor<Integer, Boolean> { public final Map<Integer, Set<Expression>> commonExprByDepth = new HashMap<>(); private final Map<Integer, Set<Expression>> expressionsByDepth = new HashMap<>(); + public int collect(Expression expr) { + return expr.accept(this, expr instanceof Lambda); + } + @Override - public Integer visit(Expression expr, Void context) { + public Integer visit(Expression expr, Boolean inLambda) { if (expr.children().isEmpty()) { return 0; } return collectCommonExpressionByDepth( expr.children() .stream() - .map(child -> child.accept(this, context)) + .map(child -> child.accept(this, inLambda == null || inLambda || child instanceof Lambda)) .reduce(Math::max) .map(m -> m + 1) .orElse(1), - expr + expr, + inLambda == null || inLambda ); } - private int collectCommonExpressionByDepth(int depth, Expression expr) { + private int collectCommonExpressionByDepth(int depth, Expression expr, boolean inLambda) { Set<Expression> expressions = getExpressionsFromDepthMap(depth, expressionsByDepth); - if (expressions.contains(expr)) { + // ArrayItemSlot and ArrayItemReference could not be common expressions + // TODO: could not extract common expression when expression contains same lambda expression + // because ArrayItemSlot in Lambda are not same. + if (expressions.contains(expr) + && !(inLambda && expr.containsType(ArrayItemSlot.class, ArrayItemReference.class))) { Set<Expression> commonExpression = getExpressionsFromDepthMap(depth, commonExprByDepth); commonExpression.add(expr); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java index fca84167994..82239f2c2c9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java @@ -65,7 +65,7 @@ public class CommonSubExpressionOpt extends PlanPostProcessor { List<List<NamedExpression>> multiLayers = Lists.newArrayList(); CommonSubExpressionCollector collector = new CommonSubExpressionCollector(); for (Expression expr : projects) { - expr.accept(collector, null); + collector.collect(expr); } Map<Expression, Alias> aliasMap = new HashMap<>(); if (!collector.commonExprByDepth.isEmpty()) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java index 56b67e087d5..c9a20f65e69 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java @@ -20,14 +20,24 @@ package org.apache.doris.nereids.postprocess; import org.apache.doris.nereids.processor.post.CommonSubExpressionCollector; import org.apache.doris.nereids.processor.post.CommonSubExpressionOpt; import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.ArrayItemReference; +import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.IntegerType; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -37,27 +47,40 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; public class CommonSubExpressionTest extends ExpressionRewriteTestHelper { @Test public void testExtractCommonExpr() { List<NamedExpression> exprs = parseProjections("a+b, a+b+1, abs(a+b+1), a"); - CommonSubExpressionCollector collector = - new CommonSubExpressionCollector(); + CommonSubExpressionCollector collector = new CommonSubExpressionCollector(); exprs.forEach(expr -> collector.visit(expr, null)); - System.out.println(collector.commonExprByDepth); Assertions.assertEquals(2, collector.commonExprByDepth.size()); - List<Expression> l1 = collector.commonExprByDepth.get(Integer.valueOf(1)) - .stream().collect(Collectors.toList()); - List<Expression> l2 = collector.commonExprByDepth.get(Integer.valueOf(2)) - .stream().collect(Collectors.toList()); + List<Expression> l1 = new ArrayList<>(collector.commonExprByDepth.get(1)); + List<Expression> l2 = new ArrayList<>(collector.commonExprByDepth.get(2)); Assertions.assertEquals(1, l1.size()); assertExpression(l1.get(0), "a+b"); Assertions.assertEquals(1, l2.size()); assertExpression(l2.get(0), "a+b+1"); } + @Test + void testLambdaExpression() { + ArrayItemReference ref = new ArrayItemReference("x", new SlotReference(new ExprId(1), "y", + ArrayType.of(IntegerType.INSTANCE), true, ImmutableList.of())); + Expression add = new Add(ref.toSlot(), Literal.of(1)); + Expression and = new And(add, add); + ArrayMap arrayMap = new ArrayMap(new Lambda(ImmutableList.of("x"), and, ImmutableList.of(ref))); + List<NamedExpression> exprs = Lists.newArrayList( + new Alias(new ExprId(10000), arrayMap, "c1"), + new Alias(new ExprId(10001), arrayMap, "c2") + ); + CommonSubExpressionCollector collector = new CommonSubExpressionCollector(); + exprs.forEach(expr -> collector.visit(expr, false)); + Assertions.assertEquals(1, collector.commonExprByDepth.size()); + Assertions.assertEquals(1, collector.commonExprByDepth.get(4).size()); + Assertions.assertEquals(arrayMap, collector.commonExprByDepth.get(4).iterator().next()); + } + @Test public void testMultiLayers() throws Exception { List<NamedExpression> exprs = parseProjections("a, a+b, a+b+1, abs(a+b+1), a"); @@ -68,15 +91,14 @@ public class CommonSubExpressionTest extends ExpressionRewriteTestHelper { computeMultLayerProjectionsMethod.setAccessible(true); List<List<NamedExpression>> multiLayers = (List<List<NamedExpression>>) computeMultLayerProjectionsMethod .invoke(opt, inputSlots, exprs); - System.out.println(multiLayers); Assertions.assertEquals(3, multiLayers.size()); List<NamedExpression> l0 = multiLayers.get(0); Assertions.assertEquals(2, l0.size()); Assertions.assertTrue(l0.contains(ExprParser.INSTANCE.parseExpression("a"))); - Assertions.assertTrue(l0.get(1) instanceof Alias); + Assertions.assertInstanceOf(Alias.class, l0.get(1)); assertExpression(l0.get(1).child(0), "a+b"); - Assertions.assertEquals(multiLayers.get(1).size(), 3); - Assertions.assertEquals(multiLayers.get(2).size(), 5); + Assertions.assertEquals(3, multiLayers.get(1).size()); + Assertions.assertEquals(5, multiLayers.get(2).size()); List<NamedExpression> l2 = multiLayers.get(2); for (int i = 0; i < 5; i++) { Assertions.assertEquals(exprs.get(i).getExprId().asInt(), l2.get(i).getExprId().asInt()); diff --git a/regression-test/suites/nereids_rules_p0/cse/cse.groovy b/regression-test/suites/nereids_rules_p0/cse/cse.groovy new file mode 100644 index 00000000000..07c83dcbacc --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/cse/cse.groovy @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("cse") { + // cse should not extract expression use for lambda, such as ArrayItemSlot and ArrayItemReference + sql """ + drop table if exists array_cse; + """ + sql """ + create table array_cse(c1 int, c2 array<varchar(255)>) PROPERTIES ("replication_allocation" = "tag.location.default: 1"); + """ + sql """ + insert into array_cse values(1, [1,2,3]); + """ + sql """ + sync + """ + sql """ + SELECT array_map(x-> if(left(x, 5) = '12345', x, left(x, 5)), c2) FROM array_cse; + """ + sql """ + SELECT c0, c0 FROM (SELECT ARRAY_MAP(x-> if(left(x, 5), x, left(x, 5)), `c2`) as `c0` FROM array_cse) t + """ + +} + --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org