    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + * <p>
    + *
    + * <p>
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.flink.streaming.tests;
    +import org.apache.flink.api.common.functions.MapFunction;
    +import org.apache.flink.api.common.functions.RichFlatMapFunction;
    +import org.apache.flink.api.common.functions.RuntimeContext;
    +import org.apache.flink.api.common.restartstrategy.RestartStrategies;
    +import org.apache.flink.api.common.state.ListState;
    +import org.apache.flink.api.common.state.ListStateDescriptor;
    +import org.apache.flink.api.common.state.ValueState;
    +import org.apache.flink.api.common.state.ValueStateDescriptor;
    +import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
    +import org.apache.flink.runtime.state.CheckpointListener;
    +import org.apache.flink.runtime.state.FunctionInitializationContext;
    +import org.apache.flink.runtime.state.FunctionSnapshotContext;
    +import org.apache.flink.runtime.state.filesystem.FsStateBackend;
    +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
    +import org.apache.flink.streaming.api.environment.CheckpointConfig;
    +import org.apache.flink.streaming.api.functions.sink.PrintSinkFunction;
    +import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
    +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
    +import org.apache.flink.util.Collector;
    +import org.apache.flink.util.Preconditions;
    +import org.apache.commons.lang3.RandomStringUtils;
    +import org.slf4j.Logger;
    +import org.slf4j.LoggerFactory;
    +import java.util.ArrayList;
    +import java.util.HashSet;
    +import java.util.Iterator;
    +import java.util.List;
    +import java.util.Set;
    + * Automatic end-to-end test for local recovery (including sticky 
    + */
    +public class StickyAllocationAndLocalRecoveryTestJob {
    +   private static final Logger LOG = 
    +   public static void main(String[] args) throws Exception {
    +           final ParameterTool pt = ParameterTool.fromArgs(args);
    +           final StreamExecutionEnvironment env = 
    +           env.setParallelism(pt.getInt("parallelism", 1));
    +           env.setMaxParallelism(pt.getInt("maxParallelism", 
pt.getInt("parallelism", 1)));
    +           env.enableCheckpointing(pt.getInt("checkpointInterval", 1000));
pt.getInt("restartDelay", 0)));
    +           if (pt.has("externalizedCheckpoints") && 
pt.getBoolean("externalizedCheckpoints", false)) {
    +           }
    +           String stateBackend = pt.get("stateBackend", "file");
    +           String checkpointDir = pt.getRequired("checkpointDir");
    +           boolean killJvmOnFail = pt.getBoolean("killJvmOnFail", false);
    +           if ("file".equals(stateBackend)) {
    +                   boolean asyncCheckpoints = 
pt.getBoolean("asyncCheckpoints", false);
    +                   env.setStateBackend(new FsStateBackend(checkpointDir, 
    +           } else if ("rocks".equals(stateBackend)) {
    +                   boolean incrementalCheckpoints = 
pt.getBoolean("incrementalCheckpoints", false);
    +                   env.setStateBackend(new 
RocksDBStateBackend(checkpointDir, incrementalCheckpoints));
    +           } else {
    +                   throw new IllegalArgumentException("Unknown backend: " 
+ stateBackend);
    +           }
    +           // make parameters available in the web interface
    +           env.getConfig().setGlobalJobParameters(pt);
    +           // delay to throttle down the production of the source
    +           long delay = pt.has("delay") ? pt.getLong("delay") : 0L;
    +           // the maximum number of attempts, before the job finishes with 
    +           int maxAttempts = pt.has("maxAttempts") ? 
pt.getInt("maxAttempts") : 3;
    +           // size of one artificial value
    +           int valueSize = pt.has("valueSize") ? pt.getInt("valueSize") : 
    +           env.addSource(new RandomLongSource(maxAttempts, delay))
    +                   .keyBy((KeySelector<Long, Long>) aLong -> aLong)
    +                   .flatMap(new StateCreatingFlatMap(valueSize, 
    +                   .map((MapFunction<String, String>) value -> value)
    +                   .addSink(new PrintSinkFunction<>());
    +           env.execute("Sticky Allocation And Local Recovery Test");
    +   }
    +   /**
    +    * Source function that produces a long sequence.
    +    */
    +   private static final class RandomLongSource extends 
RichSourceFunction<Long> implements CheckpointedFunction {
    +           private static final long serialVersionUID = 1L;
    +           /**
    +            * Generator delay between two events.
    +            */
    +           final long delay;
    +           /**
    +            * Maximum restarts before shutting down this source.
    +            */
    +           final int maxAttempts;
    +           /**
    +            * State that holds the current key for recovery.
    +            */
    +           ListState<Long> sourceCurrentKeyState;
    +           /**
    +            * Generator's current key.
    +            */
    +           long currentKey;
    +           /**
    +            * Generator runs while this is true.
    +            */
    +           volatile boolean running;
    +           RandomLongSource(int maxAttempts, long delay) {
    +                   this.delay = delay;
    +                   this.maxAttempts = maxAttempts;
    +                   this.running = true;
    +           }
    +           @Override
    +           public void run(SourceContext<Long> sourceContext) throws 
Exception {
    +                   int numberOfParallelSubtasks = 
    +                   int subtaskIdx = 
    +                   // the source emits one final event and shuts down once 
we have reached max attempts.
    +                   if (getRuntimeContext().getAttemptNumber() > 
maxAttempts) {
    +                           sourceContext.collect(Long.MAX_VALUE - 
    +                           return;
    +                   }
    +                   while (running) {
    +                           sourceContext.collect(currentKey);
    +                           currentKey += numberOfParallelSubtasks;
    +                           if (delay > 0) {
    +                                   Thread.sleep(delay);
    +                           }
    +                   }
    +           }
    +           @Override
    +           public void cancel() {
    +                   running = false;
    +           }
    +           @Override
    +           public void snapshotState(FunctionSnapshotContext context) 
throws Exception {
    +                   sourceCurrentKeyState.clear();
    +                   sourceCurrentKeyState.add(currentKey);
    +           }
    +           @Override
    +           public void initializeState(FunctionInitializationContext 
context) throws Exception {
    +                   ListStateDescriptor<Long> currentKeyDescriptor = new 
ListStateDescriptor<>("currentKey", Long.class);
    +                   sourceCurrentKeyState = 
    +                   currentKey = 
    +                   Iterable<Long> iterable = sourceCurrentKeyState.get();
    +                   if (iterable != null) {
    +                           Iterator<Long> iterator = iterable.iterator();
    +                           if (iterator.hasNext()) {
    +                                   currentKey =;
    +                           }
    +                   }
    +           }
    +   }
    +   /**
    +    * Stateful map function. Failure creation and checks happen here.
    +    */
    +   private static final class StateCreatingFlatMap
    +           extends RichFlatMapFunction<Long, String> implements 
CheckpointedFunction, CheckpointListener {
    +           private static final long serialVersionUID = 1L;
    +           /**
    +            * User configured size of the generated artificial values in 
the keyed state.
    +            */
    +           final int valueSize;
    +           /**
    +            * Holds the user configuration if the artificial test failure 
is killing the JVM.
    +            */
    +           final boolean killTaskOnFailure;
    +           /**
    +            * This state is used to create artificial keyed state in the 
    +            */
    +           transient ValueState<String> valueState;
    +           /**
    +            * This state is used to persist the schedulingAndFailureInfo 
to state.
    +            */
    +           transient ListState<MapperSchedulingAndFailureInfo> 
    +           /**
    +            * This contains the current scheduling and failure meta data.
    +            */
    +           transient MapperSchedulingAndFailureInfo 
    +           /**
    +            * Message to indicate that recovery detected a failure with 
sticky allocation.
    +            */
    +           transient volatile String allocationFailureMessage;
    +           /**
    +            * If this flag is true, the next invocation of the map 
function introduces a test failure.
    +            */
    +           transient volatile boolean failTask;
    +           StateCreatingFlatMap(int valueSize, boolean killTaskOnFailure) {
    +                   this.valueSize = valueSize;
    +                   this.failTask = false;
    +                   this.killTaskOnFailure = killTaskOnFailure;
    +                   this.allocationFailureMessage = null;
    +           }
    +           @Override
    +           public void flatMap(Long key, Collector<String> collector) 
throws IOException {
    +                   if (allocationFailureMessage != null) {
    +                           // Report the failure downstream, so that we 
can get the message from the output.
    +                           collector.collect(allocationFailureMessage);
    +                           allocationFailureMessage = null;
    +                   }
    +                   if (failTask) {
    +                           // we fail the task, either by killing the JVM 
hard, or by throwing a user code exception.
    +                           if (killTaskOnFailure) {
    +                                   Runtime.getRuntime().halt(-1);
    +                           } else {
    +                                   throw new RuntimeException("Artificial 
user code exception.");
    +                           }
    +                   }
    +                   // sanity check
    +                   if (null != valueState.value()) {
    +                           throw new IllegalStateException("This should 
never happen, keys are generated monotonously.");
    +                   }
    +                   // store artificial data to blow up the state
    +                   valueState.update(RandomStringUtils.random(valueSize, 
true, true));
    +           }
    +           @Override
    +           public void snapshotState(FunctionSnapshotContext 
functionSnapshotContext) {
    +           }
    +           @Override
    +           public void initializeState(FunctionInitializationContext 
functionInitializationContext) throws Exception {
    +                   ValueStateDescriptor<String> stateDescriptor =
    +                           new ValueStateDescriptor<>("state", 
    +                   valueState = 
    +                   ListStateDescriptor<MapperSchedulingAndFailureInfo> 
mapperInfoStateDescriptor =
    +                           new ListStateDescriptor<>("mapperState", 
    +                   schedulingAndFailureState =
    +                   StreamingRuntimeContext runtimeContext = 
(StreamingRuntimeContext) getRuntimeContext();
    +                   String allocationID = runtimeContext.getAllocationID();
    +                   final int thisJvmPid = getJvmPid();
    +                   final Set<Integer> killedJvmPids = new HashSet<>();
    +                   // here we check if the sticky scheduling worked as 
    +                   if (functionInitializationContext.isRestored()) {
    +                           Iterable<MapperSchedulingAndFailureInfo> 
iterable = schedulingAndFailureState.get();
    +                           String taskNameWithSubtasks = 
    +                           MapperSchedulingAndFailureInfo infoForThisTask 
= null;
    +                           List<MapperSchedulingAndFailureInfo> 
completeInfo = new ArrayList<>();
    +                           if (iterable != null) {
    +                                   for (MapperSchedulingAndFailureInfo 
testInfo : iterable) {
    +                                           completeInfo.add(testInfo);
    +                                           if 
(taskNameWithSubtasks.equals(testInfo.taskNameWithSubtask)) {
    +                                                   infoForThisTask = 
    +                                           }
    +                                           if (testInfo.killedJvm) {
    +                                           }
    +                                   }
    +                           }
    +                           Preconditions.checkNotNull(infoForThisTask, 
"Expected to find info here.");
    +                           if 
(!isScheduledToCorrectAllocation(infoForThisTask, allocationID, killedJvmPids)) 
    +                                   allocationFailureMessage = 
    +                                           "Sticky allocation test failed: 
Subtask %s in attempt %d was rescheduled from allocation %s " +
    +                                                   "on JVM with PID %d to 
unexpected allocation %s on JVM with PID %d.\n" +
    +                                                   "Complete information 
from before the crash: %s.",
    +                                           infoForThisTask.allocationId,
    +                                           infoForThisTask.jvmPid,
    +                                           allocationID,
    +                                           thisJvmPid,
    +                                           completeInfo);
    +                           }
    +                   }
    +                   // We determine which of the subtasks will produce the 
artificial failure
    +                   boolean failingTask = shouldTaskFailForThisAttempt();
    +                   // We take note of all the meta info that we require to 
check sticky scheduling in the next re-attempt
    +                   this.currentSchedulingAndFailureInfo = new 
    +                           failingTask,
    +                           failingTask && killTaskOnFailure,
    +                           thisJvmPid,
    +                           runtimeContext.getTaskNameWithSubtasks(),
    +                           allocationID);
    +                   schedulingAndFailureState.clear();
    +           }
    +           @Override
    +           public void notifyCheckpointComplete(long checkpointId) {
    +                   // we can only fail the task after at least one 
checkpoint is completed to record progress.
    +                   failTask = currentSchedulingAndFailureInfo.failingTask;
    +           }
    +           private boolean shouldTaskFailForThisAttempt() {
    +                   RuntimeContext runtimeContext = getRuntimeContext();
    +                   int numSubtasks = 
    +                   int subtaskIdx = runtimeContext.getIndexOfThisSubtask();
    +                   int attempt = runtimeContext.getAttemptNumber();
    +                   return (attempt % numSubtasks) == subtaskIdx;
    +           }
    +           private boolean isScheduledToCorrectAllocation(
    +                   MapperSchedulingAndFailureInfo infoForThisTask,
    +                   String allocationID,
    +                   Set<Integer> killedJvmPids) {
    +                   return 
    +                           || 
    +           }
    +   }
    +   private static int getJvmPid() throws Exception {
    +  runtime =
    +          ;
    +           java.lang.reflect.Field jvm = 
    +           jvm.setAccessible(true);
    +  mgmt =
    +                   ( jvm.get(runtime);
    +           java.lang.reflect.Method pidMethod =
    +                   mgmt.getClass().getDeclaredMethod("getProcessId");
    +           pidMethod.setAccessible(true);
