zentol closed pull request #6777: [FLINK-10461] [State Backends, Checkpointing] 
Speed up download files when restore from DFS
URL: https://github.com/apache/flink/pull/6777
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/_includes/generated/rocks_db_configuration.html 
b/docs/_includes/generated/rocks_db_configuration.html
index 8983f8b41dd..81f6b53f117 100644
--- a/docs/_includes/generated/rocks_db_configuration.html
+++ b/docs/_includes/generated/rocks_db_configuration.html
@@ -7,6 +7,11 @@
         </tr>
     </thead>
     <tbody>
+        <tr>
+            
<td><h5>state.backend.rocksdb.checkpoint.restore.thread.num</h5></td>
+            <td style="word-wrap: break-word;">1</td>
+            <td>The number of threads used to download files from DFS in 
RocksDBStateBackend.</td>
+        </tr>
         <tr>
             <td><h5>state.backend.rocksdb.localdir</h5></td>
             <td style="word-wrap: break-word;">(none)</td>
diff --git 
a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
 
b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
index fff8f02285f..96a28903792 100644
--- 
a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
+++ 
b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
@@ -82,6 +82,7 @@ public void testListSerialization() throws Exception {
                                new KeyGroupRange(0, 0),
                                new ExecutionConfig(),
                                false,
+                               1,
                                TestLocalRecoveryConfig.disabled(),
                                RocksDBStateBackend.PriorityQueueStateType.HEAP,
                                TtlTimeProvider.DEFAULT,
@@ -126,6 +127,7 @@ public void testMapSerialization() throws Exception {
                                new KeyGroupRange(0, 0),
                                new ExecutionConfig(),
                                false,
+                               1,
                                TestLocalRecoveryConfig.disabled(),
                                RocksDBStateBackend.PriorityQueueStateType.HEAP,
                                TtlTimeProvider.DEFAULT,
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/util/DirectExecutorService.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/concurrent/DirectExecutorService.java
similarity index 69%
rename from 
flink-runtime/src/test/java/org/apache/flink/runtime/util/DirectExecutorService.java
rename to 
flink-runtime/src/main/java/org/apache/flink/runtime/concurrent/DirectExecutorService.java
index 1d7c971d19c..c37adbdcdd5 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/util/DirectExecutorService.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/concurrent/DirectExecutorService.java
@@ -16,7 +16,9 @@
  * limitations under the License.
  */
 
-package org.apache.flink.runtime.util;
+package org.apache.flink.runtime.concurrent;
+
+import javax.annotation.Nonnull;
 
 import java.util.ArrayList;
 import java.util.Collection;
@@ -31,37 +33,42 @@
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
-public class DirectExecutorService implements ExecutorService {
-       private boolean _shutdown = false;
+/** The direct executor service directly executes the runnables and the 
callables in the calling thread. */
+class DirectExecutorService implements ExecutorService {
+       static final DirectExecutorService INSTANCE = new 
DirectExecutorService();
+
+       private boolean isShutdown = false;
 
        @Override
        public void shutdown() {
-               _shutdown = true;
+               isShutdown = true;
        }
 
        @Override
+       @Nonnull
        public List<Runnable> shutdownNow() {
-               _shutdown = true;
+               isShutdown = true;
                return Collections.emptyList();
        }
 
        @Override
        public boolean isShutdown() {
-               return _shutdown;
+               return isShutdown;
        }
 
        @Override
        public boolean isTerminated() {
-               return _shutdown;
+               return isShutdown;
        }
 
        @Override
-       public boolean awaitTermination(long timeout, TimeUnit unit) throws 
InterruptedException {
-               return _shutdown;
+       public boolean awaitTermination(long timeout, @Nonnull TimeUnit unit) {
+               return isShutdown;
        }
 
        @Override
-       public <T> Future<T> submit(Callable<T> task) {
+       @Nonnull
+       public <T> Future<T> submit(@Nonnull Callable<T> task) {
                try {
                        T result = task.call();
 
@@ -72,34 +79,40 @@ public boolean awaitTermination(long timeout, TimeUnit 
unit) throws InterruptedE
        }
 
        @Override
-       public <T> Future<T> submit(Runnable task, T result) {
+       @Nonnull
+       public <T> Future<T> submit(@Nonnull Runnable task, T result) {
                task.run();
 
                return new CompletedFuture<>(result, null);
        }
 
        @Override
-       public Future<?> submit(Runnable task) {
+       @Nonnull
+       public Future<?> submit(@Nonnull Runnable task) {
                task.run();
                return new CompletedFuture<>(null, null);
        }
 
        @Override
-       public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> 
tasks) throws InterruptedException {
+       @Nonnull
+       public <T> List<Future<T>> invokeAll(@Nonnull Collection<? extends 
Callable<T>> tasks) {
                ArrayList<Future<T>> result = new ArrayList<>();
 
                for (Callable<T> task : tasks) {
                        try {
-                               result.add(new CompletedFuture<T>(task.call(), 
null));
+                               result.add(new CompletedFuture<>(task.call(), 
null));
                        } catch (Exception e) {
-                               result.add(new CompletedFuture<T>(null, e));
+                               result.add(new CompletedFuture<>(null, e));
                        }
                }
                return result;
        }
 
        @Override
-       public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> 
tasks, long timeout, TimeUnit unit) throws InterruptedException {
+       @Nonnull
+       public <T> List<Future<T>> invokeAll(
+               @Nonnull Collection<? extends Callable<T>> tasks, long timeout, 
@Nonnull TimeUnit unit) {
+
                long end = System.currentTimeMillis() + unit.toMillis(timeout);
                Iterator<? extends Callable<T>> iterator = tasks.iterator();
                ArrayList<Future<T>> result = new ArrayList<>();
@@ -108,13 +121,13 @@ public boolean awaitTermination(long timeout, TimeUnit 
unit) throws InterruptedE
                        Callable<T> callable = iterator.next();
 
                        try {
-                               result.add(new 
CompletedFuture<T>(callable.call(), null));
+                               result.add(new 
CompletedFuture<>(callable.call(), null));
                        } catch (Exception e) {
-                               result.add(new CompletedFuture<T>(null, e));
+                               result.add(new CompletedFuture<>(null, e));
                        }
                }
 
-               while(iterator.hasNext()) {
+               while (iterator.hasNext()) {
                        iterator.next();
                        result.add(new Future<T>() {
                                @Override
@@ -133,12 +146,12 @@ public boolean isDone() {
                                }
 
                                @Override
-                               public T get() throws InterruptedException, 
ExecutionException {
+                               public T get() {
                                        throw new CancellationException("Task 
has been cancelled.");
                                }
 
                                @Override
-                               public T get(long timeout, TimeUnit unit) 
throws InterruptedException, ExecutionException, TimeoutException {
+                               public T get(long timeout, @Nonnull TimeUnit 
unit) {
                                        throw new CancellationException("Task 
has been cancelled.");
                                }
                        });
@@ -148,7 +161,8 @@ public T get(long timeout, TimeUnit unit) throws 
InterruptedException, Execution
        }
 
        @Override
-       public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws 
InterruptedException, ExecutionException {
+       @Nonnull
+       public <T> T invokeAny(@Nonnull Collection<? extends Callable<T>> 
tasks) throws ExecutionException {
                Exception exception = null;
 
                for (Callable<T> task : tasks) {
@@ -164,7 +178,11 @@ public T get(long timeout, TimeUnit unit) throws 
InterruptedException, Execution
        }
 
        @Override
-       public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long 
timeout, TimeUnit unit) throws InterruptedException, ExecutionException, 
TimeoutException {
+       public <T> T invokeAny(
+               @Nonnull Collection<? extends Callable<T>> tasks,
+               long timeout,
+               @Nonnull TimeUnit unit) throws ExecutionException, 
TimeoutException {
+
                long end = System.currentTimeMillis() + unit.toMillis(timeout);
                Exception exception = null;
 
@@ -189,15 +207,15 @@ public T get(long timeout, TimeUnit unit) throws 
InterruptedException, Execution
        }
 
        @Override
-       public void execute(Runnable command) {
+       public void execute(@Nonnull Runnable command) {
                command.run();
        }
 
-       public static class CompletedFuture<V> implements Future<V> {
+       static class CompletedFuture<V> implements Future<V> {
                private final V value;
                private final Exception exception;
 
-               public CompletedFuture(V value, Exception exception) {
+               CompletedFuture(V value, Exception exception) {
                        this.value = value;
                        this.exception = exception;
                }
@@ -218,7 +236,7 @@ public boolean isDone() {
                }
 
                @Override
-               public V get() throws InterruptedException, ExecutionException {
+               public V get() throws ExecutionException {
                        if (exception != null) {
                                throw new ExecutionException(exception);
                        } else {
@@ -227,7 +245,7 @@ public V get() throws InterruptedException, 
ExecutionException {
                }
 
                @Override
-               public V get(long timeout, TimeUnit unit) throws 
InterruptedException, ExecutionException, TimeoutException {
+               public V get(long timeout, @Nonnull TimeUnit unit) throws 
ExecutionException {
                        return get();
                }
        }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/concurrent/Executors.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/concurrent/Executors.java
index 703ac4eba7f..41d9a325341 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/concurrent/Executors.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/concurrent/Executors.java
@@ -18,22 +18,16 @@
 
 package org.apache.flink.runtime.concurrent;
 
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import javax.annotation.Nonnull;
-
 import java.util.concurrent.Executor;
+import java.util.concurrent.ExecutorService;
 
 import scala.concurrent.ExecutionContext;
 
 /**
- * Collection of {@link Executor} implementations.
+ * Collection of {@link Executor}, {@link ExecutorService} and {@link 
ExecutionContext} implementations.
  */
 public class Executors {
 
-       private static final Logger LOG = 
LoggerFactory.getLogger(Executors.class);
-
        /**
         * Return a direct executor. The direct executor directly executes the 
runnable in the calling
         * thread.
@@ -41,22 +35,19 @@
         * @return Direct executor
         */
        public static Executor directExecutor() {
-               return DirectExecutor.INSTANCE;
+               return DirectExecutorService.INSTANCE;
        }
 
        /**
-        * Direct executor implementation.
+        * Return a new direct executor service.
+        *
+        * <p>The direct executor service directly executes the runnables and 
the callables in the calling
+        * thread.
+        *
+        * @return New direct executor service
         */
-       private static class DirectExecutor implements Executor {
-
-               static final DirectExecutor INSTANCE = new DirectExecutor();
-
-               private DirectExecutor() {}
-
-               @Override
-               public void execute(@Nonnull Runnable command) {
-                       command.run();
-               }
+       public static ExecutorService newDirectExecutorService() {
+               return new DirectExecutorService();
        }
 
        /**
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/heartbeat/HeartbeatManagerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/heartbeat/HeartbeatManagerTest.java
index e4e86bb2803..f8bfa9443cb 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/heartbeat/HeartbeatManagerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/heartbeat/HeartbeatManagerTest.java
@@ -23,7 +23,6 @@
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.concurrent.ScheduledExecutor;
 import org.apache.flink.runtime.concurrent.ScheduledExecutorServiceAdapter;
-import org.apache.flink.runtime.util.DirectExecutorService;
 import org.apache.flink.util.TestLogger;
 
 import org.junit.Test;
@@ -79,7 +78,7 @@ public void testRegularHeartbeat() {
                        heartbeatTimeout,
                        ownResourceID,
                        heartbeatListener,
-                       new DirectExecutorService(),
+                       Executors.directExecutor(),
                        scheduledExecutor,
                        LOG);
 
@@ -122,7 +121,7 @@ public void testHeartbeatMonitorUpdate() {
                        heartbeatTimeout,
                        ownResourceID,
                        heartbeatListener,
-                       new DirectExecutorService(),
+                       Executors.directExecutor(),
                        scheduledExecutor,
                        LOG);
 
@@ -163,7 +162,7 @@ public void testHeartbeatTimeout() throws Exception {
                        heartbeatTimeout,
                        ownResourceID,
                        heartbeatListener,
-                       new DirectExecutorService(),
+                       Executors.directExecutor(),
                        new ScheduledExecutorServiceAdapter(new 
ScheduledThreadPoolExecutor(1)),
                        LOG);
 
@@ -215,7 +214,7 @@ public void testHeartbeatCluster() throws Exception {
                        heartbeatTimeout,
                        resourceID,
                        heartbeatListener,
-                       new DirectExecutorService(),
+                       Executors.directExecutor(),
                        new ScheduledExecutorServiceAdapter(new 
ScheduledThreadPoolExecutor(1)),
                        LOG);
 
@@ -224,7 +223,7 @@ public void testHeartbeatCluster() throws Exception {
                        heartbeatTimeout,
                        resourceID2,
                        heartbeatListener2,
-                       new DirectExecutorService(),
+                       Executors.directExecutor(),
                        new ScheduledExecutorServiceAdapter(new 
ScheduledThreadPoolExecutor(1)),
                        LOG);
 
@@ -264,7 +263,7 @@ public void testTargetUnmonitoring() throws 
InterruptedException, ExecutionExcep
                        heartbeatTimeout,
                        resourceID,
                        heartbeatListener,
-                       new DirectExecutorService(),
+                       Executors.directExecutor(),
                        new ScheduledExecutorServiceAdapter(new 
ScheduledThreadPoolExecutor(1)),
                        LOG);
 
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index a37f8aa8df8..700c5468c8c 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -37,7 +37,6 @@
 import 
org.apache.flink.contrib.streaming.state.snapshot.RocksFullSnapshotStrategy;
 import 
org.apache.flink.contrib.streaming.state.snapshot.RocksIncrementalSnapshotStrategy;
 import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileStatus;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
@@ -127,6 +126,7 @@
 import java.util.stream.Stream;
 import java.util.stream.StreamSupport;
 
+import static 
org.apache.flink.contrib.streaming.state.RocksDbStateDataTransfer.transferAllStateDataToDirectory;
 import static 
org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.END_OF_KEY_GROUP_MARK;
 import static 
org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.SST_FILE_SUFFIX;
 import static 
org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.clearMetaDataFollowsFlag;
@@ -217,6 +217,9 @@
        /** True if incremental checkpointing is enabled. */
        private final boolean enableIncrementalCheckpointing;
 
+       /** Thread number used to download from DFS when restore. */
+       private final int restoringThreadNum;
+
        /** The configuration of local recovery. */
        private final LocalRecoveryConfig localRecoveryConfig;
 
@@ -251,6 +254,7 @@ public RocksDBKeyedStateBackend(
                KeyGroupRange keyGroupRange,
                ExecutionConfig executionConfig,
                boolean enableIncrementalCheckpointing,
+               int restoringThreadNum,
                LocalRecoveryConfig localRecoveryConfig,
                RocksDBStateBackend.PriorityQueueStateType 
priorityQueueStateType,
                TtlTimeProvider ttlTimeProvider,
@@ -264,6 +268,7 @@ public RocksDBKeyedStateBackend(
                this.operatorIdentifier = 
Preconditions.checkNotNull(operatorIdentifier);
 
                this.enableIncrementalCheckpointing = 
enableIncrementalCheckpointing;
+               this.restoringThreadNum = restoringThreadNum;
                this.rocksDBResourceGuard = new ResourceGuard();
 
                // ensure that we use the right merge operator, because other 
code relies on this
@@ -494,7 +499,7 @@ public void restore(Collection<KeyedStateHandle> 
restoreState) throws Exception
                LOG.info("Initializing RocksDB keyed state backend.");
 
                if (LOG.isDebugEnabled()) {
-                       LOG.debug("Restoring snapshot from state handles: {}.", 
restoreState);
+                       LOG.debug("Restoring snapshot from state handles: {}, 
will use {} thread(s) to download files from DFS.", restoreState, 
restoringThreadNum);
                }
 
                // clear all meta data
@@ -876,7 +881,7 @@ void restoreWithoutRescaling(KeyedStateHandle 
rawStateHandle) throws Exception {
                                        IncrementalKeyedStateHandle 
restoreStateHandle = (IncrementalKeyedStateHandle) rawStateHandle;
 
                                        // read state data.
-                                       
transferAllStateDataToDirectory(restoreStateHandle, 
temporaryRestoreInstancePath);
+                                       
transferAllStateDataToDirectory(restoreStateHandle, 
temporaryRestoreInstancePath, stateBackend.restoringThreadNum, 
stateBackend.cancelStreamRegistry);
 
                                        stateMetaInfoSnapshots = 
readMetaData(restoreStateHandle.getMetaStateHandle());
                                        columnFamilyDescriptors = 
createAndRegisterColumnFamilyDescriptors(stateMetaInfoSnapshots);
@@ -1029,7 +1034,7 @@ private RestoredDBInstance 
restoreDBInstanceFromStateHandle(
                        IncrementalKeyedStateHandle restoreStateHandle,
                        Path temporaryRestoreInstancePath) throws Exception {
 
-                       transferAllStateDataToDirectory(restoreStateHandle, 
temporaryRestoreInstancePath);
+                       transferAllStateDataToDirectory(restoreStateHandle, 
temporaryRestoreInstancePath, stateBackend.restoringThreadNum, 
stateBackend.cancelStreamRegistry);
 
                        // read meta data
                        List<StateMetaInfoSnapshot> stateMetaInfoSnapshots =
@@ -1274,74 +1279,6 @@ private void restoreInstanceDirectoryFromPath(Path 
source) throws IOException {
                                }
                        }
                }
-
-               private void transferAllStateDataToDirectory(
-                       IncrementalKeyedStateHandle restoreStateHandle,
-                       Path dest) throws IOException {
-
-                       final Map<StateHandleID, StreamStateHandle> sstFiles =
-                               restoreStateHandle.getSharedState();
-                       final Map<StateHandleID, StreamStateHandle> miscFiles =
-                               restoreStateHandle.getPrivateState();
-
-                       transferAllDataFromStateHandles(sstFiles, dest);
-                       transferAllDataFromStateHandles(miscFiles, dest);
-               }
-
-               /**
-                * Copies all the files from the given stream state handles to 
the given path, renaming the files w.r.t. their
-                * {@link StateHandleID}.
-                */
-               private void transferAllDataFromStateHandles(
-                       Map<StateHandleID, StreamStateHandle> stateHandleMap,
-                       Path restoreInstancePath) throws IOException {
-
-                       for (Map.Entry<StateHandleID, StreamStateHandle> entry 
: stateHandleMap.entrySet()) {
-                               StateHandleID stateHandleID = entry.getKey();
-                               StreamStateHandle remoteFileHandle = 
entry.getValue();
-                               copyStateDataHandleData(new 
Path(restoreInstancePath, stateHandleID.toString()), remoteFileHandle);
-                       }
-
-               }
-
-               /**
-                * Copies the file from a single state handle to the given path.
-                */
-               private void copyStateDataHandleData(
-                       Path restoreFilePath,
-                       StreamStateHandle remoteFileHandle) throws IOException {
-
-                       FileSystem restoreFileSystem = 
restoreFilePath.getFileSystem();
-
-                       FSDataInputStream inputStream = null;
-                       FSDataOutputStream outputStream = null;
-
-                       try {
-                               inputStream = 
remoteFileHandle.openInputStream();
-                               
stateBackend.cancelStreamRegistry.registerCloseable(inputStream);
-
-                               outputStream = 
restoreFileSystem.create(restoreFilePath, FileSystem.WriteMode.OVERWRITE);
-                               
stateBackend.cancelStreamRegistry.registerCloseable(outputStream);
-
-                               byte[] buffer = new byte[8 * 1024];
-                               while (true) {
-                                       int numBytes = inputStream.read(buffer);
-                                       if (numBytes == -1) {
-                                               break;
-                                       }
-
-                                       outputStream.write(buffer, 0, numBytes);
-                               }
-                       } finally {
-                               if 
(stateBackend.cancelStreamRegistry.unregisterCloseable(inputStream)) {
-                                       inputStream.close();
-                               }
-
-                               if 
(stateBackend.cancelStreamRegistry.unregisterCloseable(outputStream)) {
-                                       outputStream.close();
-                               }
-                       }
-               }
        }
 
        // 
------------------------------------------------------------------------
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOptions.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOptions.java
index c85a7b2077c..9b15bf13fd1 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOptions.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOptions.java
@@ -45,4 +45,12 @@
                .withDescription(String.format("This determines the factory for 
timer service state implementation. Options " +
                        "are either %s (heap-based, default) or %s for an 
implementation based on RocksDB .",
                        HEAP.name(), ROCKSDB.name()));
+
+       /**
+        * The number of threads used to download files from DFS in 
RocksDBStateBackend.
+        */
+       public static final ConfigOption<Integer> CHECKPOINT_RESTORE_THREAD_NUM 
= ConfigOptions
+               .key("state.backend.rocksdb.checkpoint.restore.thread.num")
+               .defaultValue(1)
+               .withDescription("The number of threads used to download files 
from DFS in RocksDBStateBackend.");
 }
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
index 794e22160a2..080e7cfda72 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
@@ -40,6 +40,7 @@
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.util.AbstractID;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.TernaryBoolean;
 
 import org.rocksdb.ColumnFamilyOptions;
@@ -97,6 +98,8 @@
        /** Flag whether the native library has been loaded. */
        private static boolean rocksDbInitialized = false;
 
+       private static final int UNDEFINED_NUMBER_OF_RESTORING_THREADS = -1;
+
        // 
------------------------------------------------------------------------
 
        // -- configuration values, set in the application / configuration
@@ -120,6 +123,9 @@
        /** This determines if incremental checkpointing is enabled. */
        private final TernaryBoolean enableIncrementalCheckpointing;
 
+       /** Thread number used to download from DFS when restore, default 
value: 1. */
+       private int numberOfRestoringThreads;
+
        /** This determines the type of priority queue state. */
        private final PriorityQueueStateType priorityQueueStateType;
 
@@ -238,6 +244,7 @@ public RocksDBStateBackend(StateBackend 
checkpointStreamBackend) {
        public RocksDBStateBackend(StateBackend checkpointStreamBackend, 
TernaryBoolean enableIncrementalCheckpointing) {
                this.checkpointStreamBackend = 
checkNotNull(checkpointStreamBackend);
                this.enableIncrementalCheckpointing = 
enableIncrementalCheckpointing;
+               this.numberOfRestoringThreads = 
UNDEFINED_NUMBER_OF_RESTORING_THREADS;
                // for now, we use still the heap-based implementation as 
default
                this.priorityQueueStateType = PriorityQueueStateType.HEAP;
                this.defaultMetricOptions = new RocksDBNativeMetricOptions();
@@ -276,6 +283,12 @@ private RocksDBStateBackend(RocksDBStateBackend original, 
Configuration config)
                this.enableIncrementalCheckpointing = 
original.enableIncrementalCheckpointing.resolveUndefined(
                        
config.getBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS));
 
+               if (original.numberOfRestoringThreads == 
UNDEFINED_NUMBER_OF_RESTORING_THREADS) {
+                       this.numberOfRestoringThreads = 
config.getInteger(RocksDBOptions.CHECKPOINT_RESTORE_THREAD_NUM);
+               } else {
+                       this.numberOfRestoringThreads = 
original.numberOfRestoringThreads;
+               }
+
                final String priorityQueueTypeString = 
config.getString(TIMER_SERVICE_FACTORY);
 
                this.priorityQueueStateType = priorityQueueTypeString.length() 
> 0 ?
@@ -452,6 +465,7 @@ public CheckpointStorage createCheckpointStorage(JobID 
jobId) throws IOException
                                keyGroupRange,
                                env.getExecutionConfig(),
                                isIncrementalCheckpointsEnabled(),
+                               getNumberOfRestoringThreads(),
                                localRecoveryConfig,
                                priorityQueueStateType,
                                ttlTimeProvider,
@@ -686,6 +700,20 @@ public RocksDBNativeMetricOptions 
getMemoryWatcherOptions() {
                return options;
        }
 
+       /**
+        * Gets the thread number will used for downloading files from DFS when 
restore.
+        */
+       public int getNumberOfRestoringThreads() {
+               return numberOfRestoringThreads == 
UNDEFINED_NUMBER_OF_RESTORING_THREADS ?
+                       
RocksDBOptions.CHECKPOINT_RESTORE_THREAD_NUM.defaultValue() : 
numberOfRestoringThreads;
+       }
+
+       public void setNumberOfRestoringThreads(int numberOfRestoringThreads) {
+               Preconditions.checkArgument(numberOfRestoringThreads > 0,
+                       "The number of threads used to download files from DFS 
in RocksDBStateBackend should > 0.");
+               this.numberOfRestoringThreads = numberOfRestoringThreads;
+       }
+
        // 
------------------------------------------------------------------------
        //  utilities
        // 
------------------------------------------------------------------------
@@ -696,6 +724,7 @@ public String toString() {
                                "checkpointStreamBackend=" + 
checkpointStreamBackend +
                                ", localRocksDbDirectories=" + 
Arrays.toString(localRocksDbDirectories) +
                                ", enableIncrementalCheckpointing=" + 
enableIncrementalCheckpointing +
+                               ", numberOfRestoringThreads=" + 
numberOfRestoringThreads +
                                '}';
        }
 
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDbStateDataTransfer.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDbStateDataTransfer.java
new file mode 100644
index 00000000000..03e114da282
--- /dev/null
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDbStateDataTransfer.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.concurrent.FutureUtils;
+import org.apache.flink.runtime.state.IncrementalKeyedStateHandle;
+import org.apache.flink.runtime.state.StateHandleID;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.function.ThrowingRunnable;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+import static 
org.apache.flink.runtime.concurrent.Executors.newDirectExecutorService;
+
+/**
+ * Data transfer utils for {@link RocksDBKeyedStateBackend}.
+ */
+class RocksDbStateDataTransfer {
+
+       static void transferAllStateDataToDirectory(
+               IncrementalKeyedStateHandle restoreStateHandle,
+               Path dest,
+               int restoringThreadNum,
+               CloseableRegistry closeableRegistry) throws Exception {
+
+               final Map<StateHandleID, StreamStateHandle> sstFiles =
+                       restoreStateHandle.getSharedState();
+               final Map<StateHandleID, StreamStateHandle> miscFiles =
+                       restoreStateHandle.getPrivateState();
+
+               downloadDataForAllStateHandles(sstFiles, dest, 
restoringThreadNum, closeableRegistry);
+               downloadDataForAllStateHandles(miscFiles, dest, 
restoringThreadNum, closeableRegistry);
+       }
+
+       /**
+        * Copies all the files from the given stream state handles to the 
given path, renaming the files w.r.t. their
+        * {@link StateHandleID}.
+        */
+       private static void downloadDataForAllStateHandles(
+               Map<StateHandleID, StreamStateHandle> stateHandleMap,
+               Path restoreInstancePath,
+               int restoringThreadNum,
+               CloseableRegistry closeableRegistry) throws Exception {
+
+               final ExecutorService executorService = 
createExecutorService(restoringThreadNum);
+
+               try {
+                       List<Runnable> runnables = 
createDownloadRunnables(stateHandleMap, restoreInstancePath, closeableRegistry);
+                       List<CompletableFuture<Void>> futures = new 
ArrayList<>(runnables.size());
+                       for (Runnable runnable : runnables) {
+                               
futures.add(CompletableFuture.runAsync(runnable, executorService));
+                       }
+                       FutureUtils.waitForAll(futures).get();
+               } catch (ExecutionException e) {
+                       Throwable throwable = 
ExceptionUtils.stripExecutionException(e);
+                       throwable = ExceptionUtils.stripException(throwable, 
RuntimeException.class);
+                       if (throwable instanceof IOException) {
+                               throw (IOException) throwable;
+                       } else {
+                               throw new FlinkRuntimeException("Failed to 
download data for state handles.", e);
+                       }
+               } finally {
+                       executorService.shutdownNow();
+               }
+       }
+
+       private static ExecutorService createExecutorService(int threadNum) {
+               if (threadNum > 1) {
+                       return Executors.newFixedThreadPool(threadNum);
+               } else {
+                       return newDirectExecutorService();
+               }
+       }
+
+       private static List<Runnable> createDownloadRunnables(
+               Map<StateHandleID, StreamStateHandle> stateHandleMap,
+               Path restoreInstancePath,
+               CloseableRegistry closeableRegistry) {
+               List<Runnable> runnables = new 
ArrayList<>(stateHandleMap.size());
+               for (Map.Entry<StateHandleID, StreamStateHandle> entry : 
stateHandleMap.entrySet()) {
+                       StateHandleID stateHandleID = entry.getKey();
+                       StreamStateHandle remoteFileHandle = entry.getValue();
+
+                       Path path = new Path(restoreInstancePath, 
stateHandleID.toString());
+
+                       runnables.add(ThrowingRunnable.unchecked(
+                               () -> downloadDataForStateHandle(path, 
remoteFileHandle, closeableRegistry)));
+               }
+               return runnables;
+       }
+
+       /**
+        * Copies the file from a single state handle to the given path.
+        */
+       private static void downloadDataForStateHandle(
+               Path restoreFilePath,
+               StreamStateHandle remoteFileHandle,
+               CloseableRegistry closeableRegistry) throws IOException {
+
+               FSDataInputStream inputStream = null;
+               FSDataOutputStream outputStream = null;
+
+               try {
+                       FileSystem restoreFileSystem = 
restoreFilePath.getFileSystem();
+                       inputStream = remoteFileHandle.openInputStream();
+                       closeableRegistry.registerCloseable(inputStream);
+
+                       outputStream = 
restoreFileSystem.create(restoreFilePath, FileSystem.WriteMode.OVERWRITE);
+                       closeableRegistry.registerCloseable(outputStream);
+
+                       byte[] buffer = new byte[8 * 1024];
+                       while (true) {
+                               int numBytes = inputStream.read(buffer);
+                               if (numBytes == -1) {
+                                       break;
+                               }
+
+                               outputStream.write(buffer, 0, numBytes);
+                       }
+               } finally {
+                       if (closeableRegistry.unregisterCloseable(inputStream)) 
{
+                               inputStream.close();
+                       }
+
+                       if 
(closeableRegistry.unregisterCloseable(outputStream)) {
+                               outputStream.close();
+                       }
+               }
+       }
+}
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
index d7d6bdea879..0796c4f00fe 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
@@ -249,6 +249,7 @@ public void testCorrectMergeOperatorSet() throws 
IOException {
                                new KeyGroupRange(0, 0),
                                new ExecutionConfig(),
                                enableIncrementalCheckpointing,
+                               1,
                                TestLocalRecoveryConfig.disabled(),
                                RocksDBStateBackend.PriorityQueueStateType.HEAP,
                                TtlTimeProvider.DEFAULT,
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateDataTransferTest.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateDataTransferTest.java
new file mode 100644
index 00000000000..5b01e438006
--- /dev/null
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateDataTransferTest.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.state.IncrementalKeyedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.StateHandleID;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.UUID;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/**
+ * Tests for {@link RocksDbStateDataTransfer}.
+ */
+public class RocksDBStateDataTransferTest extends TestLogger {
+       @Rule
+       public final TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+       /**
+        * Test that the exception arose in the thread pool will rethrow to the 
main thread.
+        */
+       @Test
+       public void testThreadPoolExceptionRethrow() {
+               SpecifiedException expectedException = new 
SpecifiedException("throw exception while multi thread restore.");
+               StreamStateHandle stateHandle = new StreamStateHandle() {
+                       @Override
+                       public FSDataInputStream openInputStream() throws 
IOException {
+                               throw expectedException;
+                       }
+
+                       @Override
+                       public void discardState() {
+
+                       }
+
+                       @Override
+                       public long getStateSize() {
+                               return 0;
+                       }
+               };
+
+               Map<StateHandleID, StreamStateHandle> stateHandles = new 
HashMap<>(1);
+               stateHandles.put(new StateHandleID("state1"), stateHandle);
+
+               IncrementalKeyedStateHandle incrementalKeyedStateHandle =
+                       new IncrementalKeyedStateHandle(
+                               UUID.randomUUID(),
+                               KeyGroupRange.EMPTY_KEY_GROUP_RANGE,
+                               1,
+                               stateHandles,
+                               stateHandles,
+                               stateHandle);
+
+               try {
+                       
RocksDbStateDataTransfer.transferAllStateDataToDirectory(incrementalKeyedStateHandle,
 new Path(temporaryFolder.newFolder().toURI()), 5, new CloseableRegistry());
+                       fail();
+               } catch (Exception e) {
+                       assertEquals(expectedException, e);
+               }
+       }
+
+       /**
+        * Tests that download files with multi-thread correctly.
+        */
+       @Test
+       public void testMultiThreadRestoreCorrectly() throws Exception {
+               Random random = new Random();
+               int contentNum = 6;
+               byte[][] contents = new byte[contentNum][];
+               for (int i = 0; i < contentNum; ++i) {
+                       contents[i] = new byte[random.nextInt(100000) + 1];
+                       random.nextBytes(contents[i]);
+               }
+
+               List<StreamStateHandle> handles = new ArrayList<>(contentNum);
+               for (int i = 0; i < contentNum; ++i) {
+                       handles.add(new 
ByteStreamStateHandle(String.format("state%d", i), contents[i]));
+               }
+
+               Map<StateHandleID, StreamStateHandle> sharedStates = new 
HashMap<>(contentNum);
+               Map<StateHandleID, StreamStateHandle> privateStates = new 
HashMap<>(contentNum);
+               for (int i = 0; i < contentNum; ++i) {
+                       sharedStates.put(new 
StateHandleID(String.format("sharedState%d", i)), handles.get(i));
+                       privateStates.put(new 
StateHandleID(String.format("privateState%d", i)), handles.get(i));
+               }
+
+               IncrementalKeyedStateHandle incrementalKeyedStateHandle =
+                       new IncrementalKeyedStateHandle(
+                               UUID.randomUUID(),
+                               KeyGroupRange.of(0, 1),
+                               1,
+                               sharedStates,
+                               privateStates,
+                               handles.get(0));
+
+               Path dstPath = new Path(temporaryFolder.newFolder().toURI());
+               
RocksDbStateDataTransfer.transferAllStateDataToDirectory(incrementalKeyedStateHandle,
 dstPath, contentNum - 1, new CloseableRegistry());
+
+               for (int i = 0; i < contentNum; ++i) {
+                       assertStateContentEqual(contents[i], new Path(dstPath, 
String.format("sharedState%d", i)));
+               }
+       }
+
+       private void assertStateContentEqual(byte[] expected, Path path) throws 
IOException {
+               byte[] actual = Files.readAllBytes(Paths.get(path.toUri()));
+               assertArrayEquals(expected, actual);
+       }
+
+       private static class SpecifiedException extends IOException {
+               SpecifiedException(String message) {
+                       super(message);
+               }
+       }
+}
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index a593bd5e004..8dc4b447ba0 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -91,7 +91,6 @@
 import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.runtime.taskmanager.TaskManagerActions;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
-import org.apache.flink.runtime.util.DirectExecutorService;
 import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
@@ -137,6 +136,7 @@
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static 
org.apache.flink.runtime.concurrent.Executors.newDirectExecutorService;
 import static org.hamcrest.Matchers.everyItem;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.hasSize;
@@ -423,7 +423,7 @@ public void testFailingAsyncCheckpointRunnable() throws 
Exception {
                Whitebox.setInternalState(streamTask, "lock", new Object());
                Whitebox.setInternalState(streamTask, "operatorChain", 
operatorChain);
                Whitebox.setInternalState(streamTask, "cancelables", new 
CloseableRegistry());
-               Whitebox.setInternalState(streamTask, 
"asyncOperationsThreadPool", new DirectExecutorService());
+               Whitebox.setInternalState(streamTask, 
"asyncOperationsThreadPool", newDirectExecutorService());
                Whitebox.setInternalState(streamTask, "configuration", new 
StreamConfig(new Configuration()));
                Whitebox.setInternalState(streamTask, "checkpointStorage", new 
MemoryBackendCheckpointStorage(new JobID(), null, null, Integer.MAX_VALUE));
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to