godfreyhe commented on code in PR #21530: URL: https://github.com/apache/flink/pull/21530#discussion_r1057178039
########## docs/layouts/shortcodes/generated/optimizer_config_configuration.html: ########## @@ -89,5 +89,17 @@ <td>Boolean</td> <td>When it is true, the optimizer will collect and use the statistics from source connectors if the source extends from SupportsStatisticReport and the statistics from catalog is UNKNOWN.Default value is true.</td> </tr> + <tr> + <td><h5>table.optimizer.busy-join-reorder</h5><br> <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span></td> + <td style="word-wrap: break-word;">false</td> + <td>Boolean</td> + <td>Enables busy join reorder in optimizer. Default is disabled.</td> + </tr> + <tr> + <td><h5>table.optimizer.busy-join-reorder-dp-threshold</h5><br> <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span></td> Review Comment: for introducing public option, we should post a FLIP or a DISCUSS. we can remove keyword `dp` from the option since both join reorder programs using dynamic programming algorithm. just like: `table.optimizer.busy-join-reorder-threshold` ########## flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkBusyJoinReorderRule.java: ########## @@ -0,0 +1,688 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.planner.plan.cost.FlinkCost; +import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil; +import org.apache.flink.table.planner.utils.ShortcutUtils; + +import org.apache.calcite.plan.RelOptCost; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.rules.LoptMultiJoin; +import org.apache.calcite.rel.rules.LoptOptimizeJoinRule; +import org.apache.calcite.rel.rules.MultiJoin; +import org.apache.calcite.rel.rules.TransformationRule; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.sql.SqlExplainLevel; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.tools.RelBuilderFactory; +import org.apache.calcite.util.ImmutableBitSet; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +/** + * Flink busy join reorder rule, which will convert {@link MultiJoin} to a busy join tree. + * + * <p>In this busy join reorder strategy, we will first try to reorder all the inner join type + * inputs in the multiJoin, and then add all outer join type inputs on the top. + * + * <p>First, reordering all the inner join type inputs in the multiJoin. We adopt the concept of + * level in dynamic programming, and the latter layer will use the results stored in the previous + * layers. First, we put all inputs (each input in {@link MultiJoin}) into level 0, then we build + * all two-inputs join at level 1 based on the {@link FlinkCost} of level 0, then we will build + * three-inputs join based on the previous two levels, then four-inputs joins ... etc, util we + * reorder all the inner join type inputs in the multiJoin. When building m-inputs join, we only + * keep the best plan (have the lowest {@link FlinkCost}) for the same set of m inputs. E.g., for + * three-inputs join, we keep only the best plan for inputs {A, B, C} among plans (A J B) J C, (A J + * C) J B, (B J C) J A. + * + * <p>Second, we will add all outer join type inputs in the MultiJoin on the top. + */ +public class FlinkBusyJoinReorderRule extends RelRule<FlinkBusyJoinReorderRule.Config> + implements TransformationRule { + + public static final LoptOptimizeJoinRule MULTI_JOIN_OPTIMIZE = + LoptOptimizeJoinRule.Config.DEFAULT.toRule(); + + /** Creates an SparkJoinReorderRule. */ + protected FlinkBusyJoinReorderRule(Config config) { + super(config); + } + + @Deprecated // to be removed before 2.0 + public FlinkBusyJoinReorderRule(RelBuilderFactory relBuilderFactory) { + this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(Config.class)); + } + + @Deprecated // to be removed before 2.0 + public FlinkBusyJoinReorderRule( + RelFactories.JoinFactory joinFactory, + RelFactories.ProjectFactory projectFactory, + RelFactories.FilterFactory filterFactory) { + this(RelBuilder.proto(joinFactory, projectFactory, filterFactory)); + } + + @Override + public void onMatch(RelOptRuleCall call) { Review Comment: If's better we can introduce a `FlinkJoinReorderRule`, which could choose the specific join reorder algorithm based on `dpThreshold ` ########## flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkBusyJoinReorderRule.java: ########## @@ -0,0 +1,688 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.planner.plan.cost.FlinkCost; +import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil; +import org.apache.flink.table.planner.utils.ShortcutUtils; + +import org.apache.calcite.plan.RelOptCost; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.rules.LoptMultiJoin; +import org.apache.calcite.rel.rules.LoptOptimizeJoinRule; +import org.apache.calcite.rel.rules.MultiJoin; +import org.apache.calcite.rel.rules.TransformationRule; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.sql.SqlExplainLevel; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.tools.RelBuilderFactory; +import org.apache.calcite.util.ImmutableBitSet; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +/** + * Flink busy join reorder rule, which will convert {@link MultiJoin} to a busy join tree. + * + * <p>In this busy join reorder strategy, we will first try to reorder all the inner join type + * inputs in the multiJoin, and then add all outer join type inputs on the top. + * + * <p>First, reordering all the inner join type inputs in the multiJoin. We adopt the concept of + * level in dynamic programming, and the latter layer will use the results stored in the previous + * layers. First, we put all inputs (each input in {@link MultiJoin}) into level 0, then we build + * all two-inputs join at level 1 based on the {@link FlinkCost} of level 0, then we will build + * three-inputs join based on the previous two levels, then four-inputs joins ... etc, util we + * reorder all the inner join type inputs in the multiJoin. When building m-inputs join, we only + * keep the best plan (have the lowest {@link FlinkCost}) for the same set of m inputs. E.g., for + * three-inputs join, we keep only the best plan for inputs {A, B, C} among plans (A J B) J C, (A J + * C) J B, (B J C) J A. + * + * <p>Second, we will add all outer join type inputs in the MultiJoin on the top. + */ +public class FlinkBusyJoinReorderRule extends RelRule<FlinkBusyJoinReorderRule.Config> + implements TransformationRule { + + public static final LoptOptimizeJoinRule MULTI_JOIN_OPTIMIZE = + LoptOptimizeJoinRule.Config.DEFAULT.toRule(); + + /** Creates an SparkJoinReorderRule. */ + protected FlinkBusyJoinReorderRule(Config config) { + super(config); + } + + @Deprecated // to be removed before 2.0 + public FlinkBusyJoinReorderRule(RelBuilderFactory relBuilderFactory) { + this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(Config.class)); + } + + @Deprecated // to be removed before 2.0 + public FlinkBusyJoinReorderRule( + RelFactories.JoinFactory joinFactory, + RelFactories.ProjectFactory projectFactory, + RelFactories.FilterFactory filterFactory) { + this(RelBuilder.proto(joinFactory, projectFactory, filterFactory)); + } + + @Override + public void onMatch(RelOptRuleCall call) { + final RelBuilder relBuilder = call.builder(); + final MultiJoin multiJoinRel = call.rel(0); + Boolean enableBusyJoin = + ShortcutUtils.unwrapContext(multiJoinRel) + .getTableConfig() + .get(OptimizerConfigOptions.TABLE_OPTIMIZER_BUSY_JOIN_REORDER); + Integer dpThreshold = + ShortcutUtils.unwrapContext(multiJoinRel) + .getTableConfig() + .get(OptimizerConfigOptions.TABLE_OPTIMIZER_BUSY_JOIN_REORDER_DP_THRESHOLD); + final LoptMultiJoin multiJoin = new LoptMultiJoin(multiJoinRel); + int numJoinFactors = multiJoin.getNumJoinFactors(); + + if (enableBusyJoin && numJoinFactors <= dpThreshold) { + RelNode bestOrder = findBestOrder(call.getMetadataQuery(), relBuilder, multiJoin); + call.transformTo(bestOrder); + } else { + MULTI_JOIN_OPTIMIZE.onMatch(call); + } + } + + private static RelNode findBestOrder( + RelMetadataQuery mq, RelBuilder relBuilder, LoptMultiJoin multiJoin) { + // In our busy join reorder strategy, we will first try to reorder all the inner join type + // inputs in the multiJoin, and then add all outer join type inputs on the top. + // First, reorder all the inner join type inputs in the multiJoin. + List<Map<Set<Integer>, JoinPlan>> foundPlans = reOrderInnerJoin(mq, relBuilder, multiJoin); + + JoinPlan finalPlan; + // Second, add all outer join type inputs in the multiJoin on the top. + if (canOuterJoin(multiJoin)) { + finalPlan = + addToTopForOuterJoin( + getBestPlan(foundPlans.get(foundPlans.size() - 1)), + multiJoin, + relBuilder); + } else { + if (foundPlans.size() != multiJoin.getNumJoinFactors()) { + finalPlan = + addToTop( + getBestPlan(foundPlans.get(foundPlans.size() - 1)), + multiJoin, + relBuilder); + } else { + assert foundPlans.get(foundPlans.size() - 1).size() == 1; + finalPlan = new ArrayList<>(foundPlans.get(foundPlans.size() - 1).values()).get(0); + } + } + + final List<String> fieldNames = multiJoin.getMultiJoinRel().getRowType().getFieldNames(); + return creatToProject(relBuilder, multiJoin, finalPlan, fieldNames); + } + + private static List<Map<Set<Integer>, JoinPlan>> reOrderInnerJoin( + RelMetadataQuery mq, RelBuilder relBuilder, LoptMultiJoin multiJoin) { + List<Map<Set<Integer>, JoinPlan>> foundPlans = new ArrayList<>(); + + // First, we put each input in MultiJoin into level 0. + Map<Set<Integer>, JoinPlan> joinPlanMap = new LinkedHashMap<>(); + for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) { + if (!multiJoin.isNullGenerating(i)) { + HashSet<Integer> set1 = new HashSet<>(); + LinkedHashSet<Integer> set2 = new LinkedHashSet<>(); + set1.add(i); + set2.add(i); + RelNode joinFactor = multiJoin.getJoinFactor(i); + RelOptCost cost = mq.getCumulativeCost(joinFactor); + joinPlanMap.put( + set1, + new JoinPlan( + set2, + joinFactor, + new FlinkCost( + cost.getRows(), cost.getCpu(), cost.getIo(), 0.0, 0.0))); + } + } + foundPlans.add(joinPlanMap); + + // Build plans for next levels until the last level has only one plan. This plan contains + // all inputs that can be joined, so there's no need to continue + while (foundPlans.size() < multiJoin.getNumJoinFactors()) { + Map<Set<Integer>, JoinPlan> levelPlan = + searchLevel(mq, relBuilder, new ArrayList<>(foundPlans), multiJoin, false); + if (levelPlan.size() == 0) { + break; + } + foundPlans.add(levelPlan); + } + + return foundPlans; + } + + private static boolean canOuterJoin(LoptMultiJoin multiJoin) { + for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) { + if (multiJoin.getOuterJoinCond(i) != null + && RelOptUtil.conjunctions(multiJoin.getOuterJoinCond(i)).size() != 0) { + return true; + } + } + return false; + } + + private static JoinPlan getBestPlan(Map<Set<Integer>, JoinPlan> levelPlan) { + JoinPlan bestPlan = null; + for (Map.Entry<Set<Integer>, JoinPlan> entry : levelPlan.entrySet()) { + if (bestPlan == null || entry.getValue().cost.isLt(bestPlan.cost)) { + bestPlan = entry.getValue(); + } + } + + return bestPlan; + } + + private static JoinPlan addToTopForOuterJoin( + JoinPlan bestPlan, LoptMultiJoin multiJoin, RelBuilder relBuilder) { + RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder(); + List<Integer> remainIndexes = new ArrayList<>(); + for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) { + if (!bestPlan.itemIds.contains(i)) { + remainIndexes.add(i); + } + } + + RelNode leftNode = bestPlan.relNode; + LinkedHashSet<Integer> set = new LinkedHashSet<>(bestPlan.itemIds); + for (int index : remainIndexes) { + RelNode rightNode = multiJoin.getJoinFactor(index); + + // make new join condition + Optional<Tuple2<Set<RexCall>, JoinRelType>> joinConds = + getConditionsAndJoinType( + bestPlan.itemIds, Collections.singleton(index), multiJoin, true); + + if (!joinConds.isPresent()) { + // join type is always left. + leftNode = + relBuilder + .push(leftNode) + .push(rightNode) + .join(JoinRelType.LEFT, rexBuilder.makeLiteral(true)) + .build(); + } else { + Set<RexCall> conditions = joinConds.get().f0; + List<RexNode> rexCalls = new ArrayList<>(conditions); + Set<RexCall> newCondition = + convertToNewCondition( + new ArrayList<>(set), + Collections.singletonList(index), + rexCalls, + multiJoin); + // all given left join. + leftNode = + relBuilder + .push(leftNode) + .push(rightNode) + .join(JoinRelType.LEFT, newCondition) + .build(); + } + set.add(index); + } + return new JoinPlan(set, leftNode, new FlinkCost(0.0, 0.0, 0.0, 0.0, 0.0)); + } + + private static JoinPlan addToTop( + JoinPlan bestPlan, LoptMultiJoin multiJoin, RelBuilder relBuilder) { + RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder(); + List<Integer> remainIndexes = new ArrayList<>(); + for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) { + if (!bestPlan.itemIds.contains(i)) { + remainIndexes.add(i); + } + } + + RelNode leftNode = bestPlan.relNode; + LinkedHashSet<Integer> set = new LinkedHashSet<>(bestPlan.itemIds); + for (int index : remainIndexes) { + set.add(index); + RelNode rightNode = multiJoin.getJoinFactor(index); + leftNode = + relBuilder + .push(leftNode) + .push(rightNode) + .join( + multiJoin.getMultiJoinRel().getJoinTypes().get(index), + rexBuilder.makeLiteral(true)) + .build(); + } + return new JoinPlan(set, leftNode, new FlinkCost(0.0, 0.0, 0.0, 0.0, 0.0)); + } + + private static RelNode creatToProject( + RelBuilder relBuilder, + LoptMultiJoin multiJoin, + JoinPlan finalPlan, + List<String> fieldNames) { + List<RexNode> newProjExprs = new ArrayList<>(); + RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder(); + + List<Integer> newJoinOrder = new ArrayList<>(finalPlan.itemIds); + int nJoinFactors = multiJoin.getNumJoinFactors(); + List<RelDataTypeField> fields = multiJoin.getMultiJoinFields(); + + // create a mapping from each factor to its field offset in the join + // ordering + final Map<Integer, Integer> factorToOffsetMap = new HashMap<>(); + for (int pos = 0, fieldStart = 0; pos < nJoinFactors; pos++) { + factorToOffsetMap.put(newJoinOrder.get(pos), fieldStart); + fieldStart += multiJoin.getNumFieldsInJoinFactor(newJoinOrder.get(pos)); + } + + for (int currFactor = 0; currFactor < nJoinFactors; currFactor++) { + // if the factor is the right factor in a removable self-join, + // then where possible, remap references to the right factor to + // the corresponding reference in the left factor + Integer leftFactor = null; + if (multiJoin.isRightFactorInRemovableSelfJoin(currFactor)) { + leftFactor = multiJoin.getOtherSelfJoinFactor(currFactor); + } + for (int fieldPos = 0; + fieldPos < multiJoin.getNumFieldsInJoinFactor(currFactor); + fieldPos++) { + int newOffset = + requireNonNull( + factorToOffsetMap.get(currFactor), + () -> "factorToOffsetMap.get(currFactor)") + + fieldPos; + if (leftFactor != null) { + Integer leftOffset = multiJoin.getRightColumnMapping(currFactor, fieldPos); + if (leftOffset != null) { + newOffset = + requireNonNull( + factorToOffsetMap.get(leftFactor), + "factorToOffsetMap.get(leftFactor)") + + leftOffset; + } + } + newProjExprs.add( + rexBuilder.makeInputRef( + fields.get(newProjExprs.size()).getType(), newOffset)); + } + } + + relBuilder.push(finalPlan.relNode); + relBuilder.project(newProjExprs, fieldNames); + + // Place the post-join filter (if it exists) on top of the final + // projection. + RexNode postJoinFilter = multiJoin.getMultiJoinRel().getPostJoinFilter(); + if (postJoinFilter != null) { + relBuilder.filter(postJoinFilter); + } + return relBuilder.build(); + } + + private static Map<Set<Integer>, JoinPlan> searchLevel( + RelMetadataQuery mq, + RelBuilder relBuilder, + List<Map<Set<Integer>, JoinPlan>> existingLevels, + LoptMultiJoin multiJoin, + boolean isOuterJoin) { + Map<Set<Integer>, List<JoinPlan>> printNextLevel = new LinkedHashMap<>(); + Map<Set<Integer>, JoinPlan> nextLevel = new LinkedHashMap<>(); + int k = 0; + int lev = existingLevels.size() - 1; + while (k <= lev - k) { + ArrayList<JoinPlan> oneSideCandidates = new ArrayList<>(existingLevels.get(k).values()); + int oneSideSize = oneSideCandidates.size(); + for (int i = 0; i < oneSideSize; i++) { + JoinPlan oneSidePlan = oneSideCandidates.get(i); + ArrayList<JoinPlan> otherSideCandidates; + if (k == lev - k) { + otherSideCandidates = new ArrayList<>(oneSideCandidates); + if (i > 0) { + otherSideCandidates.subList(0, i).clear(); + } + } else { + otherSideCandidates = new ArrayList<>(existingLevels.get(lev - k).values()); + } + for (JoinPlan otherSidePlan : otherSideCandidates) { + Optional<JoinPlan> newJoinPlan = + buildJoin( + mq, + relBuilder, + oneSidePlan, + otherSidePlan, + multiJoin, + isOuterJoin); + if (newJoinPlan.isPresent()) { + JoinPlan existingPlan = nextLevel.get(newJoinPlan.get().itemIds); + // check if it's the first plan for the item set, or it's a better plan than + // the existing one due to lower cost. + if (existingPlan == null || newJoinPlan.get().betterThan(existingPlan)) { + nextLevel.put(newJoinPlan.get().itemIds, newJoinPlan.get()); + } + + if (printNextLevel.get(newJoinPlan.get().itemIds) == null) { + printNextLevel.put( + newJoinPlan.get().itemIds, + Collections.singletonList(newJoinPlan.get())); + } else { + List<JoinPlan> joinPlans = + new ArrayList<>(printNextLevel.get(newJoinPlan.get().itemIds)); + joinPlans.add(newJoinPlan.get()); + printNextLevel.put(newJoinPlan.get().itemIds, joinPlans); + } + } + } + } + k += 1; + } + + // print + for (Map.Entry<Set<Integer>, List<JoinPlan>> entry : printNextLevel.entrySet()) { + System.out.println("+++++++++++++++++++++++++++++++++++++++++++++++++++"); + System.out.printf("item sets: %s%n", entry.getKey()); + for (JoinPlan joinPlan : entry.getValue()) { + System.out.println("--------------------------------------------"); + System.out.printf("costs: %s%n", joinPlan.cost.getRows()); + System.out.println( + FlinkRelOptUtil.toString( + joinPlan.relNode, + SqlExplainLevel.ALL_ATTRIBUTES, + false, + false, + false, + false, + false)); + System.out.println("--------------------------------------------"); + } + System.out.println("+++++++++++++++++++++++++++++++++++++++++++++++++++"); + } + return nextLevel; + } + + private static Optional<JoinPlan> buildJoin( + RelMetadataQuery mq, + RelBuilder relBuilder, + JoinPlan oneSidePlan, + JoinPlan otherSidePlan, + LoptMultiJoin multiJoin, + boolean isOuterJoin) { + // intersect, should not join two overlapping item sets. + Set<Integer> resSet = new HashSet<>(oneSidePlan.itemIds); + resSet.retainAll(otherSidePlan.itemIds); + if (!resSet.isEmpty()) { + return Optional.empty(); + } + + Optional<Tuple2<Set<RexCall>, JoinRelType>> joinConds = + getConditionsAndJoinType( + oneSidePlan.itemIds, otherSidePlan.itemIds, multiJoin, isOuterJoin); + if (!joinConds.isPresent()) { + return Optional.empty(); + } + + Set<RexCall> conditions = joinConds.get().f0; + JoinRelType joinType = joinConds.get().f1; + + LinkedHashSet<Integer> newItemIds = new LinkedHashSet<>(); + JoinPlan leftPlan; + JoinPlan rightPlan; + // put the deeper side on the left, tend to build a left-deep tree. + if (oneSidePlan.itemIds.size() >= otherSidePlan.itemIds.size()) { + leftPlan = oneSidePlan; + rightPlan = otherSidePlan; + } else { + leftPlan = otherSidePlan; + rightPlan = oneSidePlan; + if (isOuterJoin) { + joinType = (joinType == JoinRelType.LEFT) ? JoinRelType.RIGHT : JoinRelType.LEFT; + } + } + newItemIds.addAll(leftPlan.itemIds); + newItemIds.addAll(rightPlan.itemIds); + + List<RexNode> rexCalls = new ArrayList<>(conditions); + Set<RexCall> newCondition = + convertToNewCondition( + new ArrayList<>(leftPlan.itemIds), + new ArrayList<>(rightPlan.itemIds), + rexCalls, + multiJoin); + + Join newJoin = + (Join) + relBuilder + .push(leftPlan.relNode) + .push(rightPlan.relNode) + .join(joinType, newCondition) + .build(); + + RelOptCost cost = mq.getCumulativeCost(newJoin); + return Optional.of( + new JoinPlan( + newItemIds, + newJoin, + new FlinkCost(cost.getRows(), cost.getCpu(), cost.getIo(), 0.0, 0.0))); + } + + private static Set<RexCall> convertToNewCondition( + List<Integer> leftItemIds, + List<Integer> rightItemIds, + List<RexNode> rexNodes, + LoptMultiJoin multiJoin) { + RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder(); + Set<RexCall> newCondition = new HashSet<>(); + for (RexNode cond : rexNodes) { + RexCall rexCond = (RexCall) cond; + List<RexNode> resultRexNode = new ArrayList<>(); + for (RexNode rexNode : rexCond.getOperands()) { + rexNode = + rexNode.accept( + new RexInputConverterForBusyJoin( + rexBuilder, multiJoin, leftItemIds, rightItemIds)); + resultRexNode.add(rexNode); + } + RexNode resultRex = rexBuilder.makeCall(rexCond.op, resultRexNode); + newCondition.add((RexCall) resultRex); + } + + return newCondition; + } + + private static Optional<Tuple2<Set<RexCall>, JoinRelType>> getConditionsAndJoinType( + Set<Integer> oneItemIds, + Set<Integer> otherItemIds, + LoptMultiJoin multiJoin, + boolean isOuterJoin) { + if (oneItemIds.size() + otherItemIds.size() < 2) { + return Optional.empty(); + } + JoinRelType joinType = JoinRelType.INNER; + if (multiJoin.getMultiJoinRel().isFullOuterJoin()) { + assert multiJoin.getNumJoinFactors() == 2; + joinType = JoinRelType.FULL; + } + + Set<RexCall> resultRexCall = new HashSet<>(); + List<RexNode> joinConditions = new ArrayList<>(); + if (isOuterJoin) { + for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) { + joinConditions.addAll(RelOptUtil.conjunctions(multiJoin.getOuterJoinCond(i))); + } + } else { + joinConditions = multiJoin.getJoinFilters(); + } + + for (RexNode joinCond : joinConditions) { + if (joinCond instanceof RexCall) { + RexCall callCondition = (RexCall) joinCond; + ImmutableBitSet factorsRefByJoinFilter = + multiJoin.getFactorsRefByJoinFilter(callCondition); + int oneItemNumbers = 0; + int otherItemNumbers = 0; + for (int oneItemId : oneItemIds) { + if (factorsRefByJoinFilter.get(oneItemId)) { + oneItemNumbers++; + if (isOuterJoin && multiJoin.isNullGenerating(oneItemId)) { + joinType = JoinRelType.RIGHT; + } + } + } + for (int otherItemId : otherItemIds) { + if (factorsRefByJoinFilter.get(otherItemId)) { + otherItemNumbers++; + if (isOuterJoin && multiJoin.isNullGenerating(otherItemId)) { + joinType = JoinRelType.LEFT; + } + } + } + + if (oneItemNumbers > 0 + && otherItemNumbers > 0 + && oneItemNumbers + otherItemNumbers + == factorsRefByJoinFilter.asSet().size()) { + resultRexCall.add(callCondition); + } + } else { + return Optional.empty(); + } + } + + if (resultRexCall.isEmpty()) { + return Optional.empty(); + } else { + return Optional.of(Tuple2.of(resultRexCall, joinType)); + } + } + + // ~ Inner Classes ---------------------------------------------------------- + private static class JoinPlan { + final LinkedHashSet<Integer> itemIds; + final RelNode relNode; + final FlinkCost cost; Review Comment: the cost can be recompute via `mq.getCumulativeCost`, this cost is used just for caching to speed the computation up ? ########## flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java: ########## @@ -135,6 +135,24 @@ public class OptimizerConfigOptions { .defaultValue(false) .withDescription("Enables join reorder in optimizer. Default is disabled."); + @Documentation.TableOption(execMode = Documentation.ExecMode.BATCH_STREAMING) + public static final ConfigOption<Boolean> TABLE_OPTIMIZER_BUSY_JOIN_REORDER = + key("table.optimizer.busy-join-reorder") + .booleanType() + .defaultValue(false) + .withDescription( + "Enables busy join reorder in optimizer. Default is disabled."); + + @Documentation.TableOption(execMode = Documentation.ExecMode.BATCH_STREAMING) + public static final ConfigOption<Integer> TABLE_OPTIMIZER_BUSY_JOIN_REORDER_DP_THRESHOLD = + key("table.optimizer.busy-join-reorder-dp-threshold") + .intType() + .defaultValue(12) Review Comment: can you explain why we choose `12` as the threshold value ? ########## docs/layouts/shortcodes/generated/optimizer_config_configuration.html: ########## @@ -89,5 +89,17 @@ <td>Boolean</td> <td>When it is true, the optimizer will collect and use the statistics from source connectors if the source extends from SupportsStatisticReport and the statistics from catalog is UNKNOWN.Default value is true.</td> </tr> + <tr> + <td><h5>table.optimizer.busy-join-reorder</h5><br> <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span></td> Review Comment: This option can be removed, since we can use `table.optimizer.busy-join-reorder-dp-threshold` <= 0 to disable busy-join-reorder ########## flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkBusyJoinReorderRule.java: ########## @@ -0,0 +1,688 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.planner.plan.cost.FlinkCost; +import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil; +import org.apache.flink.table.planner.utils.ShortcutUtils; + +import org.apache.calcite.plan.RelOptCost; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.rules.LoptMultiJoin; +import org.apache.calcite.rel.rules.LoptOptimizeJoinRule; +import org.apache.calcite.rel.rules.MultiJoin; +import org.apache.calcite.rel.rules.TransformationRule; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.sql.SqlExplainLevel; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.tools.RelBuilderFactory; +import org.apache.calcite.util.ImmutableBitSet; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +/** + * Flink busy join reorder rule, which will convert {@link MultiJoin} to a busy join tree. + * + * <p>In this busy join reorder strategy, we will first try to reorder all the inner join type + * inputs in the multiJoin, and then add all outer join type inputs on the top. + * + * <p>First, reordering all the inner join type inputs in the multiJoin. We adopt the concept of + * level in dynamic programming, and the latter layer will use the results stored in the previous + * layers. First, we put all inputs (each input in {@link MultiJoin}) into level 0, then we build + * all two-inputs join at level 1 based on the {@link FlinkCost} of level 0, then we will build + * three-inputs join based on the previous two levels, then four-inputs joins ... etc, util we + * reorder all the inner join type inputs in the multiJoin. When building m-inputs join, we only + * keep the best plan (have the lowest {@link FlinkCost}) for the same set of m inputs. E.g., for + * three-inputs join, we keep only the best plan for inputs {A, B, C} among plans (A J B) J C, (A J + * C) J B, (B J C) J A. + * + * <p>Second, we will add all outer join type inputs in the MultiJoin on the top. + */ +public class FlinkBusyJoinReorderRule extends RelRule<FlinkBusyJoinReorderRule.Config> + implements TransformationRule { + + public static final LoptOptimizeJoinRule MULTI_JOIN_OPTIMIZE = + LoptOptimizeJoinRule.Config.DEFAULT.toRule(); + + /** Creates an SparkJoinReorderRule. */ + protected FlinkBusyJoinReorderRule(Config config) { + super(config); + } + + @Deprecated // to be removed before 2.0 + public FlinkBusyJoinReorderRule(RelBuilderFactory relBuilderFactory) { + this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(Config.class)); + } + + @Deprecated // to be removed before 2.0 + public FlinkBusyJoinReorderRule( + RelFactories.JoinFactory joinFactory, + RelFactories.ProjectFactory projectFactory, + RelFactories.FilterFactory filterFactory) { + this(RelBuilder.proto(joinFactory, projectFactory, filterFactory)); + } + + @Override + public void onMatch(RelOptRuleCall call) { + final RelBuilder relBuilder = call.builder(); + final MultiJoin multiJoinRel = call.rel(0); + Boolean enableBusyJoin = + ShortcutUtils.unwrapContext(multiJoinRel) + .getTableConfig() + .get(OptimizerConfigOptions.TABLE_OPTIMIZER_BUSY_JOIN_REORDER); + Integer dpThreshold = + ShortcutUtils.unwrapContext(multiJoinRel) + .getTableConfig() + .get(OptimizerConfigOptions.TABLE_OPTIMIZER_BUSY_JOIN_REORDER_DP_THRESHOLD); + final LoptMultiJoin multiJoin = new LoptMultiJoin(multiJoinRel); + int numJoinFactors = multiJoin.getNumJoinFactors(); + + if (enableBusyJoin && numJoinFactors <= dpThreshold) { + RelNode bestOrder = findBestOrder(call.getMetadataQuery(), relBuilder, multiJoin); + call.transformTo(bestOrder); + } else { + MULTI_JOIN_OPTIMIZE.onMatch(call); + } + } + + private static RelNode findBestOrder( + RelMetadataQuery mq, RelBuilder relBuilder, LoptMultiJoin multiJoin) { + // In our busy join reorder strategy, we will first try to reorder all the inner join type + // inputs in the multiJoin, and then add all outer join type inputs on the top. + // First, reorder all the inner join type inputs in the multiJoin. + List<Map<Set<Integer>, JoinPlan>> foundPlans = reOrderInnerJoin(mq, relBuilder, multiJoin); + + JoinPlan finalPlan; + // Second, add all outer join type inputs in the multiJoin on the top. + if (canOuterJoin(multiJoin)) { + finalPlan = + addToTopForOuterJoin( + getBestPlan(foundPlans.get(foundPlans.size() - 1)), + multiJoin, + relBuilder); + } else { + if (foundPlans.size() != multiJoin.getNumJoinFactors()) { + finalPlan = + addToTop( + getBestPlan(foundPlans.get(foundPlans.size() - 1)), + multiJoin, + relBuilder); + } else { + assert foundPlans.get(foundPlans.size() - 1).size() == 1; + finalPlan = new ArrayList<>(foundPlans.get(foundPlans.size() - 1).values()).get(0); + } + } + + final List<String> fieldNames = multiJoin.getMultiJoinRel().getRowType().getFieldNames(); + return creatToProject(relBuilder, multiJoin, finalPlan, fieldNames); + } + + private static List<Map<Set<Integer>, JoinPlan>> reOrderInnerJoin( + RelMetadataQuery mq, RelBuilder relBuilder, LoptMultiJoin multiJoin) { + List<Map<Set<Integer>, JoinPlan>> foundPlans = new ArrayList<>(); + + // First, we put each input in MultiJoin into level 0. + Map<Set<Integer>, JoinPlan> joinPlanMap = new LinkedHashMap<>(); + for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) { + if (!multiJoin.isNullGenerating(i)) { + HashSet<Integer> set1 = new HashSet<>(); + LinkedHashSet<Integer> set2 = new LinkedHashSet<>(); + set1.add(i); + set2.add(i); + RelNode joinFactor = multiJoin.getJoinFactor(i); + RelOptCost cost = mq.getCumulativeCost(joinFactor); + joinPlanMap.put( + set1, + new JoinPlan( + set2, + joinFactor, + new FlinkCost( + cost.getRows(), cost.getCpu(), cost.getIo(), 0.0, 0.0))); + } + } + foundPlans.add(joinPlanMap); + + // Build plans for next levels until the last level has only one plan. This plan contains + // all inputs that can be joined, so there's no need to continue + while (foundPlans.size() < multiJoin.getNumJoinFactors()) { + Map<Set<Integer>, JoinPlan> levelPlan = + searchLevel(mq, relBuilder, new ArrayList<>(foundPlans), multiJoin, false); + if (levelPlan.size() == 0) { + break; + } + foundPlans.add(levelPlan); + } + + return foundPlans; + } + + private static boolean canOuterJoin(LoptMultiJoin multiJoin) { + for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) { + if (multiJoin.getOuterJoinCond(i) != null + && RelOptUtil.conjunctions(multiJoin.getOuterJoinCond(i)).size() != 0) { + return true; + } + } + return false; + } + + private static JoinPlan getBestPlan(Map<Set<Integer>, JoinPlan> levelPlan) { + JoinPlan bestPlan = null; + for (Map.Entry<Set<Integer>, JoinPlan> entry : levelPlan.entrySet()) { + if (bestPlan == null || entry.getValue().cost.isLt(bestPlan.cost)) { + bestPlan = entry.getValue(); + } + } + + return bestPlan; + } + + private static JoinPlan addToTopForOuterJoin( + JoinPlan bestPlan, LoptMultiJoin multiJoin, RelBuilder relBuilder) { + RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder(); + List<Integer> remainIndexes = new ArrayList<>(); + for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) { + if (!bestPlan.itemIds.contains(i)) { + remainIndexes.add(i); + } + } + + RelNode leftNode = bestPlan.relNode; + LinkedHashSet<Integer> set = new LinkedHashSet<>(bestPlan.itemIds); + for (int index : remainIndexes) { + RelNode rightNode = multiJoin.getJoinFactor(index); + + // make new join condition + Optional<Tuple2<Set<RexCall>, JoinRelType>> joinConds = + getConditionsAndJoinType( + bestPlan.itemIds, Collections.singleton(index), multiJoin, true); + + if (!joinConds.isPresent()) { + // join type is always left. + leftNode = + relBuilder + .push(leftNode) + .push(rightNode) + .join(JoinRelType.LEFT, rexBuilder.makeLiteral(true)) + .build(); + } else { + Set<RexCall> conditions = joinConds.get().f0; + List<RexNode> rexCalls = new ArrayList<>(conditions); + Set<RexCall> newCondition = + convertToNewCondition( + new ArrayList<>(set), + Collections.singletonList(index), + rexCalls, + multiJoin); + // all given left join. + leftNode = + relBuilder + .push(leftNode) + .push(rightNode) + .join(JoinRelType.LEFT, newCondition) + .build(); + } + set.add(index); + } + return new JoinPlan(set, leftNode, new FlinkCost(0.0, 0.0, 0.0, 0.0, 0.0)); + } + + private static JoinPlan addToTop( + JoinPlan bestPlan, LoptMultiJoin multiJoin, RelBuilder relBuilder) { + RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder(); + List<Integer> remainIndexes = new ArrayList<>(); + for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) { + if (!bestPlan.itemIds.contains(i)) { + remainIndexes.add(i); + } + } + + RelNode leftNode = bestPlan.relNode; + LinkedHashSet<Integer> set = new LinkedHashSet<>(bestPlan.itemIds); + for (int index : remainIndexes) { + set.add(index); + RelNode rightNode = multiJoin.getJoinFactor(index); + leftNode = + relBuilder + .push(leftNode) + .push(rightNode) + .join( + multiJoin.getMultiJoinRel().getJoinTypes().get(index), + rexBuilder.makeLiteral(true)) + .build(); + } + return new JoinPlan(set, leftNode, new FlinkCost(0.0, 0.0, 0.0, 0.0, 0.0)); + } + + private static RelNode creatToProject( + RelBuilder relBuilder, + LoptMultiJoin multiJoin, + JoinPlan finalPlan, + List<String> fieldNames) { + List<RexNode> newProjExprs = new ArrayList<>(); + RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder(); + + List<Integer> newJoinOrder = new ArrayList<>(finalPlan.itemIds); + int nJoinFactors = multiJoin.getNumJoinFactors(); + List<RelDataTypeField> fields = multiJoin.getMultiJoinFields(); + + // create a mapping from each factor to its field offset in the join + // ordering + final Map<Integer, Integer> factorToOffsetMap = new HashMap<>(); + for (int pos = 0, fieldStart = 0; pos < nJoinFactors; pos++) { + factorToOffsetMap.put(newJoinOrder.get(pos), fieldStart); + fieldStart += multiJoin.getNumFieldsInJoinFactor(newJoinOrder.get(pos)); + } + + for (int currFactor = 0; currFactor < nJoinFactors; currFactor++) { + // if the factor is the right factor in a removable self-join, + // then where possible, remap references to the right factor to + // the corresponding reference in the left factor + Integer leftFactor = null; + if (multiJoin.isRightFactorInRemovableSelfJoin(currFactor)) { + leftFactor = multiJoin.getOtherSelfJoinFactor(currFactor); + } + for (int fieldPos = 0; + fieldPos < multiJoin.getNumFieldsInJoinFactor(currFactor); + fieldPos++) { + int newOffset = + requireNonNull( + factorToOffsetMap.get(currFactor), + () -> "factorToOffsetMap.get(currFactor)") + + fieldPos; + if (leftFactor != null) { + Integer leftOffset = multiJoin.getRightColumnMapping(currFactor, fieldPos); + if (leftOffset != null) { + newOffset = + requireNonNull( + factorToOffsetMap.get(leftFactor), + "factorToOffsetMap.get(leftFactor)") + + leftOffset; + } + } + newProjExprs.add( + rexBuilder.makeInputRef( + fields.get(newProjExprs.size()).getType(), newOffset)); + } + } + + relBuilder.push(finalPlan.relNode); + relBuilder.project(newProjExprs, fieldNames); + + // Place the post-join filter (if it exists) on top of the final + // projection. + RexNode postJoinFilter = multiJoin.getMultiJoinRel().getPostJoinFilter(); + if (postJoinFilter != null) { + relBuilder.filter(postJoinFilter); + } + return relBuilder.build(); + } + + private static Map<Set<Integer>, JoinPlan> searchLevel( + RelMetadataQuery mq, + RelBuilder relBuilder, + List<Map<Set<Integer>, JoinPlan>> existingLevels, + LoptMultiJoin multiJoin, + boolean isOuterJoin) { + Map<Set<Integer>, List<JoinPlan>> printNextLevel = new LinkedHashMap<>(); + Map<Set<Integer>, JoinPlan> nextLevel = new LinkedHashMap<>(); + int k = 0; + int lev = existingLevels.size() - 1; + while (k <= lev - k) { + ArrayList<JoinPlan> oneSideCandidates = new ArrayList<>(existingLevels.get(k).values()); + int oneSideSize = oneSideCandidates.size(); + for (int i = 0; i < oneSideSize; i++) { + JoinPlan oneSidePlan = oneSideCandidates.get(i); + ArrayList<JoinPlan> otherSideCandidates; + if (k == lev - k) { + otherSideCandidates = new ArrayList<>(oneSideCandidates); + if (i > 0) { + otherSideCandidates.subList(0, i).clear(); + } + } else { + otherSideCandidates = new ArrayList<>(existingLevels.get(lev - k).values()); + } + for (JoinPlan otherSidePlan : otherSideCandidates) { + Optional<JoinPlan> newJoinPlan = + buildJoin( + mq, + relBuilder, + oneSidePlan, + otherSidePlan, + multiJoin, + isOuterJoin); + if (newJoinPlan.isPresent()) { + JoinPlan existingPlan = nextLevel.get(newJoinPlan.get().itemIds); + // check if it's the first plan for the item set, or it's a better plan than + // the existing one due to lower cost. + if (existingPlan == null || newJoinPlan.get().betterThan(existingPlan)) { + nextLevel.put(newJoinPlan.get().itemIds, newJoinPlan.get()); + } + + if (printNextLevel.get(newJoinPlan.get().itemIds) == null) { + printNextLevel.put( + newJoinPlan.get().itemIds, + Collections.singletonList(newJoinPlan.get())); + } else { + List<JoinPlan> joinPlans = + new ArrayList<>(printNextLevel.get(newJoinPlan.get().itemIds)); + joinPlans.add(newJoinPlan.get()); + printNextLevel.put(newJoinPlan.get().itemIds, joinPlans); + } + } + } + } + k += 1; + } + + // print + for (Map.Entry<Set<Integer>, List<JoinPlan>> entry : printNextLevel.entrySet()) { + System.out.println("+++++++++++++++++++++++++++++++++++++++++++++++++++"); Review Comment: the `print` code can be removed ########## flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkBusyJoinReorderRule.java: ########## @@ -0,0 +1,688 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.planner.plan.cost.FlinkCost; +import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil; +import org.apache.flink.table.planner.utils.ShortcutUtils; + +import org.apache.calcite.plan.RelOptCost; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.rules.LoptMultiJoin; +import org.apache.calcite.rel.rules.LoptOptimizeJoinRule; +import org.apache.calcite.rel.rules.MultiJoin; +import org.apache.calcite.rel.rules.TransformationRule; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.sql.SqlExplainLevel; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.tools.RelBuilderFactory; +import org.apache.calcite.util.ImmutableBitSet; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +/** + * Flink busy join reorder rule, which will convert {@link MultiJoin} to a busy join tree. + * + * <p>In this busy join reorder strategy, we will first try to reorder all the inner join type + * inputs in the multiJoin, and then add all outer join type inputs on the top. + * + * <p>First, reordering all the inner join type inputs in the multiJoin. We adopt the concept of + * level in dynamic programming, and the latter layer will use the results stored in the previous + * layers. First, we put all inputs (each input in {@link MultiJoin}) into level 0, then we build + * all two-inputs join at level 1 based on the {@link FlinkCost} of level 0, then we will build + * three-inputs join based on the previous two levels, then four-inputs joins ... etc, util we + * reorder all the inner join type inputs in the multiJoin. When building m-inputs join, we only + * keep the best plan (have the lowest {@link FlinkCost}) for the same set of m inputs. E.g., for + * three-inputs join, we keep only the best plan for inputs {A, B, C} among plans (A J B) J C, (A J + * C) J B, (B J C) J A. + * + * <p>Second, we will add all outer join type inputs in the MultiJoin on the top. + */ +public class FlinkBusyJoinReorderRule extends RelRule<FlinkBusyJoinReorderRule.Config> + implements TransformationRule { + + public static final LoptOptimizeJoinRule MULTI_JOIN_OPTIMIZE = + LoptOptimizeJoinRule.Config.DEFAULT.toRule(); + + /** Creates an SparkJoinReorderRule. */ + protected FlinkBusyJoinReorderRule(Config config) { + super(config); + } + + @Deprecated // to be removed before 2.0 + public FlinkBusyJoinReorderRule(RelBuilderFactory relBuilderFactory) { + this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(Config.class)); + } + + @Deprecated // to be removed before 2.0 + public FlinkBusyJoinReorderRule( + RelFactories.JoinFactory joinFactory, + RelFactories.ProjectFactory projectFactory, + RelFactories.FilterFactory filterFactory) { + this(RelBuilder.proto(joinFactory, projectFactory, filterFactory)); + } + + @Override + public void onMatch(RelOptRuleCall call) { + final RelBuilder relBuilder = call.builder(); + final MultiJoin multiJoinRel = call.rel(0); + Boolean enableBusyJoin = + ShortcutUtils.unwrapContext(multiJoinRel) + .getTableConfig() + .get(OptimizerConfigOptions.TABLE_OPTIMIZER_BUSY_JOIN_REORDER); + Integer dpThreshold = + ShortcutUtils.unwrapContext(multiJoinRel) + .getTableConfig() + .get(OptimizerConfigOptions.TABLE_OPTIMIZER_BUSY_JOIN_REORDER_DP_THRESHOLD); + final LoptMultiJoin multiJoin = new LoptMultiJoin(multiJoinRel); + int numJoinFactors = multiJoin.getNumJoinFactors(); + + if (enableBusyJoin && numJoinFactors <= dpThreshold) { + RelNode bestOrder = findBestOrder(call.getMetadataQuery(), relBuilder, multiJoin); + call.transformTo(bestOrder); + } else { + MULTI_JOIN_OPTIMIZE.onMatch(call); + } + } + + private static RelNode findBestOrder( + RelMetadataQuery mq, RelBuilder relBuilder, LoptMultiJoin multiJoin) { + // In our busy join reorder strategy, we will first try to reorder all the inner join type + // inputs in the multiJoin, and then add all outer join type inputs on the top. + // First, reorder all the inner join type inputs in the multiJoin. + List<Map<Set<Integer>, JoinPlan>> foundPlans = reOrderInnerJoin(mq, relBuilder, multiJoin); + + JoinPlan finalPlan; + // Second, add all outer join type inputs in the multiJoin on the top. + if (canOuterJoin(multiJoin)) { + finalPlan = + addToTopForOuterJoin( + getBestPlan(foundPlans.get(foundPlans.size() - 1)), + multiJoin, + relBuilder); + } else { + if (foundPlans.size() != multiJoin.getNumJoinFactors()) { + finalPlan = + addToTop( + getBestPlan(foundPlans.get(foundPlans.size() - 1)), + multiJoin, + relBuilder); + } else { + assert foundPlans.get(foundPlans.size() - 1).size() == 1; + finalPlan = new ArrayList<>(foundPlans.get(foundPlans.size() - 1).values()).get(0); + } + } + + final List<String> fieldNames = multiJoin.getMultiJoinRel().getRowType().getFieldNames(); + return creatToProject(relBuilder, multiJoin, finalPlan, fieldNames); + } + + private static List<Map<Set<Integer>, JoinPlan>> reOrderInnerJoin( + RelMetadataQuery mq, RelBuilder relBuilder, LoptMultiJoin multiJoin) { + List<Map<Set<Integer>, JoinPlan>> foundPlans = new ArrayList<>(); + + // First, we put each input in MultiJoin into level 0. + Map<Set<Integer>, JoinPlan> joinPlanMap = new LinkedHashMap<>(); + for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) { + if (!multiJoin.isNullGenerating(i)) { + HashSet<Integer> set1 = new HashSet<>(); + LinkedHashSet<Integer> set2 = new LinkedHashSet<>(); + set1.add(i); + set2.add(i); + RelNode joinFactor = multiJoin.getJoinFactor(i); + RelOptCost cost = mq.getCumulativeCost(joinFactor); + joinPlanMap.put( + set1, + new JoinPlan( + set2, + joinFactor, + new FlinkCost( + cost.getRows(), cost.getCpu(), cost.getIo(), 0.0, 0.0))); + } + } + foundPlans.add(joinPlanMap); + + // Build plans for next levels until the last level has only one plan. This plan contains + // all inputs that can be joined, so there's no need to continue + while (foundPlans.size() < multiJoin.getNumJoinFactors()) { + Map<Set<Integer>, JoinPlan> levelPlan = + searchLevel(mq, relBuilder, new ArrayList<>(foundPlans), multiJoin, false); + if (levelPlan.size() == 0) { + break; + } + foundPlans.add(levelPlan); + } + + return foundPlans; + } + + private static boolean canOuterJoin(LoptMultiJoin multiJoin) { + for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) { + if (multiJoin.getOuterJoinCond(i) != null + && RelOptUtil.conjunctions(multiJoin.getOuterJoinCond(i)).size() != 0) { + return true; + } + } + return false; + } + + private static JoinPlan getBestPlan(Map<Set<Integer>, JoinPlan> levelPlan) { + JoinPlan bestPlan = null; + for (Map.Entry<Set<Integer>, JoinPlan> entry : levelPlan.entrySet()) { + if (bestPlan == null || entry.getValue().cost.isLt(bestPlan.cost)) { + bestPlan = entry.getValue(); + } + } + + return bestPlan; + } + + private static JoinPlan addToTopForOuterJoin( + JoinPlan bestPlan, LoptMultiJoin multiJoin, RelBuilder relBuilder) { + RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder(); + List<Integer> remainIndexes = new ArrayList<>(); + for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) { + if (!bestPlan.itemIds.contains(i)) { + remainIndexes.add(i); + } + } + + RelNode leftNode = bestPlan.relNode; + LinkedHashSet<Integer> set = new LinkedHashSet<>(bestPlan.itemIds); + for (int index : remainIndexes) { + RelNode rightNode = multiJoin.getJoinFactor(index); + + // make new join condition + Optional<Tuple2<Set<RexCall>, JoinRelType>> joinConds = + getConditionsAndJoinType( + bestPlan.itemIds, Collections.singleton(index), multiJoin, true); + + if (!joinConds.isPresent()) { + // join type is always left. + leftNode = + relBuilder + .push(leftNode) + .push(rightNode) + .join(JoinRelType.LEFT, rexBuilder.makeLiteral(true)) + .build(); + } else { + Set<RexCall> conditions = joinConds.get().f0; + List<RexNode> rexCalls = new ArrayList<>(conditions); + Set<RexCall> newCondition = + convertToNewCondition( + new ArrayList<>(set), + Collections.singletonList(index), + rexCalls, + multiJoin); + // all given left join. + leftNode = + relBuilder + .push(leftNode) + .push(rightNode) + .join(JoinRelType.LEFT, newCondition) + .build(); + } + set.add(index); + } + return new JoinPlan(set, leftNode, new FlinkCost(0.0, 0.0, 0.0, 0.0, 0.0)); + } + + private static JoinPlan addToTop( + JoinPlan bestPlan, LoptMultiJoin multiJoin, RelBuilder relBuilder) { + RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder(); + List<Integer> remainIndexes = new ArrayList<>(); + for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) { + if (!bestPlan.itemIds.contains(i)) { + remainIndexes.add(i); + } + } + + RelNode leftNode = bestPlan.relNode; + LinkedHashSet<Integer> set = new LinkedHashSet<>(bestPlan.itemIds); + for (int index : remainIndexes) { + set.add(index); + RelNode rightNode = multiJoin.getJoinFactor(index); + leftNode = + relBuilder + .push(leftNode) + .push(rightNode) + .join( + multiJoin.getMultiJoinRel().getJoinTypes().get(index), + rexBuilder.makeLiteral(true)) + .build(); + } + return new JoinPlan(set, leftNode, new FlinkCost(0.0, 0.0, 0.0, 0.0, 0.0)); + } + + private static RelNode creatToProject( + RelBuilder relBuilder, + LoptMultiJoin multiJoin, + JoinPlan finalPlan, + List<String> fieldNames) { + List<RexNode> newProjExprs = new ArrayList<>(); + RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder(); + + List<Integer> newJoinOrder = new ArrayList<>(finalPlan.itemIds); + int nJoinFactors = multiJoin.getNumJoinFactors(); + List<RelDataTypeField> fields = multiJoin.getMultiJoinFields(); + + // create a mapping from each factor to its field offset in the join + // ordering + final Map<Integer, Integer> factorToOffsetMap = new HashMap<>(); + for (int pos = 0, fieldStart = 0; pos < nJoinFactors; pos++) { + factorToOffsetMap.put(newJoinOrder.get(pos), fieldStart); + fieldStart += multiJoin.getNumFieldsInJoinFactor(newJoinOrder.get(pos)); + } + + for (int currFactor = 0; currFactor < nJoinFactors; currFactor++) { + // if the factor is the right factor in a removable self-join, + // then where possible, remap references to the right factor to + // the corresponding reference in the left factor + Integer leftFactor = null; + if (multiJoin.isRightFactorInRemovableSelfJoin(currFactor)) { + leftFactor = multiJoin.getOtherSelfJoinFactor(currFactor); + } + for (int fieldPos = 0; + fieldPos < multiJoin.getNumFieldsInJoinFactor(currFactor); + fieldPos++) { + int newOffset = + requireNonNull( + factorToOffsetMap.get(currFactor), + () -> "factorToOffsetMap.get(currFactor)") + + fieldPos; + if (leftFactor != null) { + Integer leftOffset = multiJoin.getRightColumnMapping(currFactor, fieldPos); + if (leftOffset != null) { + newOffset = + requireNonNull( + factorToOffsetMap.get(leftFactor), + "factorToOffsetMap.get(leftFactor)") + + leftOffset; + } + } + newProjExprs.add( + rexBuilder.makeInputRef( + fields.get(newProjExprs.size()).getType(), newOffset)); + } + } + + relBuilder.push(finalPlan.relNode); + relBuilder.project(newProjExprs, fieldNames); + + // Place the post-join filter (if it exists) on top of the final + // projection. + RexNode postJoinFilter = multiJoin.getMultiJoinRel().getPostJoinFilter(); + if (postJoinFilter != null) { + relBuilder.filter(postJoinFilter); + } + return relBuilder.build(); + } + + private static Map<Set<Integer>, JoinPlan> searchLevel( + RelMetadataQuery mq, + RelBuilder relBuilder, + List<Map<Set<Integer>, JoinPlan>> existingLevels, + LoptMultiJoin multiJoin, + boolean isOuterJoin) { + Map<Set<Integer>, List<JoinPlan>> printNextLevel = new LinkedHashMap<>(); + Map<Set<Integer>, JoinPlan> nextLevel = new LinkedHashMap<>(); + int k = 0; + int lev = existingLevels.size() - 1; + while (k <= lev - k) { + ArrayList<JoinPlan> oneSideCandidates = new ArrayList<>(existingLevels.get(k).values()); + int oneSideSize = oneSideCandidates.size(); + for (int i = 0; i < oneSideSize; i++) { + JoinPlan oneSidePlan = oneSideCandidates.get(i); + ArrayList<JoinPlan> otherSideCandidates; + if (k == lev - k) { + otherSideCandidates = new ArrayList<>(oneSideCandidates); + if (i > 0) { + otherSideCandidates.subList(0, i).clear(); + } + } else { + otherSideCandidates = new ArrayList<>(existingLevels.get(lev - k).values()); + } + for (JoinPlan otherSidePlan : otherSideCandidates) { + Optional<JoinPlan> newJoinPlan = + buildJoin( + mq, + relBuilder, + oneSidePlan, + otherSidePlan, + multiJoin, + isOuterJoin); + if (newJoinPlan.isPresent()) { + JoinPlan existingPlan = nextLevel.get(newJoinPlan.get().itemIds); + // check if it's the first plan for the item set, or it's a better plan than + // the existing one due to lower cost. + if (existingPlan == null || newJoinPlan.get().betterThan(existingPlan)) { + nextLevel.put(newJoinPlan.get().itemIds, newJoinPlan.get()); + } + + if (printNextLevel.get(newJoinPlan.get().itemIds) == null) { + printNextLevel.put( + newJoinPlan.get().itemIds, + Collections.singletonList(newJoinPlan.get())); + } else { + List<JoinPlan> joinPlans = + new ArrayList<>(printNextLevel.get(newJoinPlan.get().itemIds)); + joinPlans.add(newJoinPlan.get()); + printNextLevel.put(newJoinPlan.get().itemIds, joinPlans); + } + } + } + } + k += 1; + } + + // print + for (Map.Entry<Set<Integer>, List<JoinPlan>> entry : printNextLevel.entrySet()) { + System.out.println("+++++++++++++++++++++++++++++++++++++++++++++++++++"); + System.out.printf("item sets: %s%n", entry.getKey()); + for (JoinPlan joinPlan : entry.getValue()) { + System.out.println("--------------------------------------------"); + System.out.printf("costs: %s%n", joinPlan.cost.getRows()); + System.out.println( + FlinkRelOptUtil.toString( + joinPlan.relNode, + SqlExplainLevel.ALL_ATTRIBUTES, + false, + false, + false, + false, + false)); + System.out.println("--------------------------------------------"); + } + System.out.println("+++++++++++++++++++++++++++++++++++++++++++++++++++"); + } + return nextLevel; + } + + private static Optional<JoinPlan> buildJoin( + RelMetadataQuery mq, + RelBuilder relBuilder, + JoinPlan oneSidePlan, + JoinPlan otherSidePlan, + LoptMultiJoin multiJoin, + boolean isOuterJoin) { + // intersect, should not join two overlapping item sets. + Set<Integer> resSet = new HashSet<>(oneSidePlan.itemIds); + resSet.retainAll(otherSidePlan.itemIds); + if (!resSet.isEmpty()) { + return Optional.empty(); + } + + Optional<Tuple2<Set<RexCall>, JoinRelType>> joinConds = + getConditionsAndJoinType( + oneSidePlan.itemIds, otherSidePlan.itemIds, multiJoin, isOuterJoin); + if (!joinConds.isPresent()) { + return Optional.empty(); + } + + Set<RexCall> conditions = joinConds.get().f0; + JoinRelType joinType = joinConds.get().f1; + + LinkedHashSet<Integer> newItemIds = new LinkedHashSet<>(); + JoinPlan leftPlan; + JoinPlan rightPlan; + // put the deeper side on the left, tend to build a left-deep tree. + if (oneSidePlan.itemIds.size() >= otherSidePlan.itemIds.size()) { + leftPlan = oneSidePlan; + rightPlan = otherSidePlan; + } else { + leftPlan = otherSidePlan; + rightPlan = oneSidePlan; + if (isOuterJoin) { + joinType = (joinType == JoinRelType.LEFT) ? JoinRelType.RIGHT : JoinRelType.LEFT; + } + } + newItemIds.addAll(leftPlan.itemIds); + newItemIds.addAll(rightPlan.itemIds); + + List<RexNode> rexCalls = new ArrayList<>(conditions); + Set<RexCall> newCondition = + convertToNewCondition( + new ArrayList<>(leftPlan.itemIds), + new ArrayList<>(rightPlan.itemIds), + rexCalls, + multiJoin); + + Join newJoin = + (Join) + relBuilder + .push(leftPlan.relNode) + .push(rightPlan.relNode) + .join(joinType, newCondition) + .build(); + + RelOptCost cost = mq.getCumulativeCost(newJoin); + return Optional.of( + new JoinPlan( + newItemIds, + newJoin, + new FlinkCost(cost.getRows(), cost.getCpu(), cost.getIo(), 0.0, 0.0))); + } + + private static Set<RexCall> convertToNewCondition( + List<Integer> leftItemIds, + List<Integer> rightItemIds, + List<RexNode> rexNodes, + LoptMultiJoin multiJoin) { + RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder(); + Set<RexCall> newCondition = new HashSet<>(); + for (RexNode cond : rexNodes) { + RexCall rexCond = (RexCall) cond; + List<RexNode> resultRexNode = new ArrayList<>(); + for (RexNode rexNode : rexCond.getOperands()) { + rexNode = + rexNode.accept( + new RexInputConverterForBusyJoin( + rexBuilder, multiJoin, leftItemIds, rightItemIds)); + resultRexNode.add(rexNode); + } + RexNode resultRex = rexBuilder.makeCall(rexCond.op, resultRexNode); + newCondition.add((RexCall) resultRex); + } + + return newCondition; + } + + private static Optional<Tuple2<Set<RexCall>, JoinRelType>> getConditionsAndJoinType( + Set<Integer> oneItemIds, + Set<Integer> otherItemIds, + LoptMultiJoin multiJoin, + boolean isOuterJoin) { + if (oneItemIds.size() + otherItemIds.size() < 2) { + return Optional.empty(); + } + JoinRelType joinType = JoinRelType.INNER; + if (multiJoin.getMultiJoinRel().isFullOuterJoin()) { + assert multiJoin.getNumJoinFactors() == 2; + joinType = JoinRelType.FULL; + } + + Set<RexCall> resultRexCall = new HashSet<>(); + List<RexNode> joinConditions = new ArrayList<>(); + if (isOuterJoin) { + for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) { + joinConditions.addAll(RelOptUtil.conjunctions(multiJoin.getOuterJoinCond(i))); + } + } else { + joinConditions = multiJoin.getJoinFilters(); + } + + for (RexNode joinCond : joinConditions) { + if (joinCond instanceof RexCall) { + RexCall callCondition = (RexCall) joinCond; + ImmutableBitSet factorsRefByJoinFilter = + multiJoin.getFactorsRefByJoinFilter(callCondition); + int oneItemNumbers = 0; + int otherItemNumbers = 0; + for (int oneItemId : oneItemIds) { + if (factorsRefByJoinFilter.get(oneItemId)) { + oneItemNumbers++; + if (isOuterJoin && multiJoin.isNullGenerating(oneItemId)) { + joinType = JoinRelType.RIGHT; + } + } + } + for (int otherItemId : otherItemIds) { + if (factorsRefByJoinFilter.get(otherItemId)) { + otherItemNumbers++; + if (isOuterJoin && multiJoin.isNullGenerating(otherItemId)) { + joinType = JoinRelType.LEFT; + } + } + } + + if (oneItemNumbers > 0 + && otherItemNumbers > 0 + && oneItemNumbers + otherItemNumbers + == factorsRefByJoinFilter.asSet().size()) { + resultRexCall.add(callCondition); + } + } else { + return Optional.empty(); + } + } + + if (resultRexCall.isEmpty()) { + return Optional.empty(); + } else { + return Optional.of(Tuple2.of(resultRexCall, joinType)); + } + } + + // ~ Inner Classes ---------------------------------------------------------- + private static class JoinPlan { + final LinkedHashSet<Integer> itemIds; Review Comment: What does `itemId` mean ? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org