wanglijie95 commented on code in PR #21162:
URL: https://github.com/apache/flink/pull/21162#discussion_r1058767679


##########
flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java:
##########
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.executiongraph;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.JobEdge;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Util to compute {@link VertexInputInfo}s for execution job vertex. */
+public class VertexInputInfoComputationUtils {
+
+    public static Map<IntermediateDataSetID, VertexInputInfo> 
computeVertexInputInfos(
+            ExecutionJobVertex ejv,
+            Function<IntermediateDataSetID, IntermediateResult> 
intermediateResultRetriever) {
+        checkState(ejv.isParallelismDecided());
+        final List<IntermediateResultInfo> intermediateResultInfos =
+                ejv.getJobVertex().getInputs().stream()
+                        .map(JobEdge::getSourceId)
+                        .map(intermediateResultRetriever)
+                        .map(IntermediateResultWrapper::new)
+                        .collect(Collectors.toList());
+        return computeVertexInputInfos(
+                ejv.getParallelism(), intermediateResultInfos, 
ejv.getGraph().isDynamic());
+    }
+
+    public static Map<IntermediateDataSetID, VertexInputInfo> 
computeVertexInputInfos(
+            int parallelism, List<IntermediateResultInfo> inputs, boolean 
isDynamicGraph) {
+
+        checkArgument(parallelism > 0);
+        final Map<IntermediateDataSetID, VertexInputInfo> vertexInputInfos = 
new LinkedHashMap<>();
+
+        for (IntermediateResultInfo input : inputs) {
+            int sourceParallelism = input.getNumPartitions();
+
+            if (input.isPointwise()) {
+                vertexInputInfos.putIfAbsent(
+                        input.getResultId(),
+                        computeVertexInputInfoForPointwise(
+                                sourceParallelism,
+                                parallelism,
+                                input::getNumSubpartitions,
+                                isDynamicGraph));
+            } else {
+                vertexInputInfos.putIfAbsent(
+                        input.getResultId(),
+                        computeVertexInputInfoForAllToAll(
+                                sourceParallelism,
+                                parallelism,
+                                input::getNumSubpartitions,
+                                isDynamicGraph,
+                                input.isBroadcast()));
+            }
+        }
+
+        return vertexInputInfos;
+    }
+
+    /**
+     * Compute the {@link VertexInputInfo} for a {@link 
DistributionPattern#POINTWISE} edge. This
+     * computation algorithm will evenly distribute subpartitions to 
downstream subtasks according
+     * to the number of subpartitions. Different downstream subtasks consume 
roughly the same number
+     * of subpartitions.
+     *
+     * @param sourceCount the parallelism of upstream
+     * @param targetCount the parallelism of downstream
+     * @param numOfSubpartitionsRetriever a retriever to get the number of 
subpartitions
+     * @param isDynamicGraph whether is dynamic graph
+     * @return the computed {@link VertexInputInfo}
+     */
+    @VisibleForTesting
+    static VertexInputInfo computeVertexInputInfoForPointwise(
+            int sourceCount,
+            int targetCount,
+            Function<Integer, Integer> numOfSubpartitionsRetriever,
+            boolean isDynamicGraph) {
+
+        final List<TaskInputInfo> taskInputInfos = new ArrayList<>();
+
+        if (sourceCount >= targetCount) {
+            for (int index = 0; index < targetCount; index++) {
+
+                int start = index * sourceCount / targetCount;
+                int end = (index + 1) * sourceCount / targetCount;
+
+                PartitionIndexRange partitionRange = new 
PartitionIndexRange(start, end - 1);
+                SubpartitionIndexRange subpartitionRange =
+                        computeConsumedSubpartitionRange(
+                                index,
+                                1,
+                                numOfSubpartitionsRetriever.apply(start),
+                                isDynamicGraph,
+                                false);
+                taskInputInfos.add(new TaskInputInfo(index, partitionRange, 
subpartitionRange));
+            }
+        } else {
+            for (int partitionNum = 0; partitionNum < sourceCount; 
partitionNum++) {
+
+                int start = (partitionNum * targetCount + sourceCount - 1) / 
sourceCount;
+                int end = ((partitionNum + 1) * targetCount + sourceCount - 1) 
/ sourceCount;
+                int numConsumers = end - start;
+
+                for (int i = start; i < end; i++) {
+                    PartitionIndexRange partitionRange =
+                            new PartitionIndexRange(partitionNum, 
partitionNum);
+                    SubpartitionIndexRange subpartitionRange =
+                            computeConsumedSubpartitionRange(
+                                    i,
+                                    numConsumers,
+                                    
numOfSubpartitionsRetriever.apply(partitionNum),
+                                    isDynamicGraph,
+                                    false);
+                    taskInputInfos.add(new TaskInputInfo(i, partitionRange, 
subpartitionRange));
+                }
+            }
+        }
+        return new VertexInputInfo(taskInputInfos);
+    }
+
+    /**
+     * Compute the {@link VertexInputInfo} for a {@link 
DistributionPattern#ALL_TO_ALL} edge. This
+     * computation algorithm will evenly distribute subpartitions to 
downstream subtasks according
+     * to the number of subpartitions. Different downstream subtasks consume 
roughly the same number
+     * of subpartitions.
+     *
+     * @param sourceCount the parallelism of upstream
+     * @param targetCount the parallelism of downstream
+     * @param numOfSubpartitionsRetriever a retriever to get the number of 
subpartitions
+     * @param isDynamicGraph whether is dynamic graph
+     * @param isBroadcast whether the edge is broadcast
+     * @return the computed {@link VertexInputInfo}
+     */
+    @VisibleForTesting
+    static VertexInputInfo computeVertexInputInfoForAllToAll(
+            int sourceCount,
+            int targetCount,
+            Function<Integer, Integer> numOfSubpartitionsRetriever,
+            boolean isDynamicGraph,
+            boolean isBroadcast) {
+        final List<TaskInputInfo> taskInputInfos = new ArrayList<>();
+        for (int i = 0; i < targetCount; ++i) {
+            PartitionIndexRange partitionRange = new PartitionIndexRange(0, 
sourceCount - 1);
+            SubpartitionIndexRange subpartitionRange =
+                    computeConsumedSubpartitionRange(
+                            i,
+                            targetCount,
+                            numOfSubpartitionsRetriever.apply(0),
+                            isDynamicGraph,
+                            isBroadcast);
+            taskInputInfos.add(new TaskInputInfo(i, partitionRange, 
subpartitionRange));
+        }
+        return new VertexInputInfo(taskInputInfos);
+    }
+
+    /**
+     * Compute the consumed subpartition range for a subtask. This computation 
algorithm will evenly
+     * distribute subpartitions to downstream subtasks according to the number 
of subpartitions.
+     * Different downstream subtasks consume roughly the same number of 
subpartitions.
+     *
+     * @param consumerSubtaskIndex the subtask index
+     * @param numConsumers the total number of consumers
+     * @param numSubpartitions the total number of subpartitions
+     * @param isDynamicGraph whether is dynamic graph
+     * @param isBroadcast whether the edge is broadcast
+     * @return the computed subpartition range
+     */
+    @VisibleForTesting
+    static SubpartitionIndexRange computeConsumedSubpartitionRange(
+            int consumerSubtaskIndex,
+            int numConsumers,
+            int numSubpartitions,
+            boolean isDynamicGraph,
+            boolean isBroadcast) {
+        int consumerIndex = consumerSubtaskIndex % numConsumers;
+        if (!isDynamicGraph) {
+            return new SubpartitionIndexRange(consumerIndex, consumerIndex);
+        } else {
+            if (isBroadcast) {
+                // broadcast results have only one subpartition, and be 
consumed multiple times.
+                checkArgument(numSubpartitions == 1);
+                return new SubpartitionIndexRange(0, 0);
+            } else {
+                checkArgument(consumerIndex < numConsumers);
+                checkArgument(numConsumers <= numSubpartitions);
+
+                int start = consumerIndex * numSubpartitions / numConsumers;
+                int nextStart = (consumerIndex + 1) * numSubpartitions / 
numConsumers;
+
+                return new SubpartitionIndexRange(start, nextStart - 1);
+            }
+        }
+    }
+
+    private static class IntermediateResultWrapper implements 
IntermediateResultInfo {
+        private final IntermediateResult intermediateResult;
+
+        IntermediateResultWrapper(IntermediateResult intermediateResult) {
+            this.intermediateResult = checkNotNull(intermediateResult);
+        }
+
+        @Override
+        public IntermediateDataSetID getResultId() {
+            return intermediateResult.getId();
+        }
+
+        @Override
+        public boolean isBroadcast() {
+            return intermediateResult.isBroadcast();
+        }
+
+        @Override
+        public boolean isPointwise() {
+            return intermediateResult.getConsumingDistributionPattern()
+                    == DistributionPattern.POINTWISE;
+        }
+
+        @Override
+        public int getNumPartitions() {
+            return intermediateResult.getNumberOfAssignedPartitions();
+        }
+
+        @Override
+        public int getNumSubpartitions(int partitionIndex) {
+            boolean isDynamicGraph = 
intermediateResult.getProducer().getGraph().isDynamic();
+            // Note that for non-dynamic graph, the num of subpartition has 
not been decided at

Review Comment:
   Fixed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to