lsyldliu commented on code in PR #22966:

@@ -0,0 +1,678 @@
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.planner.plan.optimize.program;
+import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import org.apache.flink.table.planner.plan.nodes.FlinkConventions;
+import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution;
+import org.apache.flink.table.planner.plan.utils.DefaultRelShuttle;
+import org.apache.flink.table.planner.plan.utils.FlinkRelMdUtil;
+import org.apache.flink.table.planner.plan.utils.JoinUtil;
+import org.apache.calcite.plan.RelTraitSet;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Calc;
+import org.apache.calcite.rel.core.Exchange;
+import org.apache.calcite.rel.core.Join;
+import org.apache.calcite.rel.core.JoinInfo;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.core.Union;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.ImmutableIntList;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.BiFunction;
+import static 
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+ * Planner program that tries to inject runtime filter for suitable join to 
improve join
+ * performance.
+ *
+ * <p>We build the runtime filter in a two-phase manner: First, each subtask 
on the build side
+ * builds a local filter based on its local data, and sends the built filter 
to a global aggregation
+ * node. Then the global aggregation node aggregates the received filters into 
a global filter, and
+ * sends the global filter to all probe side subtasks. Therefore, we will add 
+ * BatchPhysicalLocalRuntimeFilterBuilder}, {@link 
BatchPhysicalGlobalRuntimeFilterBuilder} and
+ * {@link BatchPhysicalRuntimeFilter} into the physical plan.
+ *
+ * <p>For example, for the following query:
+ *
+ * <pre>{@code SELECT * FROM fact, dim WHERE x = a AND z = 2}</pre>
+ *
+ * <p>The original physical plan:
+ *
+ * <pre>{@code
+ * Calc(select=[a, b, c, x, y, CAST(2 AS BIGINT) AS z])
+ * +- HashJoin(joinType=[InnerJoin], where=[=(x, a)], select=[a, b, c, x, y], 
+ *    :- Exchange(distribution=[hash[a]])
+ *    :  +- TableSourceScan(table=[[fact]], fields=[a, b, c])
+ *    +- Exchange(distribution=[hash[x]])
+ *       +- Calc(select=[x, y], where=[=(z, 2)])
+ *          +- TableSourceScan(table=[[dim, filter=[]]], fields=[x, y, z])
+ * }</pre>
+ *
+ * <p>This optimized physical plan:
+ *
+ * <pre>{@code
+ * Calc(select=[a, b, c, x, y, CAST(2 AS BIGINT) AS z])
+ * +- HashJoin(joinType=[InnerJoin], where=[=(x, a)], select=[a, b, c, x, y], 
+ *    :- Exchange(distribution=[hash[a]])
+ *    :  +- RuntimeFilter(select=[a])
+ *    :     :- Exchange(distribution=[broadcast])
+ *    :     :  +- GlobalRuntimeFilterBuilder
+ *    :     :     +- Exchange(distribution=[single])
+ *    :     :        +- LocalRuntimeFilterBuilder(select=[x])
+ *    :     :           +- Calc(select=[x, y], where=[=(z, 2)])
+ *    :     :              +- TableSourceScan(table=[[dim, filter=[]]], 
fields=[x, y, z])
+ *    :     +- TableSourceScan(table=[[fact]], fields=[a, b, c])
+ *    +- Exchange(distribution=[hash[x]])
+ *       +- Calc(select=[x, y], where=[=(z, 2)])
+ *          +- TableSourceScan(table=[[dim, filter=[]]], fields=[x, y, z])
+ *
+ * }</pre>
+ */
+public class FlinkRuntimeFilterProgram implements 
FlinkOptimizeProgram<BatchOptimizeContext> {
+    @Override
+    public RelNode optimize(RelNode root, BatchOptimizeContext context) {
+        if (!isRuntimeFilterEnabled(root)) {
+            return root;
+        }
+        // To avoid that one side can be used both as a build side and as a 
probe side
+        checkState(
+                getMinProbeDataSize(root) > getMaxBuildDataSize(root),
+                "The min probe data size should be larger than the max build 
data size.");
+        DefaultRelShuttle shuttle =
+                new DefaultRelShuttle() {
+                    @Override
+                    public RelNode visit(RelNode rel) {
+                        if (!(rel instanceof Join)) {
+                            List<RelNode> newInputs = new ArrayList<>();
+                            for (RelNode input : rel.getInputs()) {
+                                RelNode newInput = input.accept(this);
+                                newInputs.add(newInput);
+                            }
+                            return rel.copy(rel.getTraitSet(), newInputs);
+                        }
+                        Join join = (Join) rel;
+                        RelNode newLeft = join.getLeft().accept(this);
+                        RelNode newRight = join.getRight().accept(this);
+                        return tryInjectRuntimeFilter(
+                                join.copy(join.getTraitSet(), 
Arrays.asList(newLeft, newRight)));
+                    }
+                };
+        return shuttle.visit(root);
+    }
+    /**
+     * Judge whether the join is suitable, and try to inject runtime filter 
for it.
+     *
+     * @param join the join node
+     * @return the new join node with runtime filter.
+     */
+    private static Join tryInjectRuntimeFilter(Join join) {
+        // check supported join type
+        if (join.getJoinType() != JoinRelType.INNER
+                && join.getJoinType() != JoinRelType.SEMI
+                && join.getJoinType() != JoinRelType.LEFT
+                && join.getJoinType() != JoinRelType.RIGHT) {
+            return join;
+        }
+        // check supported join implementation
+        if (!(join instanceof BatchPhysicalHashJoin)
+                && !(join instanceof BatchPhysicalSortMergeJoin)) {
+            return join;
+        }
+        boolean leftIsBuild;
+        if (canBeProbeSide(join.getLeft())) {
+            leftIsBuild = false;
+        } else if (canBeProbeSide(join.getRight())) {
+            leftIsBuild = true;
+        } else {
+            return join;
+        }
+        // check left join + left build
+        if (join.getJoinType() == JoinRelType.LEFT && !leftIsBuild) {
+            return join;
+        }
+        // check right join + right build
+        if (join.getJoinType() == JoinRelType.RIGHT && leftIsBuild) {
+            return join;
+        }
+        JoinInfo joinInfo = join.analyzeCondition();
+        RelNode buildSide;
+        RelNode probeSide;
+        ImmutableIntList buildIndices;
+        ImmutableIntList probeIndices;
+        if (leftIsBuild) {
+            buildSide = join.getLeft();
+            probeSide = join.getRight();
+            buildIndices = joinInfo.leftKeys;
+            probeIndices = joinInfo.rightKeys;
+        } else {
+            buildSide = join.getRight();
+            probeSide = join.getLeft();
+            buildIndices = joinInfo.rightKeys;
+            probeIndices = joinInfo.leftKeys;
+        }
+        Optional<BuildSideInfo> suitableBuildOpt =
+                findSuitableBuildSide(
+                        buildSide,
+                        buildIndices,
+                        (build, indices) ->
+                                isSuitableDataSize(build, probeSide, indices, 
+        if (suitableBuildOpt.isPresent()) {
+            BuildSideInfo suitableBuildInfo = suitableBuildOpt.get();
+            RelNode newProbe =
+                    tryPushDownProbeAndInjectRuntimeFilter(
+                            probeSide, probeIndices, suitableBuildInfo);
+            if (leftIsBuild) {
+                return join.copy(join.getTraitSet(), Arrays.asList(buildSide, 
+            } else {
+                return join.copy(join.getTraitSet(), Arrays.asList(newProbe, 
+            }
+        }
+        return join;
+    }
+    /**
+     * Inject runtime filter and return the new probe side (without exchange).
+     *
+     * @param buildSide the build side
+     * @param probeSide the probe side
+     * @param buildIndices the build projection
+     * @param probeIndices the probe projection
+     * @return the new probe side
+     */
+    private static RelNode createNewProbeWithRuntimeFilter(
+            RelNode buildSide,
+            RelNode probeSide,
+            ImmutableIntList buildIndices,
+            ImmutableIntList probeIndices) {
+        Optional<Double> buildRowCountOpt = getEstimatedRowCount(buildSide);
+        checkState(buildRowCountOpt.isPresent());
+        int buildRowCount = buildRowCountOpt.get().intValue();
+        int maxRowCount =
+                (int)
+                        Math.ceil(
+                                getMaxBuildDataSize(buildSide)
+                                        / 
+        double filterRatio = computeFilterRatio(buildSide, probeSide, 
buildIndices, probeIndices);
+        String[] buildFiledNames =
+                        .map(buildSide.getRowType().getFieldNames()::get)
+                        .toArray(String[]::new);
+        RelNode localBuilder =
+                new BatchPhysicalLocalRuntimeFilterBuilder(
+                        buildSide.getCluster(),
+                        buildSide.getTraitSet(),
+                        buildSide,
+                        buildIndices.toIntArray(),
+                        buildFiledNames,
+                        buildRowCount,
+                        maxRowCount);
+        RelNode globalBuilder =
+                new BatchPhysicalGlobalRuntimeFilterBuilder(
+                        localBuilder.getCluster(),
+                        localBuilder.getTraitSet(),
+                        createExchange(localBuilder, 
+                        buildFiledNames,
+                        buildRowCount,
+                        maxRowCount);
+        RelNode runtimeFilter =
+                new BatchPhysicalRuntimeFilter(
+                        probeSide.getCluster(),
+                        probeSide.getTraitSet(),
+                        createExchange(globalBuilder, 
+                        probeSide,
+                        probeIndices.toIntArray(),
+                        filterRatio);
+        return runtimeFilter;
+    }
+    /**
+     * Find a suitable build side. In order not to affect MultiInput, when the 
original build side
+     * of runtime filter is not an {@link Exchange}, we need to push down the 
builder, until we find
+     * an exchange and inject the builder there.
+     *
+     * @param rel the original build side
+     * @param buildIndices build indices
+     * @param buildSideChecker check whether current build side is suitable
+     * @return An optional info of the suitable build side.It will be empty if 
we cannot find the
+     *     suitable build side.
+     */
+    private static Optional<BuildSideInfo> findSuitableBuildSide(
+            RelNode rel,
+            ImmutableIntList buildIndices,
+            BiFunction<RelNode, ImmutableIntList, Boolean> buildSideChecker) {
+        if (rel instanceof Exchange) {
+            // found the desired exchange, inject builder here
+            Exchange exchange = (Exchange) rel;
+            if (!(exchange.getInput() instanceof BatchPhysicalRuntimeFilter)
+                    && buildSideChecker.apply(exchange.getInput(), 
buildIndices)) {
+                return Optional.of(new BuildSideInfo(exchange.getInput(), 
+            }
+        } else if (rel instanceof BatchPhysicalRuntimeFilter) {
+            // runtime filter should not as build side
+            return Optional.empty();
+        } else if (rel instanceof Calc) {
+            // try to push the builder to input of projection
+            Calc calc = ((Calc) rel);
+            RexProgram program = calc.getProgram();
+            List<RexNode> projects =
+                    program.getProjectList().stream()
+                            .map(program::expandLocalRef)
+                            .collect(Collectors.toList());
+            ImmutableIntList inputIndices = getInputIndices(projects, 
+            if (inputIndices.isEmpty()) {
+                return Optional.empty();
+            }
+            return findSuitableBuildSide(calc.getInput(), inputIndices, 
+        } else if (rel instanceof Join) {
+            // try to push the builder to one input of join
+            Join join = (Join) rel;
+            if (!(join.getLeft() instanceof Exchange) && !(join.getRight() 
instanceof Exchange)) {
+                return Optional.empty();
+            }
+            Tuple2<ImmutableIntList, ImmutableIntList> tuple2 = 
getInputIndices(join, buildIndices);
+            ImmutableIntList leftIndices = tuple2.f0;
+            ImmutableIntList rightIndices = tuple2.f1;
+            if (leftIndices.isEmpty() && rightIndices.isEmpty()) {
+                return Optional.empty();
+            }
+            boolean firstCheckLeft = !leftIndices.isEmpty() && join.getLeft() 
instanceof Exchange;
+            Optional<BuildSideInfo> buildSideInfoOpt = Optional.empty();
+            if (firstCheckLeft) {
+                buildSideInfoOpt =
+                        findSuitableBuildSide(join.getLeft(), leftIndices, 
+                if (!buildSideInfoOpt.isPresent() && !rightIndices.isEmpty()) {
+                    buildSideInfoOpt =
+                            findSuitableBuildSide(join.getRight(), 
rightIndices, buildSideChecker);
+                }
+                return buildSideInfoOpt;
+            } else {
+                if (!rightIndices.isEmpty()) {
+                    buildSideInfoOpt =
+                            findSuitableBuildSide(join.getRight(), 
rightIndices, buildSideChecker);
+                    if (!buildSideInfoOpt.isPresent() && 
!leftIndices.isEmpty()) {
+                        buildSideInfoOpt =
+                                findSuitableBuildSide(
+                                        join.getLeft(), leftIndices, 
+                    }
+                }
+            }
+            return buildSideInfoOpt;
+        } else if (rel instanceof BatchPhysicalGroupAggregateBase) {
+            // try to push the builder to input of agg, iff the indices are 
all in grouping keys.
+            BatchPhysicalGroupAggregateBase agg = 
(BatchPhysicalGroupAggregateBase) rel;
+            int[] grouping = agg.grouping();
+            for (int k : buildIndices) {
+                if (k >= grouping.length) {
+                    return Optional.empty();
+                }
+            }
+            return findSuitableBuildSide(
+                    agg.getInput(),
+                    ImmutableIntList.copyOf(
+                                    .map(index -> agg.grouping()[index])
+                                    .collect(Collectors.toList())),
+                    buildSideChecker);
+        } else {
+            // the above cases can cover all cases of TPC-DS test
+            // we may find more cases later
+        }
+        return Optional.empty();
+    }
+    /**
+     * Try to push down the probe side of runtime filter, and inject the 
runtime filter.
+     *
+     * @param rel the original probe side
+     * @param probeIndices the probe indices
+     * @param buildSideInfo the build side info
+     * @return the new probe side wit runtime filter
+     */
+    private static RelNode tryPushDownProbeAndInjectRuntimeFilter(
+            RelNode rel, ImmutableIntList probeIndices, BuildSideInfo 
buildSideInfo) {
+        if (rel instanceof BatchPhysicalRuntimeFilter) {
+            // do nothing, return current probe side directly. Because we 
don't inject more than
+            // once runtime filter at the same place
+            return rel;
+        } else if (rel instanceof Exchange) {
+            // try to push the probe side to the input of exchange
+            Exchange exchange = (Exchange) rel;
+            return exchange.copy(
+                    exchange.getTraitSet(),
+                    Collections.singletonList(
+                            tryPushDownProbeAndInjectRuntimeFilter(
+                                    exchange.getInput(), probeIndices, 
+        } else if (rel instanceof Calc) {
+            // try to push the probe side to the input of projection
+            Calc calc = ((Calc) rel);
+            RexProgram program = calc.getProgram();
+            List<RexNode> projects =
+                    program.getProjectList().stream()
+                            .map(program::expandLocalRef)
+                            .collect(Collectors.toList());
+            ImmutableIntList inputIndices = getInputIndices(projects, 
+            if (!inputIndices.isEmpty()) {
+                return calc.copy(
+                        calc.getTraitSet(),
+                        Collections.singletonList(
+                                tryPushDownProbeAndInjectRuntimeFilter(
+                                        calc.getInput(), inputIndices, 
+            }
+        } else if (rel instanceof Join) {
+            // try to push the probe side to the all inputs of join
+            Join join = (Join) rel;
+            Tuple2<ImmutableIntList, ImmutableIntList> tuple2 = 
getInputIndices(join, probeIndices);
+            ImmutableIntList leftIndices = tuple2.f0;
+            ImmutableIntList rightIndices = tuple2.f1;
+            if (!leftIndices.isEmpty() || !rightIndices.isEmpty()) {
+                RelNode leftSide = join.getLeft();
+                RelNode rightSide = join.getRight();
+                if (!leftIndices.isEmpty()) {
+                    leftSide =
+                            tryPushDownProbeAndInjectRuntimeFilter(
+                                    leftSide, leftIndices, buildSideInfo);
+                }
+                if (!rightIndices.isEmpty()) {
+                    rightSide =
+                            tryPushDownProbeAndInjectRuntimeFilter(
+                                    rightSide, rightIndices, buildSideInfo);
+                }
+                return join.copy(join.getTraitSet(), Arrays.asList(leftSide, 
+            }
+        } else if (rel instanceof BatchPhysicalGroupAggregateBase) {
+            // try to push the probe side to input of agg, iff the indices are 
all in grouping keys.
+            BatchPhysicalGroupAggregateBase agg = 
(BatchPhysicalGroupAggregateBase) rel;
+            int[] grouping = agg.grouping();
+            if ( -> (index < 
grouping.length))) {

Review Comment:
   It seems that this judge's logic is not correct? The `index` is not 
necessarily smaller than `grouping.length`.

@@ -0,0 +1,417 @@
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * withOUT WARRANTIES OR ConDITIonS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.planner.plan.optimize.program;
+import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import org.apache.flink.table.catalog.ObjectPath;
+import org.apache.flink.table.catalog.stats.CatalogColumnStatistics;
+import org.apache.flink.table.catalog.stats.CatalogColumnStatisticsDataLong;
+import org.apache.flink.table.catalog.stats.CatalogTableStatistics;
+import org.apache.flink.table.planner.factories.TestValuesCatalog;
+import org.apache.flink.table.planner.utils.BatchTableTestUtil;
+import org.apache.flink.table.planner.utils.TableTestBase;
+import org.junit.Before;
+import org.junit.Test;
+import java.util.Collections;
+/** Test for {@link FlinkRuntimeFilterProgram}. */
+public class FlinkRuntimeFilterProgramTest extends TableTestBase {

Review Comment:
   Please also add some ITCase to verify the overall process.

This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail:

For queries about this service, please contact Infrastructure at:

Reply via email to