rkhachatryan commented on a change in pull request #16773:
URL: https://github.com/apache/flink/pull/16773#discussion_r687818991

File path: 
@@ -0,0 +1,448 @@
+package org.apache.flink.test.savepoint;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.state.CheckpointListener;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.client.program.ClusterClient;
+import org.apache.flink.runtime.OperatorIDPair;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
+import org.apache.flink.runtime.jobgraph.JobEdge;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.jobgraph.JobVertex;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.testutils.CommonTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.testutils.junit.SharedObjects;
+import org.apache.flink.testutils.junit.SharedReference;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import javax.annotation.Nullable;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.ListIterator;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.stream.Collectors;
+import static java.lang.String.format;
+import static java.util.Arrays.asList;
+import static java.util.Collections.emptyMap;
+import static 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+ * A test suite to check that the operator methods are called according to 
contract when the job is
+ * stopped with savepoint. The contract was refined in FLIP-147.
+ *
+ * <p>Checked assumptions:
+ *
+ * <ol>
+ *   <li>Downstream should only be "finished" after all of its the upstreams 
+ *   <li>Order of events when finishing an operator:
+ *       <ol>
+ *         <li>(last data element)
+ *         <li>{@link Watermark#MAX_WATERMARK MAX_WATERMARK} (if with drain)
+ *         <li>{@link BoundedMultiInput#endInput endInput} (if with drain)
+ *         <li>timer service quiesced
+ *         <li>{@link StreamOperator#finish() finish} (if with drain; support 
is planned for
+ *             no-drain)
+ *         <li>{@link 
AbstractStreamOperator#snapshotState(StateSnapshotContext) snapshotState} (for
+ *             the respective checkpoint)
+ *         <li>{@link CheckpointListener#notifyCheckpointComplete 
notifyCheckpointComplete} (for the
+ *             respective checkpoint)
+ *         <li>(task termination)
+ *       </ol>
+ *   <li>Timers can be registered until the operator is finished (though may 
not fire) (simply
+ *       register every 1ms and don't expect any exception)
+ *   <li>The same watermark is received
+ * </ol>
+ *
+ * <p>Variants:
+ *
+ * <ul>
+ *   <li>command - with or without drain (MAX_WATERMARK and endInput should be 
iff drain)
+ *   <li>graph - different exchanges (keyBy, forward)
+ *   <li>graph - multi-inputs (NOT IMPLEMENTED), unions
+ *   <li>graph - FLIP-27 and regular sources (should work for both) - NOT 
+ * </ul>
+ *
+ * <p>Not checked:
+ *
+ * <ul>
+ *   <li>state distribution on recovery (when a new job started from the taken 
savepoint) (a
+ *       separate IT case for partial finishing and state distribution)
+ *   <li>re-taking a savepoint after one fails (and job fails over) (as it 
should not affect
+ *       savepoints)
+ *   <li>taking a savepoint after recovery (as it should not affect savepoints)
+ *   <li>taking a savepoint on a partially completed graph (a separate IT case)
+ * </ul>
+ */
+public class StopWithSavepointITCase extends AbstractTestBase {
+    private static final Logger LOG = 
+    @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder();
+    @Rule public final SharedObjects sharedObjects = SharedObjects.create();
+    @Parameter public boolean withDrain;
+    @Test
+    public void test() throws Exception {
+        StreamExecutionEnvironment env = 
+        env.setParallelism(4);
+        env.setRestartStrategy(noRestart());
+        env.enableCheckpointing(200); // shouldn't matter
+        env.getConfig().setAutoWatermarkInterval(50);
+        SharedReference<List<TestEvent>> eventsRef =
+                sharedObjects.add(new CopyOnWriteArrayList<>());
+        TestSetup testSetup = buildGraph(env, eventsRef);
+        submitAndStopWithSavepoint(testSetup.jobGraph, withDrain, eventsRef);
+        List<TestEvent> events = eventsRef.get();
+        checkOperatorsLifecycle(events, testSetup);
+        if (withDrain) {
+            // currently (1.14), sources do not stop before taking a savepoint 
and continue emission
+            // todo: enable after updating production code
+            checkDataFlow(events, testSetup);
+        }
+    }
+    private static class TestSetup {
+        private final JobGraph jobGraph;
+        private final Set<String> operatorsWithLifecycleTracking;
+        private final Set<String> operatorsWithDataFlowTracking;
+        private TestSetup(
+                JobGraph jobGraph,
+                Set<String> operatorsWithLifecycleTracking,
+                Set<String> operatorsWithDataFlowTracking) {
+            this.jobGraph = jobGraph;
+            this.operatorsWithLifecycleTracking = 
+            this.operatorsWithDataFlowTracking = operatorsWithDataFlowTracking;
+        }
+    }
+    private TestSetup buildGraph(
+            StreamExecutionEnvironment env, SharedReference<List<TestEvent>> 
eventsRef) {
+        // using hashes so that operators emit identifiable events
+        String srcLeft = OP_ID_HASH_PREFIX + "1";
+        String srcRight = OP_ID_HASH_PREFIX + "2";
+        String mapForward = OP_ID_HASH_PREFIX + "3";
+        String mapKeyed = OP_ID_HASH_PREFIX + "4";
+        // todo: add multi-inputs and FLIP-27 sources
+        DataStream<TestEvent> unitedSources =
+                env.addSource(new TestEventSource(srcLeft, eventsRef))
+                        .setUidHash(srcLeft)
+                        .assignTimestampsAndWatermarks(createWmAssigner())
+                        .union(
+                                env.addSource(new TestEventSource(srcRight, 
+                                        .setUidHash(srcRight)
+        SingleOutputStreamOperator<TestEvent> forwardTransform =
+                unitedSources
+                        .transform(
+                                "transform-1-forward",
+                                TypeInformation.of(TestEvent.class),
+                                new 
TestOneInputStreamOperatorFactory(mapForward, eventsRef))
+                        .setUidHash(mapForward);
+        SingleOutputStreamOperator<TestEvent> keyedTransform =
+                forwardTransform
+                        .startNewChain()
+                        .keyBy(e -> e)
+                        .transform(
+                                "transform-2-keyed",
+                                TypeInformation.of(TestEvent.class),
+                                new 
TestOneInputStreamOperatorFactory(mapKeyed, eventsRef))
+                        .setUidHash(mapKeyed);
+        keyedTransform.addSink(new DiscardingSink<>());
+        return new TestSetup(
+                env.getStreamGraph().getJobGraph(),
+                new HashSet<>(asList(mapForward, mapKeyed)),
+                new HashSet<>(asList(srcLeft, srcRight, mapForward, 
+    }
+    private void submitAndStopWithSavepoint(
+            JobGraph jobGraph, boolean withDrain, 
SharedReference<List<TestEvent>> eventsRef)
+            throws Exception {
+        ClusterClient<?> client = miniClusterResource.getClusterClient();
+        JobID job = client.submitJob(jobGraph).get();
+        while (eventsRef.get().stream().noneMatch(e -> e instanceof 
WatermarkReceivedEvent)) {
+            Thread.sleep(100);
+        }
+        client.stopWithSavepoint(job, withDrain, 
+    }
+    // 
+    // validation
+    private void checkOperatorsLifecycle(List<TestEvent> events, TestSetup 
testSetup) {
+        long lastCheckpointID =
+                events.stream()
+                        .filter(e -> e instanceof CheckpointCompletedEvent)
+                        .mapToLong(e -> ((CheckpointCompletedEvent) 
+                        .max()
+                        .getAsLong();
+        long lastWatermarkTs = withDrain ? 
Watermark.MAX_WATERMARK.getTimestamp() : -1;
+        // not checking if the watermark was the same if !withDrain
+        //                        : events.stream()
+        //                                .filter(e -> e instanceof 
+        //                                .mapToLong(e -> 
((WatermarkReceivedEvent) e).ts)
+        //                                .max()
+        //                                .getAsLong();
+        Map<Tuple2<String, Integer>, List<TestEvent>> eventsByOperator = new 
+        for (TestEvent ev : events) {
+            eventsByOperator
+                    .computeIfAbsent(
+                            Tuple2.of(ev.operatorId, ev.subtaskIndex), ign -> 
new ArrayList<>())
+                    .add(ev);
+        }
+        eventsByOperator.forEach(
+                (operatorIdAndIndex, operatorEvents) -> {
+                    String id = operatorIdAndIndex.f0;
+                    if (testSetup.operatorsWithLifecycleTracking.contains(id)) 
+                        assertEquals(
+                                format(
+                                        "Illegal event sequence for %s[%d] 
(lastCheckpointID: %d, lastWatermarkTs: %d)\n",
+                                        id,
+                                        operatorIdAndIndex.f1,
+                                        lastCheckpointID,
+                                        lastWatermarkTs),
+                                buildExpectedEvents(
+                                        id,
+                                        operatorIdAndIndex.f1,
+                                        lastCheckpointID,
+                                        lastWatermarkTs,
+                                        withDrain),
+                                filterValidationEvents(
+                                        operatorEvents, lastCheckpointID, 
+                    }
+                });
+    }
+    private static List<TestEvent> buildExpectedEvents(
+            String operator,
+            int subtask,
+            long lastCheckpointID,
+            long watermarkTs,
+            boolean withDrain) {
+        List<TestEvent> expected = new ArrayList<>();
+        if (withDrain) {
+            // without drain, watermark can be emitted even after stopping 
some operators
+            // so they won't have the same watermark; so only check if 
+            expected.add(new WatermarkReceivedEvent(operator, subtask, 
+            expected.add(new InputEndedEvent(operator, subtask));
+            // currently (1.14), finish is only called withDrain
+            // todo: enable after updating production code
+            expected.add(
+                    new OperatorFinishedEvent(
+                            operator,
+                            subtask,
+                            0,
+                            new 
+        }
+        expected.add(new CheckpointStartedEvent(operator, subtask, 
+        expected.add(new CheckpointCompletedEvent(operator, subtask, 
+        return expected;
+    }
+    private List<TestEvent> filterValidationEvents(
+            List<TestEvent> events, long allowedCheckpointID, long 
allowedWatermarkTS) {
+        return events.stream()
+                .filter(ev -> isValidationEvent(ev, allowedCheckpointID, 
+                .collect(Collectors.toList());
+    }
+    private boolean isValidationEvent(
+            TestEvent ev, long lastCheckpointID, long allowedWatermarkTS) {
+        if (ev instanceof CheckpointCompletedEvent) {
+            return ((CheckpointCompletedEvent) ev).checkpointID == 
+        } else if (ev instanceof CheckpointStartedEvent) {
+            return ((CheckpointStartedEvent) ev).checkpointID == 
+        } else if (ev instanceof WatermarkReceivedEvent) {
+            return ((WatermarkReceivedEvent) ev).ts == allowedWatermarkTS;
+        } else {
+            return !(ev instanceof DataSentEvent);
+        }
+    }
+    /** Check that all data from the upstream reached the respective 
downstreams. */
+    private void checkDataFlow(List<TestEvent> events, TestSetup testSetup) {
+        Map<String, Map<Integer, OperatorFinishedEvent>> finishEvents = new 
+        for (TestEvent ev : events) {
+            if (ev instanceof OperatorFinishedEvent) {
+                finishEvents
+                        .computeIfAbsent(ev.operatorId, ign -> new HashMap<>())
+                        .put(ev.subtaskIndex, ((OperatorFinishedEvent) ev));
+            }
+        }
+        for (JobVertex upstream : testSetup.jobGraph.getVertices()) {
+            for (IntermediateDataSet produced : 
upstream.getProducedDataSets()) {
+                for (JobEdge edge : produced.getConsumers()) {
+                    Optional<String> upstreamID = 
getTrackedOperatorID(upstream, true, testSetup);
+                    Optional<String> downstreamID =
+                            getTrackedOperatorID(edge.getTarget(), false, 
+                    if (upstreamID.isPresent() && downstreamID.isPresent()) {
+                        checkDataFlow(upstreamID.get(), downstreamID.get(), 
edge, finishEvents);
+                    } else {
+                        LOG.debug("Ignoring edge (untracked operator): {}", 
+                    }
+                }
+            }
+        }
+    }
+    private void checkDataFlow(
+            String upstreamID,
+            String downstreamID,
+            JobEdge edge,
+            Map<String, Map<Integer, OperatorFinishedEvent>> finishEvents) {
+        LOG.debug(
+                "Checking {} edge\n  from {} ({})\n  to {} ({})",
+                edge.getDistributionPattern(),
+                edge.getSource().getProducer().getName(),
+                upstreamID,
+                edge.getTarget().getName(),
+                downstreamID);
+        Map<Integer, OperatorFinishedEvent> downstreamFinishInfo =
+                getForOperator(downstreamID, finishEvents);
+        Map<Integer, OperatorFinishedEvent> upstreamFinishInfo =
+                getForOperator(upstreamID, finishEvents);
+        upstreamFinishInfo.forEach(
+                (upstreamIndex, upstreamInfo) ->
+                        assertTrue(
+                                anySubtaskReceived(
+                                        upstreamID,
+                                        upstreamIndex,
+                                        upstreamInfo.lastSent,
+                                        downstreamFinishInfo.values())));

Review comment:
       This is a check that for each upstream subtask there exists a downstream 
subtask that received it's latest emitted element.
   For both keyed and forward exchanges there should be exactly one such 
   Previously, I mistakenly expected that after keyed exchange **all** 
downstream subtasks should receive the same data; however, this is apparently 
the `broadcast` case. It doesn't seem the highest priority to check all tasks 
in the broadcast case, WDYT?

