This is an automated email from the ASF dual-hosted git repository. xiejiann pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 015f051f73a [opt](Nereids) Optimize findValidItems method to handle circular dependencies (#36839) 015f051f73a is described below commit 015f051f73a8fe6d0e50dc643b2fcd114838b465 Author: 谢健 <jianx...@gmail.com> AuthorDate: Wed Jun 26 16:51:39 2024 +0800 [opt](Nereids) Optimize findValidItems method to handle circular dependencies (#36839) ## Proposed changes These optimizations allow the findValidItems method to correctly handle circular dependencies while maintaining the required output slots. The code is now more efficient and ensures that the necessary edges and items are preserved during the traversal process. --- .../apache/doris/nereids/properties/FuncDeps.java | 25 +++++++++++++++++----- .../nereids/rules/rewrite/EliminateGroupByKey.java | 2 +- .../doris/nereids/properties/FuncDepsTest.java | 13 +++++------ .../rules/rewrite/EliminateGroupByKeyTest.java | 21 ++++++++++++------ 4 files changed, 43 insertions(+), 18 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java index 6c1b302d7dc..c17fd2eee57 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java @@ -27,6 +27,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; /** * Function dependence items. @@ -96,11 +97,25 @@ public class FuncDeps { } } - // find item that not in a circle - private Set<FuncDepsItem> findValidItems() { + // Find items that are not part of a circular dependency. + // To keep the slots in requireOutputs, we need to always keep the edges that start with output slots. + // Note: We reduce the last edge in a circular dependency, + // so we need to traverse from parents that contain the required output slots. + private Set<FuncDepsItem> findValidItems(Set<Slot> requireOutputs) { Set<FuncDepsItem> circleItem = new HashSet<>(); Set<Set<Slot>> visited = new HashSet<>(); - for (Set<Slot> parent : edges.keySet()) { + Set<Set<Slot>> parentInOutput = edges.keySet().stream() + .filter(requireOutputs::containsAll) + .collect(Collectors.toSet()); + for (Set<Slot> parent : parentInOutput) { + if (!visited.contains(parent)) { + dfs(parent, visited, circleItem); + } + } + Set<Set<Slot>> otherParent = edges.keySet().stream() + .filter(parent -> !parentInOutput.contains(parent)) + .collect(Collectors.toSet()); + for (Set<Slot> parent : otherParent) { if (!visited.contains(parent)) { dfs(parent, visited, circleItem); } @@ -126,10 +141,10 @@ public class FuncDeps { * @param slots the initial set of slot sets to be reduced * @return the minimal set of slot sets after applying all possible reductions */ - public Set<Set<Slot>> eliminateDeps(Set<Set<Slot>> slots) { + public Set<Set<Slot>> eliminateDeps(Set<Set<Slot>> slots, Set<Slot> requireOutputs) { Set<Set<Slot>> minSlotSet = Sets.newHashSet(slots); Set<Set<Slot>> eliminatedSlots = new HashSet<>(); - Set<FuncDepsItem> validItems = findValidItems(); + Set<FuncDepsItem> validItems = findValidItems(requireOutputs); for (FuncDepsItem funcDepsItem : validItems) { if (minSlotSet.contains(funcDepsItem.dependencies) && minSlotSet.contains(funcDepsItem.determinants)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java index 9e205f85809..fbe0988daff 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java @@ -91,7 +91,7 @@ public class EliminateGroupByKey implements RewriteRuleFactory { return null; } - Set<Set<Slot>> minGroupBySlots = funcDeps.eliminateDeps(new HashSet<>(groupBySlots.values())); + Set<Set<Slot>> minGroupBySlots = funcDeps.eliminateDeps(new HashSet<>(groupBySlots.values()), requireOutput); Set<Expression> removeExpression = new HashSet<>(); for (Entry<Expression, Set<Slot>> entry : groupBySlots.entrySet()) { if (!minGroupBySlots.contains(entry.getValue()) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java index 64df33acd60..6b17305ed7a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.types.IntegerType; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -43,7 +44,7 @@ class FuncDepsTest { Set<Set<Slot>> slotSet = Sets.newHashSet(set1, set2, set3, set4); FuncDeps funcDeps = new FuncDeps(); funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2)); - Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet); + Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of()); Set<Set<Slot>> expected = new HashSet<>(); expected.add(set1); expected.add(set3); @@ -58,7 +59,7 @@ class FuncDepsTest { funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2)); funcDeps.addFuncItems(Sets.newHashSet(s2), Sets.newHashSet(s3)); funcDeps.addFuncItems(Sets.newHashSet(s3), Sets.newHashSet(s4)); - Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet); + Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of()); Set<Set<Slot>> expected = new HashSet<>(); expected.add(set1); Assertions.assertEquals(expected, slots); @@ -71,7 +72,7 @@ class FuncDepsTest { funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2)); funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s3)); funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s4)); - Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet); + Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of()); Set<Set<Slot>> expected = new HashSet<>(); expected.add(set1); Assertions.assertEquals(expected, slots); @@ -83,7 +84,7 @@ class FuncDepsTest { FuncDeps funcDeps = new FuncDeps(); funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2)); funcDeps.addFuncItems(Sets.newHashSet(s2), Sets.newHashSet(s1)); - Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet); + Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of()); Set<Set<Slot>> expected = new HashSet<>(); expected.add(set1); expected.add(set3); @@ -99,7 +100,7 @@ class FuncDepsTest { funcDeps.addFuncItems(Sets.newHashSet(s2), Sets.newHashSet(s3)); funcDeps.addFuncItems(Sets.newHashSet(s3), Sets.newHashSet(s4)); funcDeps.addFuncItems(Sets.newHashSet(s4), Sets.newHashSet(s1)); - Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet); + Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of()); Set<Set<Slot>> expected = new HashSet<>(); expected.add(set1); Assertions.assertEquals(expected, slots); @@ -112,7 +113,7 @@ class FuncDepsTest { funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2)); funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s3)); funcDeps.addFuncItems(Sets.newHashSet(s3), Sets.newHashSet(s4)); - Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet); + Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of()); Set<Set<Slot>> expected = new HashSet<>(); expected.add(set1); Assertions.assertEquals(expected, slots); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java index 203e902b3eb..5a9e15cf477 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java @@ -66,7 +66,7 @@ class EliminateGroupByKeyTest extends TestWithFeService implements MemoPatternMa funcDeps.addFuncItems(set1, set2); funcDeps.addFuncItems(set2, set3); funcDeps.addFuncItems(set3, set4); - Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4)); + Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4), ImmutableSet.of()); Assertions.assertEquals(1, slots.size()); Assertions.assertEquals(set1, slots.iterator().next()); } @@ -78,7 +78,7 @@ class EliminateGroupByKeyTest extends TestWithFeService implements MemoPatternMa funcDeps.addFuncItems(set2, set3); funcDeps.addFuncItems(set3, set4); funcDeps.addFuncItems(set4, set1); - Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4)); + Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4), ImmutableSet.of()); Assertions.assertEquals(1, slots.size()); Assertions.assertEquals(set1, slots.iterator().next()); } @@ -89,7 +89,7 @@ class EliminateGroupByKeyTest extends TestWithFeService implements MemoPatternMa funcDeps.addFuncItems(set1, set2); funcDeps.addFuncItems(set1, set3); funcDeps.addFuncItems(set1, set4); - Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4)); + Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4), ImmutableSet.of()); Assertions.assertEquals(1, slots.size()); Assertions.assertEquals(set1, slots.iterator().next()); } @@ -163,11 +163,20 @@ class EliminateGroupByKeyTest extends TestWithFeService implements MemoPatternMa @Test void testEliminateByEqual() { PlanChecker.from(connectContext) - .analyze("select count(t1.name) from t1 as t1 join t1 as t2 on t1.name = t2.name group by t1.name, t2.name") + .analyze("select t1.name from t1 as t1 join t1 as t2 on t1.name = t2.name group by t1.name, t2.name") .rewrite() .printlnTree() .matches(logicalAggregate().when(agg -> - agg.getGroupByExpressions().size() == 1 && agg.getGroupByExpressions().get(0).toSql().equals("name"))); - } + agg.getGroupByExpressions().size() == 1 + && agg.getGroupByExpressions().get(0).toSql().equals("name"))); + PlanChecker.from(connectContext) + .analyze("select t2.name from t1 as t1 join t1 as t2 " + + "on t1.name = t2.name group by t1.name, t2.name") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1 + && agg.getGroupByExpressions().get(0).toSql().equals("name"))); + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org