zhuzhurk commented on code in PR #25414: URL: https://github.com/apache/flink/pull/25414#discussion_r1843381297
########## flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/util/ImmutableStreamNode.java: ########## @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph.util; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.streaming.api.graph.StreamEdge; +import org.apache.flink.streaming.api.graph.StreamNode; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** Helper class that provides read-only StreamNode. */ +@Internal +public class ImmutableStreamNode { + private final StreamNode streamNode; + private List<ImmutableStreamEdge> immutableOutEdges = null; + private List<ImmutableStreamEdge> immutableInEdges = null; + + public ImmutableStreamNode(StreamNode streamNode) { + this.streamNode = streamNode; + } + + public List<ImmutableStreamEdge> getOutEdges() { + if (immutableOutEdges == null) { + immutableOutEdges = new ArrayList<>(); + for (StreamEdge edge : streamNode.getOutEdges()) { + immutableOutEdges.add(new ImmutableStreamEdge(edge)); + } + } + return Collections.unmodifiableList(immutableOutEdges); + } + + public List<ImmutableStreamEdge> getInEdges() { + if (immutableInEdges == null) { + immutableInEdges = new ArrayList<>(); + for (StreamEdge edge : streamNode.getInEdges()) { + immutableInEdges.add(new ImmutableStreamEdge(edge)); + } + } + return Collections.unmodifiableList(immutableInEdges); + } + + public int getId() { + return streamNode.getId(); + } + + public @Nullable StreamOperatorFactory<?> getOperatorFactory() { Review Comment: The returned `operatorFactory` can be modifiable. In which case Flink needs to get an `operatorFactory` from an `immutableStreamNode`? ########## flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroup.java: ########## @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.jobgraph.forwardgroup; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.streaming.api.graph.StreamNode; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** Stream node level implement for {@link ForwardGroup}. */ +public class StreamNodeForwardGroup implements ForwardGroup { + + private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT; + + private int maxParallelism = JobVertex.MAX_PARALLELISM_DEFAULT; + + private final Map<StreamNode, List<StreamNode>> chainedStreamNodeGroupsByStartNode = + new HashMap<>(); + + // For a group of chained stream nodes, their parallelism is consistent. In order to make + // calculation and usage easier, we only use the start node to calculate forward group. + public StreamNodeForwardGroup( + final Map<StreamNode, List<StreamNode>> chainedStreamNodeGroupsByStartNode) { + checkNotNull(chainedStreamNodeGroupsByStartNode); + + Set<Integer> configuredParallelisms = + chainedStreamNodeGroupsByStartNode.keySet().stream() + .map(StreamNode::getParallelism) + .filter(val -> val > 0) + .collect(Collectors.toSet()); + + checkState(configuredParallelisms.size() <= 1); + + if (configuredParallelisms.size() == 1) { + this.parallelism = configuredParallelisms.iterator().next(); + } + + Set<Integer> configuredMaxParallelisms = + chainedStreamNodeGroupsByStartNode.keySet().stream() + .map(StreamNode::getMaxParallelism) + .filter(val -> val > 0) + .collect(Collectors.toSet()); + + if (!configuredMaxParallelisms.isEmpty()) { + this.maxParallelism = Collections.min(configuredMaxParallelisms); + checkState( + parallelism == ExecutionConfig.PARALLELISM_DEFAULT + || maxParallelism >= parallelism, + "There is a start node in the forward group whose maximum parallelism is smaller than the group's parallelism"); + } + + this.chainedStreamNodeGroupsByStartNode.putAll(chainedStreamNodeGroupsByStartNode); + } + + @Override + public void setParallelism(int parallelism) { + checkState(this.parallelism == ExecutionConfig.PARALLELISM_DEFAULT); + this.parallelism = parallelism; + } + + @Override + public boolean isParallelismDecided() { + return parallelism > 0; + } + + @Override + public int getParallelism() { + checkState(isParallelismDecided()); + return parallelism; + } + + @Override + public boolean isMaxParallelismDecided() { + return maxParallelism > 0; + } + + @Override + public int getMaxParallelism() { + checkState(isMaxParallelismDecided()); + return maxParallelism; + } + + @VisibleForTesting + public int size() { + return chainedStreamNodeGroupsByStartNode.values().stream().mapToInt(List::size).sum(); + } + + public Iterable<StreamNode> getStartNodes() { + return chainedStreamNodeGroupsByStartNode.keySet(); + } + + public Iterable<List<StreamNode>> getChainedStreamNodeGroups() { + return chainedStreamNodeGroupsByStartNode.values(); + } + + /** + * Responds to merge targetForwardGroup into this and update the parallelism information for + * stream nodes in merged forward group. + * + * @param targetForwardGroup The forward group to be merged. + * @return whether the merge was successful. + */ + public boolean mergeForwardGroup(StreamNodeForwardGroup targetForwardGroup) { Review Comment: maybe `targetForwardGroup` -> `forwardGroupToMerge`. When merging group B into group A, it's more natural to recognize group A as the target. And also for `canTargetMergeIntoSourceForwardGroup`. ########## flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroupTest.java: ########## @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.jobgraph.forwardgroup; + +import org.apache.flink.streaming.api.graph.StreamNode; +import org.apache.flink.streaming.api.operators.StreamOperator; + +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Unit tests for {@link StreamNodeForwardGroup}. */ +class StreamNodeForwardGroupTest { + @Test + void testStreamNodeForwardGroup() { + Map<StreamNode, List<StreamNode>> chainedNodeGroupsByStartNode = new HashMap<>(); + StreamNode streamNode1 = createStreamNode(0, 1, 1); + StreamNode streamNode2 = createStreamNode(1, 1, 1); + + chainedNodeGroupsByStartNode.put(streamNode1, List.of(streamNode1)); + chainedNodeGroupsByStartNode.put(streamNode2, List.of(streamNode2)); + + StreamNodeForwardGroup forwardGroup = + new StreamNodeForwardGroup(chainedNodeGroupsByStartNode); + assertThat(forwardGroup.getParallelism()).isEqualTo(1); + assertThat(forwardGroup.getMaxParallelism()).isEqualTo(1); + assertThat(forwardGroup.size()).isEqualTo(2); + + StreamNode streamNode3 = createStreamNode(3, 1, 1); + chainedNodeGroupsByStartNode.put(streamNode2, List.of(streamNode2, streamNode3)); + + StreamNodeForwardGroup forwardGroup2 = + new StreamNodeForwardGroup(chainedNodeGroupsByStartNode); + assertThat(forwardGroup2.size()).isEqualTo(3); + } + + @Test + void testMergeForwardGroup() { + StreamNodeForwardGroup forwardGroup = + createForwardGroupWithSingleNode(createStreamNode(0, -1, -1)); + + StreamNodeForwardGroup forwardGroupWithUnDecidedParallelism = + createForwardGroupWithSingleNode(createStreamNode(1, -1, -1)); + forwardGroup.mergeForwardGroup(forwardGroupWithUnDecidedParallelism); + assertThat(forwardGroup.isParallelismDecided()).isFalse(); + assertThat(forwardGroup.isMaxParallelismDecided()).isFalse(); + + StreamNodeForwardGroup forwardGroupWithDecidedParallelism = + createForwardGroupWithSingleNode(createStreamNode(2, 2, 4)); + forwardGroup.mergeForwardGroup(forwardGroupWithDecidedParallelism); + assertThat(forwardGroup.getParallelism()).isEqualTo(2); + assertThat(forwardGroup.getMaxParallelism()).isEqualTo(4); + + StreamNodeForwardGroup forwardGroupWithLargerMaxParallelism = + createForwardGroupWithSingleNode(createStreamNode(3, 2, 5)); + + // The target max parallelism is larger than source. + assertThat(forwardGroup.mergeForwardGroup(forwardGroupWithLargerMaxParallelism)).isTrue(); + assertThat(forwardGroup.getMaxParallelism()).isEqualTo(4); + + StreamNodeForwardGroup forwardGroupWithSmallerMaxParallelism = + createForwardGroupWithSingleNode(createStreamNode(4, 2, 3)); + assertThat(forwardGroup.mergeForwardGroup(forwardGroupWithSmallerMaxParallelism)).isTrue(); Review Comment: Maybe add a check of the maxParallelism? ########## flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java: ########## @@ -101,6 +104,101 @@ public static Map<JobVertexID, ForwardGroup> computeForwardGroups( return ret; } + /** + * We calculate forward group by a set of chained stream nodes, and use the start node to + * identify the chain group. + * + * @param topologicallySortedChainedStreamNodesMap Topologically sorted chained stream nodes. + * @param forwardProducersRetriever Records all upstream chain groups which connected to the + * given chain group with forward edge. + * @return a map of forward groups, with the start node id as the key. + */ + public static Map<Integer, StreamNodeForwardGroup> computeStreamNodeForwardGroup( + final Map<StreamNode, List<StreamNode>> topologicallySortedChainedStreamNodesMap, + final Function<StreamNode, Set<StreamNode>> forwardProducersRetriever) { + // In the forwardProducersRetriever, only the upstream nodes connected to the given start + // node by the forward edge are saved. We need to calculate the chain groups that can be + // accessed with consecutive forward edges and put them in the same forward group. + final Map<StreamNode, Set<StreamNode>> nodeToGroup = new IdentityHashMap<>(); + for (StreamNode currentNode : topologicallySortedChainedStreamNodesMap.keySet()) { + Set<StreamNode> currentGroup = new HashSet<>(); + currentGroup.add(currentNode); + nodeToGroup.put(currentNode, currentGroup); + for (StreamNode producerNode : forwardProducersRetriever.apply(currentNode)) { + // Merge nodes from the current group and producer group. + final Set<StreamNode> producerGroup = nodeToGroup.get(producerNode); + // The producerGroup cannot be null unless the topological order is incorrect. + if (producerGroup == null) { + throw new IllegalStateException( + "Producer task " + + producerNode.getId() + + " forward group is null" + + " while calculating forward group for the consumer task " + + currentNode.getId() + + ". This should be a forward group building bug."); + } + // Merge the forward group groups where the upstream and downstream are connected by + // forward edge + if (currentGroup != producerGroup) { + currentGroup = + VertexGroupComputeUtil.mergeVertexGroups( + currentGroup, producerGroup, nodeToGroup); + } + } + } + final Map<Integer, StreamNodeForwardGroup> result = new HashMap<>(); + for (Set<StreamNode> nodeGroup : VertexGroupComputeUtil.uniqueVertexGroups(nodeToGroup)) { + Map<StreamNode, List<StreamNode>> chainedStreamNodeGroupsByStartNode = new HashMap<>(); Review Comment: HashMap -> IdentityHashMap ########## flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java: ########## @@ -101,6 +132,49 @@ private void testThreeVerticesConnectSequentially( checkGroupSize(groups, numOfGroups, groupSizes); } + @Test + void testVariousConnectTypesBetweenChainedStreamNodeGroup() throws Exception { + testThreeChainedStreamNodeGroupsConnectSequentially(false, true, 2, 1, 2); + testThreeChainedStreamNodeGroupsConnectSequentially(false, false, 3, 1); Review Comment: -> testThreeChainedStreamNodeGroupsConnectSequentially(false, false, 3, 1, 1, 1); ########## flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java: ########## @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil; +import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkState; + +/** Default implementation for {@link StreamGraphContext}. */ +@Internal +public class DefaultStreamGraphContext implements StreamGraphContext { + + private static final Logger LOG = LoggerFactory.getLogger(DefaultStreamGraphContext.class); + + private final StreamGraph streamGraph; + private final ImmutableStreamGraph immutableStreamGraph; + private final Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap; + private final Map<Integer, Integer> frozenNodeToStartNodeMap; + private final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches; + + public DefaultStreamGraphContext( + StreamGraph streamGraph, + Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap, + Map<Integer, Integer> frozenNodeToStartNodeMap, + Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches) { + this.streamGraph = streamGraph; + this.immutableStreamGraph = new ImmutableStreamGraph(streamGraph); + this.startAndEndNodeIdToForwardGroupMap = startAndEndNodeIdToForwardGroupMap; + this.frozenNodeToStartNodeMap = frozenNodeToStartNodeMap; + this.opIntermediateOutputsCaches = opIntermediateOutputsCaches; Review Comment: Could you add some description for these map arguments? What do they stand for? What are they used for? Why to reuse them? Are they read-only or modifiable? ########## flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java: ########## @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil; +import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkState; + +/** Default implementation for {@link StreamGraphContext}. */ +@Internal +public class DefaultStreamGraphContext implements StreamGraphContext { + + private static final Logger LOG = LoggerFactory.getLogger(DefaultStreamGraphContext.class); + + private final StreamGraph streamGraph; + private final ImmutableStreamGraph immutableStreamGraph; + private final Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap; + private final Map<Integer, Integer> frozenNodeToStartNodeMap; + private final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches; + + public DefaultStreamGraphContext( + StreamGraph streamGraph, + Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap, + Map<Integer, Integer> frozenNodeToStartNodeMap, + Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches) { + this.streamGraph = streamGraph; + this.immutableStreamGraph = new ImmutableStreamGraph(streamGraph); + this.startAndEndNodeIdToForwardGroupMap = startAndEndNodeIdToForwardGroupMap; + this.frozenNodeToStartNodeMap = frozenNodeToStartNodeMap; + this.opIntermediateOutputsCaches = opIntermediateOutputsCaches; + } + + @Override + public ImmutableStreamGraph getStreamGraph() { + return immutableStreamGraph; + } + + @Override + public boolean modifyStreamEdge(List<StreamEdgeUpdateRequestInfo> requestInfos) { + // We first verify the legality of all requestInfos to ensure that all requests can be + // modified atomically. + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + if (!modifyStreamEdgeValidate(requestInfo)) { + return false; + } + } + + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + StreamEdge targetEdge = + getStreamEdge( + requestInfo.getSourceId(), + requestInfo.getTargetId(), + requestInfo.getEdgeId()); + StreamPartitioner<?> newPartitioner = requestInfo.getOutputPartitioner(); + if (newPartitioner != null) { + modifyOutputPartitioner(targetEdge, newPartitioner); + } + } + + return true; + } + + private boolean modifyStreamEdgeValidate(StreamEdgeUpdateRequestInfo requestInfo) { + Integer sourceNodeId = requestInfo.getSourceId(); + Integer targetNodeId = requestInfo.getTargetId(); + + StreamEdge targetEdge = getStreamEdge(sourceNodeId, targetNodeId, requestInfo.getEdgeId()); + + if (targetEdge == null) { + return false; + } + + // Modification is not allowed when the subscribing output is reused. + Map<StreamEdge, NonChainedOutput> opIntermediateOutputs = + opIntermediateOutputsCaches.get(sourceNodeId); + NonChainedOutput output = + opIntermediateOutputs != null ? opIntermediateOutputs.get(targetEdge) : null; + if (output != null) { + Set<StreamEdge> consumerStreamEdges = + opIntermediateOutputs.entrySet().stream() + .filter(entry -> entry.getValue().equals(output)) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + if (consumerStreamEdges.size() != 1) { + LOG.info( + "Modification for edge {} is not allowed as the subscribing output is reused.", + targetEdge); + return false; + } + } + + if (frozenNodeToStartNodeMap.containsKey(targetNodeId)) { + LOG.info( + "Modification for edge {} is not allowed as the target node with id {} is in frozen list.", + targetEdge, + targetNodeId); + return false; + } + + StreamPartitioner<?> newPartitioner = requestInfo.getOutputPartitioner(); + if (newPartitioner != null + && targetEdge.getPartitioner().getClass().equals(ForwardPartitioner.class)) { + LOG.info( + "Modification for edge {} is not allowed as the origin partitioner is ForwardPartitioner.", + targetEdge); + return false; + } + + return true; + } + + private void modifyOutputPartitioner( + StreamEdge targetEdge, StreamPartitioner<?> newPartitioner) { + if (newPartitioner == null || targetEdge == null) { + return; + } + Integer sourceNodeId = targetEdge.getSourceId(); + Integer targetNodeId = targetEdge.getTargetId(); + + StreamPartitioner<?> oldPartitioner = targetEdge.getPartitioner(); + + targetEdge.setPartitioner(newPartitioner); + + // For non-chainable edges, we change the ForwardPartitioner to RescalePartitioner to avoid + // limiting the parallelism of the downstream node by the forward edge. + // 1. If the upstream job vertex is created. + if (targetEdge.getPartitioner() instanceof ForwardPartitioner + && frozenNodeToStartNodeMap.containsKey(sourceNodeId)) { + targetEdge.setPartitioner(new RescalePartitioner<>()); + } + // 2. If the source and target are non-chainable. + if (targetEdge.getPartitioner() instanceof ForwardPartitioner + && !StreamingJobGraphGenerator.isChainable(targetEdge, streamGraph)) { + targetEdge.setPartitioner(new RescalePartitioner<>()); + } + // 3. If the forward group cannot be merged. + if (targetEdge.getPartitioner() instanceof ForwardPartitioner Review Comment: Maybe use one outer `if (targetEdge.getPartitioner() instanceof ForwardPartitioner` to wrap these 3 if-clauses? ########## flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java: ########## @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil; +import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkState; + +/** Default implementation for {@link StreamGraphContext}. */ +@Internal +public class DefaultStreamGraphContext implements StreamGraphContext { + + private static final Logger LOG = LoggerFactory.getLogger(DefaultStreamGraphContext.class); + + private final StreamGraph streamGraph; + private final ImmutableStreamGraph immutableStreamGraph; + private final Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap; + private final Map<Integer, Integer> frozenNodeToStartNodeMap; + private final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches; + + public DefaultStreamGraphContext( + StreamGraph streamGraph, + Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap, + Map<Integer, Integer> frozenNodeToStartNodeMap, + Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches) { + this.streamGraph = streamGraph; + this.immutableStreamGraph = new ImmutableStreamGraph(streamGraph); + this.startAndEndNodeIdToForwardGroupMap = startAndEndNodeIdToForwardGroupMap; + this.frozenNodeToStartNodeMap = frozenNodeToStartNodeMap; + this.opIntermediateOutputsCaches = opIntermediateOutputsCaches; + } + + @Override + public ImmutableStreamGraph getStreamGraph() { + return immutableStreamGraph; + } + + @Override + public boolean modifyStreamEdge(List<StreamEdgeUpdateRequestInfo> requestInfos) { + // We first verify the legality of all requestInfos to ensure that all requests can be + // modified atomically. + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + if (!modifyStreamEdgeValidate(requestInfo)) { + return false; + } + } + + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + StreamEdge targetEdge = + getStreamEdge( + requestInfo.getSourceId(), + requestInfo.getTargetId(), + requestInfo.getEdgeId()); + StreamPartitioner<?> newPartitioner = requestInfo.getOutputPartitioner(); + if (newPartitioner != null) { + modifyOutputPartitioner(targetEdge, newPartitioner); + } + } + + return true; + } + + private boolean modifyStreamEdgeValidate(StreamEdgeUpdateRequestInfo requestInfo) { + Integer sourceNodeId = requestInfo.getSourceId(); + Integer targetNodeId = requestInfo.getTargetId(); + + StreamEdge targetEdge = getStreamEdge(sourceNodeId, targetNodeId, requestInfo.getEdgeId()); + + if (targetEdge == null) { + return false; + } + + // Modification is not allowed when the subscribing output is reused. + Map<StreamEdge, NonChainedOutput> opIntermediateOutputs = + opIntermediateOutputsCaches.get(sourceNodeId); + NonChainedOutput output = + opIntermediateOutputs != null ? opIntermediateOutputs.get(targetEdge) : null; + if (output != null) { + Set<StreamEdge> consumerStreamEdges = + opIntermediateOutputs.entrySet().stream() + .filter(entry -> entry.getValue().equals(output)) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + if (consumerStreamEdges.size() != 1) { + LOG.info( + "Modification for edge {} is not allowed as the subscribing output is reused.", + targetEdge); + return false; + } + } + + if (frozenNodeToStartNodeMap.containsKey(targetNodeId)) { + LOG.info( + "Modification for edge {} is not allowed as the target node with id {} is in frozen list.", Review Comment: Will this kind of logs happen too frequently? ########## flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java: ########## @@ -59,6 +68,28 @@ void testIsolatedVertices() throws Exception { checkGroupSize(groups, 0); } + @Test + void testIsolatedChainedStreamNodeGroups() throws Exception { + Map<StreamNode, List<StreamNode>> topologicallySortedChainedStreamNodeByStartNode = + new LinkedHashMap<>(); + Map<StreamNode, Set<StreamNode>> forwardProducersByStartNode = Collections.emptyMap(); + for (int i = 1; i <= 3; ++i) { + StreamNode streamNode = createStreamNode(i); + topologicallySortedChainedStreamNodeByStartNode.put( + streamNode, Collections.singletonList(streamNode)); + } + + Set<ForwardGroup> groups = + computeForwardGroups( + topologicallySortedChainedStreamNodeByStartNode, + forwardProducersByStartNode); + + // Different from the job vertex forward group, the stream node forward group is allowed to + // contain only one single stream node, as these groups may merge with other groups in the + // future. + checkGroupSize(groups, 3, 1); Review Comment: B.T.W. will it create a job vertex forward group if the size is 1 when the job vertex is created? ########## flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java: ########## @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil; +import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkState; + +/** Default implementation for {@link StreamGraphContext}. */ +@Internal +public class DefaultStreamGraphContext implements StreamGraphContext { + + private static final Logger LOG = LoggerFactory.getLogger(DefaultStreamGraphContext.class); + + private final StreamGraph streamGraph; + private final ImmutableStreamGraph immutableStreamGraph; + private final Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap; + private final Map<Integer, Integer> frozenNodeToStartNodeMap; + private final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches; + + public DefaultStreamGraphContext( + StreamGraph streamGraph, + Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap, + Map<Integer, Integer> frozenNodeToStartNodeMap, + Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches) { + this.streamGraph = streamGraph; + this.immutableStreamGraph = new ImmutableStreamGraph(streamGraph); + this.startAndEndNodeIdToForwardGroupMap = startAndEndNodeIdToForwardGroupMap; + this.frozenNodeToStartNodeMap = frozenNodeToStartNodeMap; + this.opIntermediateOutputsCaches = opIntermediateOutputsCaches; + } + + @Override + public ImmutableStreamGraph getStreamGraph() { + return immutableStreamGraph; + } + + @Override + public boolean modifyStreamEdge(List<StreamEdgeUpdateRequestInfo> requestInfos) { + // We first verify the legality of all requestInfos to ensure that all requests can be + // modified atomically. + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + if (!modifyStreamEdgeValidate(requestInfo)) { + return false; + } + } + + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + StreamEdge targetEdge = + getStreamEdge( + requestInfo.getSourceId(), + requestInfo.getTargetId(), + requestInfo.getEdgeId()); + StreamPartitioner<?> newPartitioner = requestInfo.getOutputPartitioner(); + if (newPartitioner != null) { + modifyOutputPartitioner(targetEdge, newPartitioner); + } + } + + return true; + } + + private boolean modifyStreamEdgeValidate(StreamEdgeUpdateRequestInfo requestInfo) { + Integer sourceNodeId = requestInfo.getSourceId(); + Integer targetNodeId = requestInfo.getTargetId(); + + StreamEdge targetEdge = getStreamEdge(sourceNodeId, targetNodeId, requestInfo.getEdgeId()); + + if (targetEdge == null) { + return false; + } + + // Modification is not allowed when the subscribing output is reused. + Map<StreamEdge, NonChainedOutput> opIntermediateOutputs = + opIntermediateOutputsCaches.get(sourceNodeId); + NonChainedOutput output = + opIntermediateOutputs != null ? opIntermediateOutputs.get(targetEdge) : null; + if (output != null) { + Set<StreamEdge> consumerStreamEdges = + opIntermediateOutputs.entrySet().stream() + .filter(entry -> entry.getValue().equals(output)) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + if (consumerStreamEdges.size() != 1) { + LOG.info( + "Modification for edge {} is not allowed as the subscribing output is reused.", + targetEdge); + return false; + } + } + + if (frozenNodeToStartNodeMap.containsKey(targetNodeId)) { + LOG.info( + "Modification for edge {} is not allowed as the target node with id {} is in frozen list.", + targetEdge, + targetNodeId); + return false; + } + + StreamPartitioner<?> newPartitioner = requestInfo.getOutputPartitioner(); + if (newPartitioner != null + && targetEdge.getPartitioner().getClass().equals(ForwardPartitioner.class)) { + LOG.info( + "Modification for edge {} is not allowed as the origin partitioner is ForwardPartitioner.", + targetEdge); + return false; + } + + return true; + } + + private void modifyOutputPartitioner( + StreamEdge targetEdge, StreamPartitioner<?> newPartitioner) { + if (newPartitioner == null || targetEdge == null) { + return; + } + Integer sourceNodeId = targetEdge.getSourceId(); + Integer targetNodeId = targetEdge.getTargetId(); + + StreamPartitioner<?> oldPartitioner = targetEdge.getPartitioner(); + + targetEdge.setPartitioner(newPartitioner); + + // For non-chainable edges, we change the ForwardPartitioner to RescalePartitioner to avoid + // limiting the parallelism of the downstream node by the forward edge. + // 1. If the upstream job vertex is created. + if (targetEdge.getPartitioner() instanceof ForwardPartitioner + && frozenNodeToStartNodeMap.containsKey(sourceNodeId)) { + targetEdge.setPartitioner(new RescalePartitioner<>()); + } + // 2. If the source and target are non-chainable. + if (targetEdge.getPartitioner() instanceof ForwardPartitioner + && !StreamingJobGraphGenerator.isChainable(targetEdge, streamGraph)) { + targetEdge.setPartitioner(new RescalePartitioner<>()); + } + // 3. If the forward group cannot be merged. + if (targetEdge.getPartitioner() instanceof ForwardPartitioner + && !mergeForwardGroups(sourceNodeId, targetNodeId)) { + targetEdge.setPartitioner(new RescalePartitioner<>()); + } + + Map<StreamEdge, NonChainedOutput> opIntermediateOutputs = + opIntermediateOutputsCaches.get(sourceNodeId); + NonChainedOutput output = + opIntermediateOutputs != null ? opIntermediateOutputs.get(targetEdge) : null; + if (output != null) { + output.setPartitioner(targetEdge.getPartitioner()); + } + LOG.info( + "The original partitioner of the edge {} is: {} , requested change to: {} , and finally modified to: {}.", + targetEdge, + oldPartitioner, + newPartitioner, + targetEdge.getPartitioner()); + } + + private boolean mergeForwardGroups(Integer sourceNodeId, Integer targetNodeId) { + StreamNodeForwardGroup sourceForwardGroup = + startAndEndNodeIdToForwardGroupMap.get(sourceNodeId); + StreamNodeForwardGroup targetForwardGroup = + startAndEndNodeIdToForwardGroupMap.get(targetNodeId); + if (sourceForwardGroup == null || targetForwardGroup == null) { + return false; + } + + if (!ForwardGroupComputeUtil.canTargetMergeIntoSourceForwardGroup( + sourceForwardGroup, targetForwardGroup)) { + return false; + } + + // sanity check Review Comment: The "sanity check" here is a bit misleading because here is not only a sanity check but also a critical production operation. Maybe remove the above check as mentioned in another comment, and just returns false if the merge method returns false? ########## flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroup.java: ########## @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.jobgraph.forwardgroup; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.streaming.api.graph.StreamNode; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** Stream node level implement for {@link ForwardGroup}. */ +public class StreamNodeForwardGroup implements ForwardGroup { + + private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT; + + private int maxParallelism = JobVertex.MAX_PARALLELISM_DEFAULT; + + private final Map<StreamNode, List<StreamNode>> chainedStreamNodeGroupsByStartNode = + new HashMap<>(); + + // For a group of chained stream nodes, their parallelism is consistent. In order to make + // calculation and usage easier, we only use the start node to calculate forward group. + public StreamNodeForwardGroup( + final Map<StreamNode, List<StreamNode>> chainedStreamNodeGroupsByStartNode) { + checkNotNull(chainedStreamNodeGroupsByStartNode); + + Set<Integer> configuredParallelisms = + chainedStreamNodeGroupsByStartNode.keySet().stream() + .map(StreamNode::getParallelism) + .filter(val -> val > 0) + .collect(Collectors.toSet()); + + checkState(configuredParallelisms.size() <= 1); + + if (configuredParallelisms.size() == 1) { + this.parallelism = configuredParallelisms.iterator().next(); + } + + Set<Integer> configuredMaxParallelisms = + chainedStreamNodeGroupsByStartNode.keySet().stream() + .map(StreamNode::getMaxParallelism) + .filter(val -> val > 0) + .collect(Collectors.toSet()); + + if (!configuredMaxParallelisms.isEmpty()) { + this.maxParallelism = Collections.min(configuredMaxParallelisms); + checkState( + parallelism == ExecutionConfig.PARALLELISM_DEFAULT + || maxParallelism >= parallelism, + "There is a start node in the forward group whose maximum parallelism is smaller than the group's parallelism"); + } + + this.chainedStreamNodeGroupsByStartNode.putAll(chainedStreamNodeGroupsByStartNode); + } + + @Override + public void setParallelism(int parallelism) { + checkState(this.parallelism == ExecutionConfig.PARALLELISM_DEFAULT); + this.parallelism = parallelism; + } + + @Override + public boolean isParallelismDecided() { + return parallelism > 0; + } + + @Override + public int getParallelism() { + checkState(isParallelismDecided()); + return parallelism; + } + + @Override + public boolean isMaxParallelismDecided() { + return maxParallelism > 0; + } + + @Override + public int getMaxParallelism() { + checkState(isMaxParallelismDecided()); + return maxParallelism; + } + + @VisibleForTesting + public int size() { + return chainedStreamNodeGroupsByStartNode.values().stream().mapToInt(List::size).sum(); + } + + public Iterable<StreamNode> getStartNodes() { + return chainedStreamNodeGroupsByStartNode.keySet(); + } + + public Iterable<List<StreamNode>> getChainedStreamNodeGroups() { + return chainedStreamNodeGroupsByStartNode.values(); + } + + /** + * Responds to merge targetForwardGroup into this and update the parallelism information for + * stream nodes in merged forward group. + * + * @param targetForwardGroup The forward group to be merged. + * @return whether the merge was successful. + */ + public boolean mergeForwardGroup(StreamNodeForwardGroup targetForwardGroup) { + checkNotNull(targetForwardGroup); + + if (targetForwardGroup == this) { + return true; + } + + if (!ForwardGroupComputeUtil.canTargetMergeIntoSourceForwardGroup( + this, targetForwardGroup)) { + return false; + } + + this.chainedStreamNodeGroupsByStartNode.putAll( + targetForwardGroup.chainedStreamNodeGroupsByStartNode); + + Set<Integer> configuredParallelisms = new HashSet<>(); + if (this.isParallelismDecided()) { + configuredParallelisms.add(this.getParallelism()); + } + if (targetForwardGroup.isParallelismDecided()) { + configuredParallelisms.add(targetForwardGroup.getParallelism()); + } + + checkState(configuredParallelisms.size() <= 1); + + if (configuredParallelisms.size() == 1) { + this.parallelism = configuredParallelisms.iterator().next(); + } Review Comment: It's better to simplify the check. ``` if (!this.isParallelismDecided()) { this.parallelism = targetForwardGroup.getParallelism(); } else if (targetForwardGroup.isParallelismDecided()) { checkState(this.parallelism == targetForwardGroup.getParallelism()); } ``` ########## flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java: ########## @@ -59,6 +68,28 @@ void testIsolatedVertices() throws Exception { checkGroupSize(groups, 0); } + @Test + void testIsolatedChainedStreamNodeGroups() throws Exception { + Map<StreamNode, List<StreamNode>> topologicallySortedChainedStreamNodeByStartNode = + new LinkedHashMap<>(); + Map<StreamNode, Set<StreamNode>> forwardProducersByStartNode = Collections.emptyMap(); + for (int i = 1; i <= 3; ++i) { + StreamNode streamNode = createStreamNode(i); + topologicallySortedChainedStreamNodeByStartNode.put( + streamNode, Collections.singletonList(streamNode)); + } + + Set<ForwardGroup> groups = + computeForwardGroups( + topologicallySortedChainedStreamNodeByStartNode, + forwardProducersByStartNode); + + // Different from the job vertex forward group, the stream node forward group is allowed to + // contain only one single stream node, as these groups may merge with other groups in the + // future. + checkGroupSize(groups, 3, 1); Review Comment: -> checkGroupSize(groups, 3, 1, 1, 1) ########## flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java: ########## @@ -182,7 +319,56 @@ private static Set<ForwardGroup> computeForwardGroups(JobVertex... vertices) { private static void checkGroupSize( Set<ForwardGroup> groups, int numOfGroups, Integer... sizes) { assertThat(groups.size()).isEqualTo(numOfGroups); - assertThat(groups.stream().map(ForwardGroup::size).collect(Collectors.toList())) + assertThat( + groups.stream() + .map( + group -> { + if (group instanceof JobVertexForwardGroup) { + return ((JobVertexForwardGroup) group).size(); + } else { + return ((StreamNodeForwardGroup) group).size(); + } + }) + .collect(Collectors.toList())) Review Comment: Looks to me this change is not needed if invoking this method as proposed in other comments. ########## flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroup.java: ########## @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.jobgraph.forwardgroup; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.streaming.api.graph.StreamNode; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** Stream node level implement for {@link ForwardGroup}. */ +public class StreamNodeForwardGroup implements ForwardGroup { + + private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT; + + private int maxParallelism = JobVertex.MAX_PARALLELISM_DEFAULT; + + private final Map<StreamNode, List<StreamNode>> chainedStreamNodeGroupsByStartNode = + new HashMap<>(); + + // For a group of chained stream nodes, their parallelism is consistent. In order to make + // calculation and usage easier, we only use the start node to calculate forward group. + public StreamNodeForwardGroup( + final Map<StreamNode, List<StreamNode>> chainedStreamNodeGroupsByStartNode) { + checkNotNull(chainedStreamNodeGroupsByStartNode); + + Set<Integer> configuredParallelisms = + chainedStreamNodeGroupsByStartNode.keySet().stream() + .map(StreamNode::getParallelism) + .filter(val -> val > 0) + .collect(Collectors.toSet()); + + checkState(configuredParallelisms.size() <= 1); + + if (configuredParallelisms.size() == 1) { + this.parallelism = configuredParallelisms.iterator().next(); + } + + Set<Integer> configuredMaxParallelisms = + chainedStreamNodeGroupsByStartNode.keySet().stream() + .map(StreamNode::getMaxParallelism) + .filter(val -> val > 0) + .collect(Collectors.toSet()); + + if (!configuredMaxParallelisms.isEmpty()) { + this.maxParallelism = Collections.min(configuredMaxParallelisms); + checkState( + parallelism == ExecutionConfig.PARALLELISM_DEFAULT + || maxParallelism >= parallelism, + "There is a start node in the forward group whose maximum parallelism is smaller than the group's parallelism"); + } + + this.chainedStreamNodeGroupsByStartNode.putAll(chainedStreamNodeGroupsByStartNode); + } + + @Override + public void setParallelism(int parallelism) { + checkState(this.parallelism == ExecutionConfig.PARALLELISM_DEFAULT); + this.parallelism = parallelism; + } + + @Override + public boolean isParallelismDecided() { + return parallelism > 0; + } + + @Override + public int getParallelism() { + checkState(isParallelismDecided()); + return parallelism; + } + + @Override + public boolean isMaxParallelismDecided() { + return maxParallelism > 0; + } + + @Override + public int getMaxParallelism() { + checkState(isMaxParallelismDecided()); + return maxParallelism; + } + + @VisibleForTesting + public int size() { + return chainedStreamNodeGroupsByStartNode.values().stream().mapToInt(List::size).sum(); + } + + public Iterable<StreamNode> getStartNodes() { + return chainedStreamNodeGroupsByStartNode.keySet(); + } + + public Iterable<List<StreamNode>> getChainedStreamNodeGroups() { + return chainedStreamNodeGroupsByStartNode.values(); + } + + /** + * Responds to merge targetForwardGroup into this and update the parallelism information for + * stream nodes in merged forward group. + * + * @param targetForwardGroup The forward group to be merged. + * @return whether the merge was successful. + */ + public boolean mergeForwardGroup(StreamNodeForwardGroup targetForwardGroup) { + checkNotNull(targetForwardGroup); + + if (targetForwardGroup == this) { + return true; + } + + if (!ForwardGroupComputeUtil.canTargetMergeIntoSourceForwardGroup( + this, targetForwardGroup)) { + return false; + } + + this.chainedStreamNodeGroupsByStartNode.putAll( + targetForwardGroup.chainedStreamNodeGroupsByStartNode); + + Set<Integer> configuredParallelisms = new HashSet<>(); + if (this.isParallelismDecided()) { + configuredParallelisms.add(this.getParallelism()); + } + if (targetForwardGroup.isParallelismDecided()) { + configuredParallelisms.add(targetForwardGroup.getParallelism()); + } + + checkState(configuredParallelisms.size() <= 1); + + if (configuredParallelisms.size() == 1) { + this.parallelism = configuredParallelisms.iterator().next(); + } + + Set<Integer> configuredMaxParallelisms = new HashSet<>(); + if (this.isMaxParallelismDecided()) { + configuredMaxParallelisms.add(this.getMaxParallelism()); + } + if (targetForwardGroup.isMaxParallelismDecided()) { + configuredMaxParallelisms.add(targetForwardGroup.getMaxParallelism()); + } + + if (!configuredMaxParallelisms.isEmpty()) { + this.maxParallelism = Collections.min(configuredMaxParallelisms); + checkState( + parallelism == ExecutionConfig.PARALLELISM_DEFAULT + || maxParallelism >= parallelism); + } Review Comment: It's better to simplify it. ``` if (targetForwardGroup.isMaxParallelismDecided() && (!this.isMaxParallelismDecided() || targetForwardGroup.getMaxParallelism() < this.maxParallelism)) { this.maxParallelism = targetForwardGroup.getMaxParallelism(); checkState(parallelism == ExecutionConfig.PARALLELISM_DEFAULT || maxParallelism >= parallelism); } ``` ########## flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/StreamNodeForwardGroup.java: ########## @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.jobgraph.forwardgroup; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.streaming.api.graph.StreamNode; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** Stream node level implement for {@link ForwardGroup}. */ +public class StreamNodeForwardGroup implements ForwardGroup { + + private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT; + + private int maxParallelism = JobVertex.MAX_PARALLELISM_DEFAULT; + + private final Map<StreamNode, List<StreamNode>> chainedStreamNodeGroupsByStartNode = + new HashMap<>(); + + // For a group of chained stream nodes, their parallelism is consistent. In order to make + // calculation and usage easier, we only use the start node to calculate forward group. + public StreamNodeForwardGroup( + final Map<StreamNode, List<StreamNode>> chainedStreamNodeGroupsByStartNode) { + checkNotNull(chainedStreamNodeGroupsByStartNode); + + Set<Integer> configuredParallelisms = + chainedStreamNodeGroupsByStartNode.keySet().stream() + .map(StreamNode::getParallelism) + .filter(val -> val > 0) + .collect(Collectors.toSet()); + + checkState(configuredParallelisms.size() <= 1); + + if (configuredParallelisms.size() == 1) { + this.parallelism = configuredParallelisms.iterator().next(); + } + + Set<Integer> configuredMaxParallelisms = + chainedStreamNodeGroupsByStartNode.keySet().stream() + .map(StreamNode::getMaxParallelism) + .filter(val -> val > 0) + .collect(Collectors.toSet()); + + if (!configuredMaxParallelisms.isEmpty()) { + this.maxParallelism = Collections.min(configuredMaxParallelisms); + checkState( + parallelism == ExecutionConfig.PARALLELISM_DEFAULT + || maxParallelism >= parallelism, + "There is a start node in the forward group whose maximum parallelism is smaller than the group's parallelism"); + } + + this.chainedStreamNodeGroupsByStartNode.putAll(chainedStreamNodeGroupsByStartNode); + } + + @Override + public void setParallelism(int parallelism) { + checkState(this.parallelism == ExecutionConfig.PARALLELISM_DEFAULT); + this.parallelism = parallelism; + } + + @Override + public boolean isParallelismDecided() { + return parallelism > 0; + } + + @Override + public int getParallelism() { + checkState(isParallelismDecided()); + return parallelism; + } + + @Override + public boolean isMaxParallelismDecided() { + return maxParallelism > 0; + } + + @Override + public int getMaxParallelism() { + checkState(isMaxParallelismDecided()); + return maxParallelism; + } + + @VisibleForTesting + public int size() { + return chainedStreamNodeGroupsByStartNode.values().stream().mapToInt(List::size).sum(); + } + + public Iterable<StreamNode> getStartNodes() { + return chainedStreamNodeGroupsByStartNode.keySet(); + } + + public Iterable<List<StreamNode>> getChainedStreamNodeGroups() { + return chainedStreamNodeGroupsByStartNode.values(); + } + + /** + * Responds to merge targetForwardGroup into this and update the parallelism information for + * stream nodes in merged forward group. + * + * @param targetForwardGroup The forward group to be merged. + * @return whether the merge was successful. + */ + public boolean mergeForwardGroup(StreamNodeForwardGroup targetForwardGroup) { + checkNotNull(targetForwardGroup); + + if (targetForwardGroup == this) { + return true; + } + + if (!ForwardGroupComputeUtil.canTargetMergeIntoSourceForwardGroup( + this, targetForwardGroup)) { + return false; + } + + this.chainedStreamNodeGroupsByStartNode.putAll( + targetForwardGroup.chainedStreamNodeGroupsByStartNode); + + Set<Integer> configuredParallelisms = new HashSet<>(); + if (this.isParallelismDecided()) { + configuredParallelisms.add(this.getParallelism()); + } + if (targetForwardGroup.isParallelismDecided()) { + configuredParallelisms.add(targetForwardGroup.getParallelism()); + } + + checkState(configuredParallelisms.size() <= 1); + + if (configuredParallelisms.size() == 1) { + this.parallelism = configuredParallelisms.iterator().next(); + } + + Set<Integer> configuredMaxParallelisms = new HashSet<>(); + if (this.isMaxParallelismDecided()) { + configuredMaxParallelisms.add(this.getMaxParallelism()); + } + if (targetForwardGroup.isMaxParallelismDecided()) { + configuredMaxParallelisms.add(targetForwardGroup.getMaxParallelism()); + } + + if (!configuredMaxParallelisms.isEmpty()) { + this.maxParallelism = Collections.min(configuredMaxParallelisms); + checkState( + parallelism == ExecutionConfig.PARALLELISM_DEFAULT + || maxParallelism >= parallelism); + } + + if (this.isParallelismDecided() || this.isMaxParallelismDecided()) { Review Comment: It's better to introduce two methods `updateNodeParallelism` and `updateNodeMaxParallelism` and invoke them only when any value really changes. Besides that, looks to me here the invocation of `setMaxParallelism(int)` will always set the `parallelismConfigured` of a `StreamNode` to true. Not sure if it is expected? cc @JunRuiLee ########## flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java: ########## @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil; +import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkState; + +/** Default implementation for {@link StreamGraphContext}. */ +@Internal +public class DefaultStreamGraphContext implements StreamGraphContext { + + private static final Logger LOG = LoggerFactory.getLogger(DefaultStreamGraphContext.class); + + private final StreamGraph streamGraph; + private final ImmutableStreamGraph immutableStreamGraph; + private final Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap; + private final Map<Integer, Integer> frozenNodeToStartNodeMap; + private final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches; + + public DefaultStreamGraphContext( + StreamGraph streamGraph, + Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap, + Map<Integer, Integer> frozenNodeToStartNodeMap, + Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches) { + this.streamGraph = streamGraph; + this.immutableStreamGraph = new ImmutableStreamGraph(streamGraph); + this.startAndEndNodeIdToForwardGroupMap = startAndEndNodeIdToForwardGroupMap; Review Comment: checkNotNull. ########## flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java: ########## @@ -171,6 +276,38 @@ void testOneInputSplitsIntoTwo() throws Exception { checkGroupSize(groups, 1, 3); } + @Test + void testOneInputSplitsIntoTwoForStreamNodeForwardGroup() throws Exception { + Review Comment: Unnecessary empty line. ########## flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java: ########## @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil; +import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkState; + +/** Default implementation for {@link StreamGraphContext}. */ +@Internal +public class DefaultStreamGraphContext implements StreamGraphContext { + + private static final Logger LOG = LoggerFactory.getLogger(DefaultStreamGraphContext.class); + + private final StreamGraph streamGraph; + private final ImmutableStreamGraph immutableStreamGraph; + private final Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap; + private final Map<Integer, Integer> frozenNodeToStartNodeMap; + private final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches; + + public DefaultStreamGraphContext( + StreamGraph streamGraph, + Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap, + Map<Integer, Integer> frozenNodeToStartNodeMap, + Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches) { + this.streamGraph = streamGraph; + this.immutableStreamGraph = new ImmutableStreamGraph(streamGraph); + this.startAndEndNodeIdToForwardGroupMap = startAndEndNodeIdToForwardGroupMap; + this.frozenNodeToStartNodeMap = frozenNodeToStartNodeMap; + this.opIntermediateOutputsCaches = opIntermediateOutputsCaches; + } + + @Override + public ImmutableStreamGraph getStreamGraph() { + return immutableStreamGraph; + } + + @Override + public boolean modifyStreamEdge(List<StreamEdgeUpdateRequestInfo> requestInfos) { + // We first verify the legality of all requestInfos to ensure that all requests can be + // modified atomically. + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + if (!modifyStreamEdgeValidate(requestInfo)) { + return false; + } + } + + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + StreamEdge targetEdge = + getStreamEdge( + requestInfo.getSourceId(), + requestInfo.getTargetId(), + requestInfo.getEdgeId()); + StreamPartitioner<?> newPartitioner = requestInfo.getOutputPartitioner(); + if (newPartitioner != null) { + modifyOutputPartitioner(targetEdge, newPartitioner); + } + } + + return true; + } + + private boolean modifyStreamEdgeValidate(StreamEdgeUpdateRequestInfo requestInfo) { Review Comment: maybe "validateStreamEdgeUpdateRequests" ########## flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java: ########## @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil; +import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkState; + +/** Default implementation for {@link StreamGraphContext}. */ +@Internal +public class DefaultStreamGraphContext implements StreamGraphContext { + + private static final Logger LOG = LoggerFactory.getLogger(DefaultStreamGraphContext.class); + + private final StreamGraph streamGraph; + private final ImmutableStreamGraph immutableStreamGraph; + private final Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap; + private final Map<Integer, Integer> frozenNodeToStartNodeMap; + private final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches; + + public DefaultStreamGraphContext( + StreamGraph streamGraph, + Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap, + Map<Integer, Integer> frozenNodeToStartNodeMap, + Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches) { + this.streamGraph = streamGraph; + this.immutableStreamGraph = new ImmutableStreamGraph(streamGraph); + this.startAndEndNodeIdToForwardGroupMap = startAndEndNodeIdToForwardGroupMap; + this.frozenNodeToStartNodeMap = frozenNodeToStartNodeMap; + this.opIntermediateOutputsCaches = opIntermediateOutputsCaches; + } + + @Override + public ImmutableStreamGraph getStreamGraph() { + return immutableStreamGraph; + } + + @Override + public boolean modifyStreamEdge(List<StreamEdgeUpdateRequestInfo> requestInfos) { + // We first verify the legality of all requestInfos to ensure that all requests can be + // modified atomically. + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + if (!modifyStreamEdgeValidate(requestInfo)) { + return false; + } + } + + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + StreamEdge targetEdge = + getStreamEdge( + requestInfo.getSourceId(), + requestInfo.getTargetId(), + requestInfo.getEdgeId()); + StreamPartitioner<?> newPartitioner = requestInfo.getOutputPartitioner(); + if (newPartitioner != null) { + modifyOutputPartitioner(targetEdge, newPartitioner); + } + } + + return true; + } + + private boolean modifyStreamEdgeValidate(StreamEdgeUpdateRequestInfo requestInfo) { + Integer sourceNodeId = requestInfo.getSourceId(); + Integer targetNodeId = requestInfo.getTargetId(); + + StreamEdge targetEdge = getStreamEdge(sourceNodeId, targetNodeId, requestInfo.getEdgeId()); + + if (targetEdge == null) { + return false; + } + + // Modification is not allowed when the subscribing output is reused. + Map<StreamEdge, NonChainedOutput> opIntermediateOutputs = + opIntermediateOutputsCaches.get(sourceNodeId); + NonChainedOutput output = + opIntermediateOutputs != null ? opIntermediateOutputs.get(targetEdge) : null; + if (output != null) { + Set<StreamEdge> consumerStreamEdges = + opIntermediateOutputs.entrySet().stream() + .filter(entry -> entry.getValue().equals(output)) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + if (consumerStreamEdges.size() != 1) { + LOG.info( + "Modification for edge {} is not allowed as the subscribing output is reused.", Review Comment: Maybe "Skip modifying edge {} because the subscribing output is reused." "is not allowed" is more likely to refer to unsupported user actions. ########## flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java: ########## @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil; +import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkState; + +/** Default implementation for {@link StreamGraphContext}. */ +@Internal +public class DefaultStreamGraphContext implements StreamGraphContext { + + private static final Logger LOG = LoggerFactory.getLogger(DefaultStreamGraphContext.class); + + private final StreamGraph streamGraph; + private final ImmutableStreamGraph immutableStreamGraph; + private final Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap; + private final Map<Integer, Integer> frozenNodeToStartNodeMap; + private final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches; + + public DefaultStreamGraphContext( + StreamGraph streamGraph, + Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap, + Map<Integer, Integer> frozenNodeToStartNodeMap, + Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches) { + this.streamGraph = streamGraph; + this.immutableStreamGraph = new ImmutableStreamGraph(streamGraph); + this.startAndEndNodeIdToForwardGroupMap = startAndEndNodeIdToForwardGroupMap; + this.frozenNodeToStartNodeMap = frozenNodeToStartNodeMap; + this.opIntermediateOutputsCaches = opIntermediateOutputsCaches; + } + + @Override + public ImmutableStreamGraph getStreamGraph() { + return immutableStreamGraph; + } + + @Override + public boolean modifyStreamEdge(List<StreamEdgeUpdateRequestInfo> requestInfos) { + // We first verify the legality of all requestInfos to ensure that all requests can be + // modified atomically. + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + if (!modifyStreamEdgeValidate(requestInfo)) { + return false; + } + } + + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + StreamEdge targetEdge = + getStreamEdge( + requestInfo.getSourceId(), + requestInfo.getTargetId(), + requestInfo.getEdgeId()); + StreamPartitioner<?> newPartitioner = requestInfo.getOutputPartitioner(); + if (newPartitioner != null) { + modifyOutputPartitioner(targetEdge, newPartitioner); + } + } + + return true; + } + + private boolean modifyStreamEdgeValidate(StreamEdgeUpdateRequestInfo requestInfo) { + Integer sourceNodeId = requestInfo.getSourceId(); + Integer targetNodeId = requestInfo.getTargetId(); + + StreamEdge targetEdge = getStreamEdge(sourceNodeId, targetNodeId, requestInfo.getEdgeId()); + + if (targetEdge == null) { + return false; + } + + // Modification is not allowed when the subscribing output is reused. + Map<StreamEdge, NonChainedOutput> opIntermediateOutputs = + opIntermediateOutputsCaches.get(sourceNodeId); + NonChainedOutput output = + opIntermediateOutputs != null ? opIntermediateOutputs.get(targetEdge) : null; + if (output != null) { + Set<StreamEdge> consumerStreamEdges = + opIntermediateOutputs.entrySet().stream() + .filter(entry -> entry.getValue().equals(output)) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + if (consumerStreamEdges.size() != 1) { + LOG.info( + "Modification for edge {} is not allowed as the subscribing output is reused.", + targetEdge); + return false; + } + } + + if (frozenNodeToStartNodeMap.containsKey(targetNodeId)) { + LOG.info( + "Modification for edge {} is not allowed as the target node with id {} is in frozen list.", + targetEdge, + targetNodeId); + return false; + } + + StreamPartitioner<?> newPartitioner = requestInfo.getOutputPartitioner(); + if (newPartitioner != null + && targetEdge.getPartitioner().getClass().equals(ForwardPartitioner.class)) { + LOG.info( + "Modification for edge {} is not allowed as the origin partitioner is ForwardPartitioner.", + targetEdge); + return false; + } + + return true; + } + + private void modifyOutputPartitioner( + StreamEdge targetEdge, StreamPartitioner<?> newPartitioner) { + if (newPartitioner == null || targetEdge == null) { + return; + } + Integer sourceNodeId = targetEdge.getSourceId(); + Integer targetNodeId = targetEdge.getTargetId(); + + StreamPartitioner<?> oldPartitioner = targetEdge.getPartitioner(); + + targetEdge.setPartitioner(newPartitioner); + + // For non-chainable edges, we change the ForwardPartitioner to RescalePartitioner to avoid + // limiting the parallelism of the downstream node by the forward edge. + // 1. If the upstream job vertex is created. + if (targetEdge.getPartitioner() instanceof ForwardPartitioner + && frozenNodeToStartNodeMap.containsKey(sourceNodeId)) { + targetEdge.setPartitioner(new RescalePartitioner<>()); + } + // 2. If the source and target are non-chainable. + if (targetEdge.getPartitioner() instanceof ForwardPartitioner + && !StreamingJobGraphGenerator.isChainable(targetEdge, streamGraph)) { + targetEdge.setPartitioner(new RescalePartitioner<>()); + } + // 3. If the forward group cannot be merged. + if (targetEdge.getPartitioner() instanceof ForwardPartitioner + && !mergeForwardGroups(sourceNodeId, targetNodeId)) { + targetEdge.setPartitioner(new RescalePartitioner<>()); + } + + Map<StreamEdge, NonChainedOutput> opIntermediateOutputs = + opIntermediateOutputsCaches.get(sourceNodeId); + NonChainedOutput output = + opIntermediateOutputs != null ? opIntermediateOutputs.get(targetEdge) : null; + if (output != null) { + output.setPartitioner(targetEdge.getPartitioner()); + } + LOG.info( + "The original partitioner of the edge {} is: {} , requested change to: {} , and finally modified to: {}.", + targetEdge, + oldPartitioner, + newPartitioner, + targetEdge.getPartitioner()); + } + + private boolean mergeForwardGroups(Integer sourceNodeId, Integer targetNodeId) { + StreamNodeForwardGroup sourceForwardGroup = + startAndEndNodeIdToForwardGroupMap.get(sourceNodeId); + StreamNodeForwardGroup targetForwardGroup = + startAndEndNodeIdToForwardGroupMap.get(targetNodeId); + if (sourceForwardGroup == null || targetForwardGroup == null) { + return false; + } + + if (!ForwardGroupComputeUtil.canTargetMergeIntoSourceForwardGroup( Review Comment: Looks to it it invokes `canTargetMergeIntoSourceForwardGroup(...)` twice. One here and one in `mergeForwardGroup(...)`. ########## flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java: ########## @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil; +import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkState; + +/** Default implementation for {@link StreamGraphContext}. */ +@Internal +public class DefaultStreamGraphContext implements StreamGraphContext { + + private static final Logger LOG = LoggerFactory.getLogger(DefaultStreamGraphContext.class); + + private final StreamGraph streamGraph; + private final ImmutableStreamGraph immutableStreamGraph; + private final Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap; + private final Map<Integer, Integer> frozenNodeToStartNodeMap; + private final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches; + + public DefaultStreamGraphContext( + StreamGraph streamGraph, + Map<Integer, StreamNodeForwardGroup> startAndEndNodeIdToForwardGroupMap, + Map<Integer, Integer> frozenNodeToStartNodeMap, + Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputsCaches) { + this.streamGraph = streamGraph; + this.immutableStreamGraph = new ImmutableStreamGraph(streamGraph); + this.startAndEndNodeIdToForwardGroupMap = startAndEndNodeIdToForwardGroupMap; + this.frozenNodeToStartNodeMap = frozenNodeToStartNodeMap; + this.opIntermediateOutputsCaches = opIntermediateOutputsCaches; + } + + @Override + public ImmutableStreamGraph getStreamGraph() { + return immutableStreamGraph; + } + + @Override + public boolean modifyStreamEdge(List<StreamEdgeUpdateRequestInfo> requestInfos) { + // We first verify the legality of all requestInfos to ensure that all requests can be + // modified atomically. + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + if (!modifyStreamEdgeValidate(requestInfo)) { + return false; + } + } + + for (StreamEdgeUpdateRequestInfo requestInfo : requestInfos) { + StreamEdge targetEdge = + getStreamEdge( + requestInfo.getSourceId(), + requestInfo.getTargetId(), + requestInfo.getEdgeId()); + StreamPartitioner<?> newPartitioner = requestInfo.getOutputPartitioner(); + if (newPartitioner != null) { + modifyOutputPartitioner(targetEdge, newPartitioner); + } + } + + return true; + } + + private boolean modifyStreamEdgeValidate(StreamEdgeUpdateRequestInfo requestInfo) { + Integer sourceNodeId = requestInfo.getSourceId(); + Integer targetNodeId = requestInfo.getTargetId(); + + StreamEdge targetEdge = getStreamEdge(sourceNodeId, targetNodeId, requestInfo.getEdgeId()); + + if (targetEdge == null) { + return false; + } + + // Modification is not allowed when the subscribing output is reused. + Map<StreamEdge, NonChainedOutput> opIntermediateOutputs = + opIntermediateOutputsCaches.get(sourceNodeId); + NonChainedOutput output = + opIntermediateOutputs != null ? opIntermediateOutputs.get(targetEdge) : null; + if (output != null) { + Set<StreamEdge> consumerStreamEdges = + opIntermediateOutputs.entrySet().stream() + .filter(entry -> entry.getValue().equals(output)) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + if (consumerStreamEdges.size() != 1) { + LOG.info( + "Modification for edge {} is not allowed as the subscribing output is reused.", + targetEdge); + return false; + } + } + + if (frozenNodeToStartNodeMap.containsKey(targetNodeId)) { + LOG.info( + "Modification for edge {} is not allowed as the target node with id {} is in frozen list.", + targetEdge, + targetNodeId); + return false; + } + + StreamPartitioner<?> newPartitioner = requestInfo.getOutputPartitioner(); + if (newPartitioner != null + && targetEdge.getPartitioner().getClass().equals(ForwardPartitioner.class)) { + LOG.info( + "Modification for edge {} is not allowed as the origin partitioner is ForwardPartitioner.", + targetEdge); + return false; + } + + return true; + } + + private void modifyOutputPartitioner( + StreamEdge targetEdge, StreamPartitioner<?> newPartitioner) { + if (newPartitioner == null || targetEdge == null) { + return; + } + Integer sourceNodeId = targetEdge.getSourceId(); + Integer targetNodeId = targetEdge.getTargetId(); + + StreamPartitioner<?> oldPartitioner = targetEdge.getPartitioner(); + + targetEdge.setPartitioner(newPartitioner); + + // For non-chainable edges, we change the ForwardPartitioner to RescalePartitioner to avoid + // limiting the parallelism of the downstream node by the forward edge. + // 1. If the upstream job vertex is created. + if (targetEdge.getPartitioner() instanceof ForwardPartitioner + && frozenNodeToStartNodeMap.containsKey(sourceNodeId)) { + targetEdge.setPartitioner(new RescalePartitioner<>()); + } + // 2. If the source and target are non-chainable. + if (targetEdge.getPartitioner() instanceof ForwardPartitioner + && !StreamingJobGraphGenerator.isChainable(targetEdge, streamGraph)) { + targetEdge.setPartitioner(new RescalePartitioner<>()); + } + // 3. If the forward group cannot be merged. + if (targetEdge.getPartitioner() instanceof ForwardPartitioner + && !mergeForwardGroups(sourceNodeId, targetNodeId)) { + targetEdge.setPartitioner(new RescalePartitioner<>()); + } + + Map<StreamEdge, NonChainedOutput> opIntermediateOutputs = Review Comment: It's better to add some comments for this block. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org