This is an automated email from the ASF dual-hosted git repository.

xiejiann pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 015f051f73a [opt](Nereids) Optimize findValidItems method to handle 
circular dependencies (#36839)
015f051f73a is described below

commit 015f051f73a8fe6d0e50dc643b2fcd114838b465
Author: 谢健 <jianx...@gmail.com>
AuthorDate: Wed Jun 26 16:51:39 2024 +0800

    [opt](Nereids) Optimize findValidItems method to handle circular 
dependencies (#36839)
    
    ## Proposed changes
    
    These optimizations allow the findValidItems method to correctly handle
    circular dependencies while maintaining the required output slots. The
    code is now more efficient and ensures that the necessary edges and
    items are preserved during the traversal process.
---
 .../apache/doris/nereids/properties/FuncDeps.java  | 25 +++++++++++++++++-----
 .../nereids/rules/rewrite/EliminateGroupByKey.java |  2 +-
 .../doris/nereids/properties/FuncDepsTest.java     | 13 +++++------
 .../rules/rewrite/EliminateGroupByKeyTest.java     | 21 ++++++++++++------
 4 files changed, 43 insertions(+), 18 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java
index 6c1b302d7dc..c17fd2eee57 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java
@@ -27,6 +27,7 @@ import java.util.HashSet;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 /**
  * Function dependence items.
@@ -96,11 +97,25 @@ public class FuncDeps {
         }
     }
 
-    // find item that not in a circle
-    private Set<FuncDepsItem> findValidItems() {
+    // Find items that are not part of a circular dependency.
+    // To keep the slots in requireOutputs, we need to always keep the edges 
that start with output slots.
+    // Note: We reduce the last edge in a circular dependency,
+    // so we need to traverse from parents that contain the required output 
slots.
+    private Set<FuncDepsItem> findValidItems(Set<Slot> requireOutputs) {
         Set<FuncDepsItem> circleItem = new HashSet<>();
         Set<Set<Slot>> visited = new HashSet<>();
-        for (Set<Slot> parent : edges.keySet()) {
+        Set<Set<Slot>> parentInOutput = edges.keySet().stream()
+                .filter(requireOutputs::containsAll)
+                .collect(Collectors.toSet());
+        for (Set<Slot> parent : parentInOutput) {
+            if (!visited.contains(parent)) {
+                dfs(parent, visited, circleItem);
+            }
+        }
+        Set<Set<Slot>> otherParent = edges.keySet().stream()
+                .filter(parent -> !parentInOutput.contains(parent))
+                .collect(Collectors.toSet());
+        for (Set<Slot> parent : otherParent) {
             if (!visited.contains(parent)) {
                 dfs(parent, visited, circleItem);
             }
@@ -126,10 +141,10 @@ public class FuncDeps {
      * @param slots the initial set of slot sets to be reduced
      * @return the minimal set of slot sets after applying all possible 
reductions
      */
-    public Set<Set<Slot>> eliminateDeps(Set<Set<Slot>> slots) {
+    public Set<Set<Slot>> eliminateDeps(Set<Set<Slot>> slots, Set<Slot> 
requireOutputs) {
         Set<Set<Slot>> minSlotSet = Sets.newHashSet(slots);
         Set<Set<Slot>> eliminatedSlots = new HashSet<>();
-        Set<FuncDepsItem> validItems = findValidItems();
+        Set<FuncDepsItem> validItems = findValidItems(requireOutputs);
         for (FuncDepsItem funcDepsItem : validItems) {
             if (minSlotSet.contains(funcDepsItem.dependencies)
                     && minSlotSet.contains(funcDepsItem.determinants)) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java
index 9e205f85809..fbe0988daff 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java
@@ -91,7 +91,7 @@ public class EliminateGroupByKey implements 
RewriteRuleFactory {
             return null;
         }
 
-        Set<Set<Slot>> minGroupBySlots = funcDeps.eliminateDeps(new 
HashSet<>(groupBySlots.values()));
+        Set<Set<Slot>> minGroupBySlots = funcDeps.eliminateDeps(new 
HashSet<>(groupBySlots.values()), requireOutput);
         Set<Expression> removeExpression = new HashSet<>();
         for (Entry<Expression, Set<Slot>> entry : groupBySlots.entrySet()) {
             if (!minGroupBySlots.contains(entry.getValue())
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java
index 64df33acd60..6b17305ed7a 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java
@@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.types.IntegerType;
 
+import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Sets;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
@@ -43,7 +44,7 @@ class FuncDepsTest {
         Set<Set<Slot>> slotSet = Sets.newHashSet(set1, set2, set3, set4);
         FuncDeps funcDeps = new FuncDeps();
         funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2));
-        Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet);
+        Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, 
ImmutableSet.of());
         Set<Set<Slot>> expected = new HashSet<>();
         expected.add(set1);
         expected.add(set3);
@@ -58,7 +59,7 @@ class FuncDepsTest {
         funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2));
         funcDeps.addFuncItems(Sets.newHashSet(s2), Sets.newHashSet(s3));
         funcDeps.addFuncItems(Sets.newHashSet(s3), Sets.newHashSet(s4));
-        Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet);
+        Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, 
ImmutableSet.of());
         Set<Set<Slot>> expected = new HashSet<>();
         expected.add(set1);
         Assertions.assertEquals(expected, slots);
@@ -71,7 +72,7 @@ class FuncDepsTest {
         funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2));
         funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s3));
         funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s4));
-        Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet);
+        Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, 
ImmutableSet.of());
         Set<Set<Slot>> expected = new HashSet<>();
         expected.add(set1);
         Assertions.assertEquals(expected, slots);
@@ -83,7 +84,7 @@ class FuncDepsTest {
         FuncDeps funcDeps = new FuncDeps();
         funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2));
         funcDeps.addFuncItems(Sets.newHashSet(s2), Sets.newHashSet(s1));
-        Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet);
+        Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, 
ImmutableSet.of());
         Set<Set<Slot>> expected = new HashSet<>();
         expected.add(set1);
         expected.add(set3);
@@ -99,7 +100,7 @@ class FuncDepsTest {
         funcDeps.addFuncItems(Sets.newHashSet(s2), Sets.newHashSet(s3));
         funcDeps.addFuncItems(Sets.newHashSet(s3), Sets.newHashSet(s4));
         funcDeps.addFuncItems(Sets.newHashSet(s4), Sets.newHashSet(s1));
-        Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet);
+        Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, 
ImmutableSet.of());
         Set<Set<Slot>> expected = new HashSet<>();
         expected.add(set1);
         Assertions.assertEquals(expected, slots);
@@ -112,7 +113,7 @@ class FuncDepsTest {
         funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2));
         funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s3));
         funcDeps.addFuncItems(Sets.newHashSet(s3), Sets.newHashSet(s4));
-        Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet);
+        Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, 
ImmutableSet.of());
         Set<Set<Slot>> expected = new HashSet<>();
         expected.add(set1);
         Assertions.assertEquals(expected, slots);
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java
index 203e902b3eb..5a9e15cf477 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java
@@ -66,7 +66,7 @@ class EliminateGroupByKeyTest extends TestWithFeService 
implements MemoPatternMa
         funcDeps.addFuncItems(set1, set2);
         funcDeps.addFuncItems(set2, set3);
         funcDeps.addFuncItems(set3, set4);
-        Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, 
set2, set3, set4));
+        Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, 
set2, set3, set4), ImmutableSet.of());
         Assertions.assertEquals(1, slots.size());
         Assertions.assertEquals(set1, slots.iterator().next());
     }
@@ -78,7 +78,7 @@ class EliminateGroupByKeyTest extends TestWithFeService 
implements MemoPatternMa
         funcDeps.addFuncItems(set2, set3);
         funcDeps.addFuncItems(set3, set4);
         funcDeps.addFuncItems(set4, set1);
-        Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, 
set2, set3, set4));
+        Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, 
set2, set3, set4), ImmutableSet.of());
         Assertions.assertEquals(1, slots.size());
         Assertions.assertEquals(set1, slots.iterator().next());
     }
@@ -89,7 +89,7 @@ class EliminateGroupByKeyTest extends TestWithFeService 
implements MemoPatternMa
         funcDeps.addFuncItems(set1, set2);
         funcDeps.addFuncItems(set1, set3);
         funcDeps.addFuncItems(set1, set4);
-        Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, 
set2, set3, set4));
+        Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, 
set2, set3, set4), ImmutableSet.of());
         Assertions.assertEquals(1, slots.size());
         Assertions.assertEquals(set1, slots.iterator().next());
     }
@@ -163,11 +163,20 @@ class EliminateGroupByKeyTest extends TestWithFeService 
implements MemoPatternMa
     @Test
     void testEliminateByEqual() {
         PlanChecker.from(connectContext)
-                .analyze("select count(t1.name) from t1 as t1 join t1 as t2 on 
t1.name = t2.name group by t1.name, t2.name")
+                .analyze("select t1.name from t1 as t1 join t1 as t2 on 
t1.name = t2.name group by t1.name, t2.name")
                 .rewrite()
                 .printlnTree()
                 .matches(logicalAggregate().when(agg ->
-                        agg.getGroupByExpressions().size() == 1 && 
agg.getGroupByExpressions().get(0).toSql().equals("name")));
-    }
+                        agg.getGroupByExpressions().size() == 1
+                                && 
agg.getGroupByExpressions().get(0).toSql().equals("name")));
 
+        PlanChecker.from(connectContext)
+                .analyze("select t2.name from t1 as t1 join t1 as t2 "
+                        + "on t1.name = t2.name group by t1.name, t2.name")
+                .rewrite()
+                .printlnTree()
+                .matches(logicalAggregate().when(agg ->
+                        agg.getGroupByExpressions().size() == 1
+                                && 
agg.getGroupByExpressions().get(0).toSql().equals("name")));
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to