morrySnow commented on code in PR #10462:
URL: https://github.com/apache/doris/pull/10462#discussion_r907997591


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/GroupExpressionMatching.java:
##########
@@ -133,6 +133,12 @@ private void assembleAllCombinationPlanTree(Plan root, 
Pattern<Plan, Plan> rootP
             int[] childrenPlanIndex = new int[childrenPlans.size()];
             int offset = 0;
 
+            for (List<Plan> plan : childrenPlans) {

Review Comment:
   maybe we need to do this at L103



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java:
##########
@@ -57,7 +58,7 @@ public RuleSet getRuleSet() {
     public List<Rule<NODE_TYPE>> getValidRules(GroupExpression groupExpression,
             List<Rule<NODE_TYPE>> candidateRules) {
         return candidateRules.stream()
-                .filter(rule -> 
rule.getPattern().matchOperator(groupExpression.getOperator())
+                .filter(rule -> Objects.nonNull(rule) && 
rule.getPattern().matchOperator(groupExpression.getOperator())

Review Comment:
   why add this? we have null in list?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java:
##########
@@ -0,0 +1,146 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.operators.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.operators.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import 
org.apache.doris.nereids.rules.expression.rewrite.ExpressionRuleExecutor;
+import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotExtractor;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalBinaryPlan;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.Lists;
+
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * Push the predicate in the LogicalFilter or LogicalJoin to the join children.
+ */
+public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
+
+    @Override
+    public Rule<Plan> build() {
+        return logicalFilter(logicalJoin()).then(plan -> {
+
+            Expression filterPredicates = plan.operator.getPredicates();
+            Optional<Expression> onPredicates = 
plan.child().operator.getCondition();
+
+            List<Slot> leftInput = plan.child().left().getOutput();
+            List<Slot> rightInput = plan.child().right().getOutput();
+
+            List<Expression> joinConditions = Lists.newArrayList();
+            List<Expression> otherConjuncts = Lists.newArrayList();
+
+            
ExpressionUtils.extractConjunct(onPredicates.get()).forEach(predicate -> {
+                if (Objects.nonNull(getJoinCondition(predicate, leftInput, 
rightInput))) {
+                    joinConditions.add(predicate);
+                } else {
+                    otherConjuncts.add(predicate);
+                }
+            });
+            
otherConjuncts.addAll(ExpressionUtils.extractConjunct(filterPredicates));
+
+            List<Expression> leftPredicates = Lists.newArrayList();
+            List<Expression> rightPredicates = Lists.newArrayList();
+
+            for (Expression conjunct : otherConjuncts) {
+                Set<Slot> slots = SlotExtractor.extractSlot(conjunct);
+
+                if (slots.isEmpty()) {
+                    leftPredicates.add(conjunct);
+                    rightPredicates.add(conjunct);
+                    continue;
+                }
+                if (leftInput.containsAll(slots)) {
+                    leftPredicates.add(conjunct);
+                }
+                if (rightInput.containsAll(slots)) {
+                    rightPredicates.add(conjunct);
+                }
+            }
+            otherConjuncts.removeAll(leftPredicates);
+            otherConjuncts.removeAll(rightPredicates);
+
+            joinConditions.addAll(otherConjuncts);
+
+            return pushDownPredicate(plan.child(), joinConditions, 
leftPredicates, rightPredicates);
+        }).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_JOIN);
+    }
+
+    private Plan pushDownPredicate(LogicalBinaryPlan<LogicalJoin, GroupPlan, 
GroupPlan> joinPlan,
+            List<Expression> joinConditions, List<Expression> leftPredicates, 
List<Expression> rightPredicates) {
+
+        Expression left = ExpressionUtils.add(leftPredicates);
+        Expression right = ExpressionUtils.add(rightPredicates);
+        ExpressionRuleExecutor exprRewriter = new ExpressionRuleExecutor();
+        Plan leftPlan = joinPlan.left();
+        Plan rightPlan = joinPlan.right();
+        if (!left.equals(ExpressionUtils.TRUE_LITERAL)) {
+            leftPlan = plan(new LogicalFilter(exprRewriter.rewrite(left)), 
leftPlan);
+        }
+
+        if (!right.equals(ExpressionUtils.TRUE_LITERAL)) {
+            rightPlan = plan(new LogicalFilter(exprRewriter.rewrite(right)), 
rightPlan);
+        }
+
+        if (!joinConditions.isEmpty()) {
+            return plan(new LogicalJoin(joinPlan.getOperator().getJoinType(),
+                    Optional.of(ExpressionUtils.add(joinConditions))), 
leftPlan, rightPlan);
+        }
+
+        return joinPlan.withChildren(Lists.newArrayList(leftPlan, rightPlan));
+    }
+
+    private Expression getJoinCondition(Expression predicate, List<Slot> 
leftOutput, List<Slot> rightOutput) {
+        if (!(predicate instanceof ComparisonPredicate)) {
+            return null;
+        }
+
+        ComparisonPredicate comparison = (ComparisonPredicate) predicate;
+
+        if (!(comparison.left() instanceof Slot) || !(comparison.right() 
instanceof Slot)) {

Review Comment:
   maybe we need to support this situation:
   ```sql
   on t1.k1 + 1 = t2.k2 - 1
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java:
##########
@@ -0,0 +1,146 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.operators.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.operators.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import 
org.apache.doris.nereids.rules.expression.rewrite.ExpressionRuleExecutor;
+import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotExtractor;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalBinaryPlan;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.Lists;
+
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * Push the predicate in the LogicalFilter or LogicalJoin to the join children.
+ */
+public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
+
+    @Override
+    public Rule<Plan> build() {
+        return logicalFilter(logicalJoin()).then(plan -> {
+
+            Expression filterPredicates = plan.operator.getPredicates();
+            Optional<Expression> onPredicates = 
plan.child().operator.getCondition();
+
+            List<Slot> leftInput = plan.child().left().getOutput();
+            List<Slot> rightInput = plan.child().right().getOutput();
+
+            List<Expression> joinConditions = Lists.newArrayList();
+            List<Expression> otherConjuncts = Lists.newArrayList();
+
+            
ExpressionUtils.extractConjunct(onPredicates.get()).forEach(predicate -> {
+                if (Objects.nonNull(getJoinCondition(predicate, leftInput, 
rightInput))) {
+                    joinConditions.add(predicate);
+                } else {
+                    otherConjuncts.add(predicate);
+                }
+            });
+            
otherConjuncts.addAll(ExpressionUtils.extractConjunct(filterPredicates));
+
+            List<Expression> leftPredicates = Lists.newArrayList();
+            List<Expression> rightPredicates = Lists.newArrayList();
+
+            for (Expression conjunct : otherConjuncts) {
+                Set<Slot> slots = SlotExtractor.extractSlot(conjunct);
+
+                if (slots.isEmpty()) {
+                    leftPredicates.add(conjunct);
+                    rightPredicates.add(conjunct);
+                    continue;

Review Comment:
   why slot is empty present this conjunct is both leftPredicate and 
rightPredicate?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java:
##########
@@ -0,0 +1,146 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.operators.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.operators.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import 
org.apache.doris.nereids.rules.expression.rewrite.ExpressionRuleExecutor;
+import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotExtractor;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalBinaryPlan;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.Lists;
+
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * Push the predicate in the LogicalFilter or LogicalJoin to the join children.
+ */
+public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
+
+    @Override
+    public Rule<Plan> build() {
+        return logicalFilter(logicalJoin()).then(plan -> {
+
+            Expression filterPredicates = plan.operator.getPredicates();
+            Optional<Expression> onPredicates = 
plan.child().operator.getCondition();
+
+            List<Slot> leftInput = plan.child().left().getOutput();
+            List<Slot> rightInput = plan.child().right().getOutput();
+
+            List<Expression> joinConditions = Lists.newArrayList();
+            List<Expression> otherConjuncts = Lists.newArrayList();
+
+            
ExpressionUtils.extractConjunct(onPredicates.get()).forEach(predicate -> {
+                if (Objects.nonNull(getJoinCondition(predicate, leftInput, 
rightInput))) {
+                    joinConditions.add(predicate);
+                } else {
+                    otherConjuncts.add(predicate);
+                }
+            });
+            
otherConjuncts.addAll(ExpressionUtils.extractConjunct(filterPredicates));
+
+            List<Expression> leftPredicates = Lists.newArrayList();
+            List<Expression> rightPredicates = Lists.newArrayList();
+
+            for (Expression conjunct : otherConjuncts) {
+                Set<Slot> slots = SlotExtractor.extractSlot(conjunct);
+
+                if (slots.isEmpty()) {
+                    leftPredicates.add(conjunct);
+                    rightPredicates.add(conjunct);
+                    continue;
+                }
+                if (leftInput.containsAll(slots)) {

Review Comment:
   expressions in output are not always slot and slot in predicate is a slot 
reference refer to expressions in child's output. So we should use ExprId in  
slot reference to judge which child it come from



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java:
##########
@@ -37,13 +39,31 @@ public CompoundPredicate(NodeType type, LEFT_CHILD_TYPE 
left, RIGHT_CHILD_TYPE r
         super(type, left, right);
     }
 
+    @Override
+    public String toString() {
+        return sql();
+    }
+
     @Override
     public String sql() {
         String nodeType = getType().toString();
-        return left().sql() + ' ' + nodeType + ' ' + right().sql();
+        return left() + " " + nodeType + " " + right();

Review Comment:
   why remove '.sql()'?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java:
##########
@@ -0,0 +1,142 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.util;
+
+import org.apache.doris.nereids.trees.NodeType;
+import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Literal;
+
+import com.google.common.collect.Lists;
+
+import java.util.LinkedHashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+/**
+ * Expression rewrite helper class.
+ */
+public class ExpressionUtils {
+
+    public static final Literal TRUE_LITERAL = new Literal(true);

Review Comment:
   move to Literal maybe better



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java:
##########
@@ -0,0 +1,146 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.operators.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.operators.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import 
org.apache.doris.nereids.rules.expression.rewrite.ExpressionRuleExecutor;
+import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotExtractor;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalBinaryPlan;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.Lists;
+
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * Push the predicate in the LogicalFilter or LogicalJoin to the join children.
+ */
+public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
+
+    @Override
+    public Rule<Plan> build() {
+        return logicalFilter(logicalJoin()).then(plan -> {
+
+            Expression filterPredicates = plan.operator.getPredicates();
+            Optional<Expression> onPredicates = 
plan.child().operator.getCondition();

Review Comment:
   currently, we only support equal on condition, add a TODO to support other 
on condition



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java:
##########
@@ -0,0 +1,142 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.util;
+
+import org.apache.doris.nereids.trees.NodeType;
+import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Literal;
+
+import com.google.common.collect.Lists;
+
+import java.util.LinkedHashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+/**
+ * Expression rewrite helper class.
+ */
+public class ExpressionUtils {
+
+    public static final Literal TRUE_LITERAL = new Literal(true);
+    public static final Literal FALSE_LITERAL = new Literal(false);
+
+    public static boolean isConstant(Expression expr) {
+        return expr.isConstant();
+    }
+
+    public static List<Expression> extractConjunct(Expression expr) {
+        return extract(NodeType.AND, expr);
+    }
+
+
+    public static List<Expression> extractDisjunct(Expression expr) {
+        return extract(NodeType.OR, expr);
+    }
+
+    public static List<Expression> extract(CompoundPredicate expr) {
+        return extract(expr.getType(), expr);
+    }
+
+    private static List<Expression> extract(NodeType op, Expression expr) {
+        List<Expression> result = Lists.newArrayList();
+        extract(op, expr, result);
+        return result;
+    }
+
+    private static void extract(NodeType op, Expression expr, List<Expression> 
result) {
+        if (expr instanceof CompoundPredicate && expr.getType() == op) {
+            CompoundPredicate predicate = (CompoundPredicate) expr;
+            extract(op, predicate.left(), result);
+            extract(op, predicate.right(), result);
+        } else {
+            result.add(expr);
+        }
+    }
+
+
+    public static Expression add(List<Expression> expressions) {
+        return combine(NodeType.AND, expressions);
+    }
+
+    public static Expression add(Expression... expressions) {
+        return combine(NodeType.AND, Lists.newArrayList(expressions));
+    }
+
+    public static Expression or(Expression... expressions) {
+        return combine(NodeType.OR, Lists.newArrayList(expressions));
+    }
+
+    public static Expression or(List<Expression> expressions) {
+        return combine(NodeType.OR, expressions);
+    }
+
+    /**
+     * Use AND/OR to combine expressions together.
+     */
+    public static Expression combine(NodeType op, List<Expression> 
expressions) {
+
+        Objects.requireNonNull(expressions, "expressions is null");
+
+        if (expressions.size() == 0) {
+            if (op == NodeType.AND) {
+                return new Literal(true);
+            }
+            if (op == NodeType.OR) {
+                return new Literal(false);
+            }
+        }
+
+        if (expressions.size() == 1) {
+            return expressions.get(0);
+        }
+
+        List<Expression> distinctExpressions = Lists.newArrayList(new 
LinkedHashSet<>(expressions));
+        if (op == NodeType.AND) {
+            if (distinctExpressions.contains(FALSE_LITERAL)) {
+                return FALSE_LITERAL;
+            }
+            distinctExpressions = distinctExpressions.stream().filter(p -> 
!p.equals(TRUE_LITERAL))
+                    .collect(Collectors.toList());
+        }
+
+        if (op == NodeType.OR) {
+            if (distinctExpressions.contains(TRUE_LITERAL)) {
+                return TRUE_LITERAL;
+            }
+            distinctExpressions = distinctExpressions.stream().filter(p -> 
!p.equals(FALSE_LITERAL))
+                    .collect(Collectors.toList());
+        }
+
+        List<List<Expression>> partitions = 
Lists.partition(distinctExpressions, 2);
+        List<Expression> result = new LinkedList<>();
+
+        for (List<Expression> partition : partitions) {
+            if (partition.size() == 2) {
+                result.add(new CompoundPredicate(op, partition.get(0), 
partition.get(1)));
+            }
+            if (partition.size() == 1) {
+                result.add(partition.get(0));
+            }
+        }
+
+        return combine(op, result);

Review Comment:
   how about just combine it one by one?
   ```java
   Expression newExpression = distinctExpressions.stream()
                   .reduce((combined, expression) -> new 
CompoundPredicate<>(NodeType.AND, combined, expression)).get();
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java:
##########
@@ -0,0 +1,146 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.operators.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.operators.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import 
org.apache.doris.nereids.rules.expression.rewrite.ExpressionRuleExecutor;
+import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotExtractor;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalBinaryPlan;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.Lists;
+
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * Push the predicate in the LogicalFilter or LogicalJoin to the join children.
+ */
+public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
+
+    @Override
+    public Rule<Plan> build() {
+        return logicalFilter(logicalJoin()).then(plan -> {
+
+            Expression filterPredicates = plan.operator.getPredicates();
+            Optional<Expression> onPredicates = 
plan.child().operator.getCondition();
+
+            List<Slot> leftInput = plan.child().left().getOutput();
+            List<Slot> rightInput = plan.child().right().getOutput();
+
+            List<Expression> joinConditions = Lists.newArrayList();
+            List<Expression> otherConjuncts = Lists.newArrayList();
+
+            
ExpressionUtils.extractConjunct(onPredicates.get()).forEach(predicate -> {
+                if (Objects.nonNull(getJoinCondition(predicate, leftInput, 
rightInput))) {
+                    joinConditions.add(predicate);
+                } else {
+                    otherConjuncts.add(predicate);
+                }
+            });
+            
otherConjuncts.addAll(ExpressionUtils.extractConjunct(filterPredicates));
+
+            List<Expression> leftPredicates = Lists.newArrayList();
+            List<Expression> rightPredicates = Lists.newArrayList();
+
+            for (Expression conjunct : otherConjuncts) {
+                Set<Slot> slots = SlotExtractor.extractSlot(conjunct);
+
+                if (slots.isEmpty()) {
+                    leftPredicates.add(conjunct);
+                    rightPredicates.add(conjunct);
+                    continue;
+                }
+                if (leftInput.containsAll(slots)) {
+                    leftPredicates.add(conjunct);
+                }
+                if (rightInput.containsAll(slots)) {
+                    rightPredicates.add(conjunct);
+                }
+            }
+            otherConjuncts.removeAll(leftPredicates);
+            otherConjuncts.removeAll(rightPredicates);
+
+            joinConditions.addAll(otherConjuncts);
+
+            return pushDownPredicate(plan.child(), joinConditions, 
leftPredicates, rightPredicates);
+        }).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_JOIN);
+    }
+
+    private Plan pushDownPredicate(LogicalBinaryPlan<LogicalJoin, GroupPlan, 
GroupPlan> joinPlan,
+            List<Expression> joinConditions, List<Expression> leftPredicates, 
List<Expression> rightPredicates) {
+
+        Expression left = ExpressionUtils.add(leftPredicates);
+        Expression right = ExpressionUtils.add(rightPredicates);
+        ExpressionRuleExecutor exprRewriter = new ExpressionRuleExecutor();
+        Plan leftPlan = joinPlan.left();
+        Plan rightPlan = joinPlan.right();
+        if (!left.equals(ExpressionUtils.TRUE_LITERAL)) {
+            leftPlan = plan(new LogicalFilter(exprRewriter.rewrite(left)), 
leftPlan);
+        }
+
+        if (!right.equals(ExpressionUtils.TRUE_LITERAL)) {
+            rightPlan = plan(new LogicalFilter(exprRewriter.rewrite(right)), 
rightPlan);
+        }
+
+        if (!joinConditions.isEmpty()) {
+            return plan(new LogicalJoin(joinPlan.getOperator().getJoinType(),
+                    Optional.of(ExpressionUtils.add(joinConditions))), 
leftPlan, rightPlan);
+        }
+
+        return joinPlan.withChildren(Lists.newArrayList(leftPlan, rightPlan));
+    }
+
+    private Expression getJoinCondition(Expression predicate, List<Slot> 
leftOutput, List<Slot> rightOutput) {
+        if (!(predicate instanceof ComparisonPredicate)) {
+            return null;
+        }
+
+        ComparisonPredicate comparison = (ComparisonPredicate) predicate;
+
+        if (!(comparison.left() instanceof Slot) || !(comparison.right() 
instanceof Slot)) {
+            return null;
+        }
+
+        Slot left = (Slot) comparison.left();
+        Slot right = (Slot) comparison.right();
+
+        if (!leftOutput.contains(left)) {
+            Slot tmp = left;
+            left = right;
+            right = tmp;
+        }
+
+        if (leftOutput.contains(left) && rightOutput.contains(right)) {
+            return predicate;
+        }

Review Comment:
   ```suggestion
           if ((leftOutput.contains(left) && rightOutput.contains(right))
                    || (leftOutput.contains(right) && 
rightOutput.contains(left))) {
               return predicate;
           }
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java:
##########
@@ -0,0 +1,142 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.util;
+
+import org.apache.doris.nereids.trees.NodeType;
+import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Literal;
+
+import com.google.common.collect.Lists;
+
+import java.util.LinkedHashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+/**
+ * Expression rewrite helper class.
+ */
+public class ExpressionUtils {
+
+    public static final Literal TRUE_LITERAL = new Literal(true);
+    public static final Literal FALSE_LITERAL = new Literal(false);
+
+    public static boolean isConstant(Expression expr) {
+        return expr.isConstant();
+    }
+
+    public static List<Expression> extractConjunct(Expression expr) {
+        return extract(NodeType.AND, expr);
+    }
+
+
+    public static List<Expression> extractDisjunct(Expression expr) {
+        return extract(NodeType.OR, expr);
+    }
+
+    public static List<Expression> extract(CompoundPredicate expr) {
+        return extract(expr.getType(), expr);
+    }
+
+    private static List<Expression> extract(NodeType op, Expression expr) {
+        List<Expression> result = Lists.newArrayList();
+        extract(op, expr, result);
+        return result;
+    }
+
+    private static void extract(NodeType op, Expression expr, List<Expression> 
result) {
+        if (expr instanceof CompoundPredicate && expr.getType() == op) {
+            CompoundPredicate predicate = (CompoundPredicate) expr;
+            extract(op, predicate.left(), result);
+            extract(op, predicate.right(), result);
+        } else {
+            result.add(expr);
+        }
+    }
+
+
+    public static Expression add(List<Expression> expressions) {
+        return combine(NodeType.AND, expressions);
+    }
+
+    public static Expression add(Expression... expressions) {
+        return combine(NodeType.AND, Lists.newArrayList(expressions));
+    }
+
+    public static Expression or(Expression... expressions) {
+        return combine(NodeType.OR, Lists.newArrayList(expressions));
+    }
+
+    public static Expression or(List<Expression> expressions) {
+        return combine(NodeType.OR, expressions);
+    }
+
+    /**
+     * Use AND/OR to combine expressions together.
+     */
+    public static Expression combine(NodeType op, List<Expression> 
expressions) {
+
+        Objects.requireNonNull(expressions, "expressions is null");
+
+        if (expressions.size() == 0) {
+            if (op == NodeType.AND) {
+                return new Literal(true);
+            }
+            if (op == NodeType.OR) {
+                return new Literal(false);
+            }
+        }
+
+        if (expressions.size() == 1) {
+            return expressions.get(0);
+        }
+
+        List<Expression> distinctExpressions = Lists.newArrayList(new 
LinkedHashSet<>(expressions));

Review Comment:
   what is LinkedHashSet use for?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to