This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 73f6d9532dce59c4610c297eaa4812e8c2c1a663
Author: minghong <engle...@gmail.com>
AuthorDate: Thu Mar 7 14:29:25 2024 +0800

    [feat](nereids) support null safe eq runtime filter (FE part) (#31655)
    
    be part has been merged in #31754
---
 .../processor/post/RuntimeFilterGenerator.java     | 34 +++++------
 .../rules/expression/ExpressionOptimization.java   |  4 +-
 .../expression/rules/NullSafeEqualToEqual.java     | 62 +++++++++++++++++++
 .../rules/rewrite/FindHashConditionForJoin.java    |  3 +-
 .../trees/plans/physical/AbstractPhysicalPlan.java | 10 ++++
 .../org/apache/doris/planner/HashJoinNode.java     |  6 --
 .../org/apache/doris/planner/RuntimeFilter.java    |  9 +++
 .../main/java/org/apache/doris/qe/Coordinator.java |  4 +-
 .../expression/rules/NullSafeEqualToEqualTest.java | 69 ++++++++++++++++++++++
 .../rules/rewrite/EliminateJoinByUniqueTest.java   | 14 -----
 10 files changed, 174 insertions(+), 41 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java
index f7a7c166bd7..8da3ede4200 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java
@@ -24,6 +24,7 @@ import org.apache.doris.nereids.stats.ExpressionEstimation;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.CTEId;
 import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
@@ -268,25 +269,22 @@ public class RuntimeFilterGenerator extends 
PlanPostProcessor {
         List<Expression> hashJoinConjuncts = 
join.getHashJoinConjuncts().stream().collect(Collectors.toList());
         boolean buildSideContainsConsumer = 
hasCTEConsumerDescendant((PhysicalPlan) join.right());
         for (int i = 0; i < hashJoinConjuncts.size(); i++) {
-            // BE do not support RF generated from NullSafeEqual, skip them
-            if (hashJoinConjuncts.get(i) instanceof EqualTo) {
-                EqualTo equalTo = ((EqualTo) 
JoinUtils.swapEqualToForChildrenOrder(
-                        (EqualTo) hashJoinConjuncts.get(i), 
join.left().getOutputSet()));
-                for (TRuntimeFilterType type : legalTypes) {
-                    //bitmap rf is generated by nested loop join.
-                    if (type == TRuntimeFilterType.BITMAP) {
-                        continue;
-                    }
-                    long buildSideNdv = getBuildSideNdv(join, equalTo);
-                    Pair<PhysicalRelation, Slot> pair = 
ctx.getAliasTransferMap().get(equalTo.right());
-                    // CteConsumer is not allowed to generate RF in order to 
avoid RF cycle.
-                    if ((pair == null && buildSideContainsConsumer)
-                            || (pair != null && pair.first instanceof 
PhysicalCTEConsumer)) {
-                        continue;
-                    }
-                    join.pushDownRuntimeFilter(context, generator, join, 
equalTo.right(),
-                            equalTo.left(), type, buildSideNdv, i);
+            EqualPredicate equalTo = JoinUtils.swapEqualToForChildrenOrder(
+                    (EqualPredicate) hashJoinConjuncts.get(i), 
join.left().getOutputSet());
+            for (TRuntimeFilterType type : legalTypes) {
+                //bitmap rf is generated by nested loop join.
+                if (type == TRuntimeFilterType.BITMAP) {
+                    continue;
+                }
+                long buildSideNdv = getBuildSideNdv(join, equalTo);
+                Pair<PhysicalRelation, Slot> pair = 
ctx.getAliasTransferMap().get(equalTo.right());
+                // CteConsumer is not allowed to generate RF in order to avoid 
RF cycle.
+                if ((pair == null && buildSideContainsConsumer)
+                        || (pair != null && pair.first instanceof 
PhysicalCTEConsumer)) {
+                    continue;
                 }
+                join.pushDownRuntimeFilter(context, generator, join, 
equalTo.right(),
+                        equalTo.left(), type, buildSideNdv, i);
             }
         }
         return join;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
index e7b3a308f0f..fdf9820c582 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
@@ -22,6 +22,7 @@ import 
org.apache.doris.nereids.rules.expression.rules.CaseWhenToIf;
 import org.apache.doris.nereids.rules.expression.rules.DateFunctionRewrite;
 import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule;
 import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
+import org.apache.doris.nereids.rules.expression.rules.NullSafeEqualToEqual;
 import org.apache.doris.nereids.rules.expression.rules.OrToIn;
 import 
org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
 import 
org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison;
@@ -48,7 +49,8 @@ public class ExpressionOptimization extends ExpressionRewrite 
{
             OrToIn.INSTANCE,
             ArrayContainToArrayOverlap.INSTANCE,
             CaseWhenToIf.INSTANCE,
-            TopnToMax.INSTANCE
+            TopnToMax.INSTANCE,
+            NullSafeEqualToEqual.INSTANCE
     );
     private static final ExpressionRuleExecutor EXECUTOR = new 
ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java
new file mode 100644
index 00000000000..c215e65f722
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rules;
+
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
+import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+
+/**
+ * convert "<=>" to "=", if any side is not nullable
+ * convert "A <=> null" to "A is null"
+ */
+public class NullSafeEqualToEqual extends 
DefaultExpressionRewriter<ExpressionRewriteContext> implements
+        ExpressionRewriteRule<ExpressionRewriteContext> {
+    public static final NullSafeEqualToEqual INSTANCE = new 
NullSafeEqualToEqual();
+
+    @Override
+    public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
+        return expr.accept(this, null);
+    }
+
+    @Override
+    public Expression visitNullSafeEqual(NullSafeEqual nullSafeEqual, 
ExpressionRewriteContext ctx) {
+        if (nullSafeEqual.left() instanceof NullLiteral) {
+            if (nullSafeEqual.right().nullable()) {
+                return new IsNull(nullSafeEqual.right());
+            } else {
+                return BooleanLiteral.FALSE;
+            }
+        } else if (nullSafeEqual.right() instanceof NullLiteral) {
+            if (nullSafeEqual.left().nullable()) {
+                return new IsNull(nullSafeEqual.left());
+            } else {
+                return BooleanLiteral.FALSE;
+            }
+        } else if (!nullSafeEqual.left().nullable() || 
!nullSafeEqual.right().nullable()) {
+            return new EqualTo(nullSafeEqual.left(), nullSafeEqual.right());
+        }
+        return nullSafeEqual;
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/FindHashConditionForJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/FindHashConditionForJoin.java
index 2826d709065..86bcdf4487b 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/FindHashConditionForJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/FindHashConditionForJoin.java
@@ -63,7 +63,8 @@ public class FindHashConditionForJoin extends 
OneRewriteRuleFactory {
             }
 
             List<Expression> combinedHashJoinConjuncts = Streams
-                    .concat(join.getHashJoinConjuncts().stream(), 
extractedHashJoinConjuncts.stream())
+                    .concat(join.getHashJoinConjuncts().stream(),
+                            extractedHashJoinConjuncts.stream())
                     .distinct()
                     .collect(ImmutableList.toImmutableList());
             JoinType joinType = join.getJoinType();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalPlan.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalPlan.java
index 1e9135d600e..9f71bda0b40 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalPlan.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalPlan.java
@@ -24,7 +24,9 @@ import 
org.apache.doris.nereids.processor.post.RuntimeFilterContext;
 import org.apache.doris.nereids.processor.post.RuntimeFilterGenerator;
 import org.apache.doris.nereids.properties.LogicalProperties;
 import org.apache.doris.nereids.properties.PhysicalProperties;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.plans.AbstractPlan;
 import org.apache.doris.nereids.trees.plans.Explainable;
@@ -131,6 +133,14 @@ public abstract class AbstractPhysicalPlan extends 
AbstractPlan implements Physi
                 
ctx.setTargetsOnScanNode(ctx.getAliasTransferPair(probeSlot).first, scanSlot);
             }
         } else {
+            // null safe equal runtime filter only support bloom filter
+            EqualPredicate eq = (EqualPredicate) 
builderNode.getHashJoinConjuncts().get(exprOrder);
+            if (eq instanceof NullSafeEqual && type == 
TRuntimeFilterType.IN_OR_BLOOM) {
+                type = TRuntimeFilterType.BLOOM;
+            }
+            if (eq instanceof NullSafeEqual && type != 
TRuntimeFilterType.BLOOM) {
+                return false;
+            }
             filter = new RuntimeFilter(generator.getNextId(),
                     src, ImmutableList.of(scanSlot), 
ImmutableList.of(probeExpr),
                     type, exprOrder, builderNode, buildSideNdv,
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/planner/HashJoinNode.java 
b/fe/fe-core/src/main/java/org/apache/doris/planner/HashJoinNode.java
index 1413d42ceee..d8cc4a77a0a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/HashJoinNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/HashJoinNode.java
@@ -199,12 +199,6 @@ public class HashJoinNode extends JoinNodeBase {
         for (Expr eqJoinPredicate : eqJoinConjuncts) {
             Preconditions.checkArgument(eqJoinPredicate instanceof 
BinaryPredicate);
             BinaryPredicate eqJoin = (BinaryPredicate) eqJoinPredicate;
-            if (eqJoin.getOp().equals(BinaryPredicate.Operator.EQ_FOR_NULL)) {
-                Preconditions.checkArgument(eqJoin.getChildren().size() == 2);
-                if (!eqJoin.getChild(0).isNullable() || 
!eqJoin.getChild(1).isNullable()) {
-                    eqJoin.setOp(BinaryPredicate.Operator.EQ);
-                }
-            }
             this.eqJoinConjuncts.add(eqJoin);
         }
         this.distrMode = DistributionMode.NONE;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/planner/RuntimeFilter.java 
b/fe/fe-core/src/main/java/org/apache/doris/planner/RuntimeFilter.java
index ade087e2de0..00117beae72 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/RuntimeFilter.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/RuntimeFilter.java
@@ -230,6 +230,15 @@ public final class RuntimeFilter {
         }
         tFilter.setOptRemoteRf(hasRemoteTargets);
         
tFilter.setBloomFilterSizeCalculatedByNdv(bloomFilterSizeCalculatedByNdv);
+        if (builderNode instanceof HashJoinNode) {
+            HashJoinNode join = (HashJoinNode) builderNode;
+            BinaryPredicate eq = join.getEqJoinConjuncts().get(exprOrder);
+            if (eq.getOp().equals(BinaryPredicate.Operator.EQ_FOR_NULL)) {
+                tFilter.setNullAware(true);
+            } else {
+                tFilter.setNullAware(false);
+            }
+        }
         return tFilter;
     }
 
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java
index 64f32ed2e24..872a149b176 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java
@@ -94,6 +94,7 @@ import org.apache.doris.thrift.TPipelineFragmentParams;
 import org.apache.doris.thrift.TPipelineFragmentParamsList;
 import org.apache.doris.thrift.TPipelineInstanceParams;
 import org.apache.doris.thrift.TPipelineWorkloadGroup;
+import org.apache.doris.thrift.TPlanFragment;
 import org.apache.doris.thrift.TPlanFragmentDestination;
 import org.apache.doris.thrift.TPlanFragmentExecParams;
 import org.apache.doris.thrift.TQueryGlobals;
@@ -3703,6 +3704,7 @@ public class Coordinator implements CoordInterface {
 
             Map<TNetworkAddress, TPipelineFragmentParams> res = new HashMap();
             Map<TNetworkAddress, Integer> instanceIdx = new HashMap();
+            TPlanFragment fragmentThrift = fragment.toThrift();
             for (int i = 0; i < instanceExecParams.size(); ++i) {
                 final FInstanceExecParam instanceExecParam = 
instanceExecParams.get(i);
                 Map<Integer, List<TScanRangeParams>> scanRanges = 
instanceExecParam.perNodeScanRanges;
@@ -3728,7 +3730,7 @@ public class Coordinator implements CoordInterface {
                     params.query_options.setMemLimit(memLimit);
                     params.setSendQueryStatisticsWithEveryBatch(
                             
fragment.isTransferQueryStatisticsWithEveryBatch());
-                    params.setFragment(fragment.toThrift());
+                    params.setFragment(fragmentThrift);
                     params.setLocalParams(Lists.newArrayList());
                     if (tWorkloadGroups != null) {
                         params.setWorkloadGroups(tWorkloadGroups);
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java
new file mode 100644
index 00000000000..3a2b4eb2a6e
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java
@@ -0,0 +1,69 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rules;
+
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
+import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
+import org.apache.doris.nereids.types.StringType;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Test;
+
+class NullSafeEqualToEqualTest extends ExpressionRewriteTestHelper {
+
+    // "A<=> Null" to "A is null"
+    @Test
+    void testNullSafeEqualToIsNull() {
+        executor = new 
ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE));
+        SlotReference slot = new SlotReference("a", StringType.INSTANCE, true);
+        assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), new 
IsNull(slot));
+    }
+
+    // "A<=> Null" to "False", when A is not nullable
+    @Test
+    void testNullSafeEqualToFalse() {
+        executor = new 
ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE));
+        SlotReference slot = new SlotReference("a", StringType.INSTANCE, 
false);
+        assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), 
BooleanLiteral.FALSE);
+    }
+
+    // "A<=> "abc" to "A = "abc"
+    @Test
+    void testNullSafeEqualToEqual() {
+        executor = new 
ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE));
+        SlotReference slot = new SlotReference("a", StringType.INSTANCE, true);
+        StringLiteral str = new StringLiteral("abc");
+        assertRewrite(new NullSafeEqual(slot, str), new EqualTo(slot, str));
+    }
+
+    // "A<=>B" not changed
+    @Test
+    void testNullSafeEqualNotChanged() {
+        executor = new 
ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE));
+        SlotReference a = new SlotReference("a", StringType.INSTANCE, true);
+        SlotReference b = new SlotReference("b", StringType.INSTANCE, true);
+        assertRewrite(new NullSafeEqual(a, b), new NullSafeEqual(a, b));
+    }
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByUniqueTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByUniqueTest.java
index c3b9c9de005..778b65f992c 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByUniqueTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByUniqueTest.java
@@ -56,13 +56,6 @@ class EliminateJoinByUniqueTest extends TestWithFeService 
implements MemoPattern
                 .nonMatch(logicalJoin())
                 .printlnTree();
 
-        sql = "select t1.id1 from t1 left outer join t2 on t1.id1 <=> t2.id2";
-        PlanChecker.from(connectContext)
-                .analyze(sql)
-                .rewrite()
-                .matches(logicalJoin())
-                .printlnTree();
-
         sql = "select t2.id2 from t1 left outer join t2 on t1.id1 = t2.id2";
         PlanChecker.from(connectContext)
                 .analyze(sql)
@@ -80,13 +73,6 @@ class EliminateJoinByUniqueTest extends TestWithFeService 
implements MemoPattern
                 .nonMatch(logicalJoin())
                 .printlnTree();
 
-        sql = "select t1.id1 from t1 left outer join t2 on t1.id_null <=> 
t2.id2";
-        PlanChecker.from(connectContext)
-                .analyze(sql)
-                .rewrite()
-                .matches(logicalJoin())
-                .printlnTree();
-
         sql = "select t2.id2 from t1 left outer join t2 on t1.id_null = 
t2.id2";
         PlanChecker.from(connectContext)
                 .analyze(sql)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to