[ https://issues.apache.org/jira/browse/FLINK-10122?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16609077#comment-16609077 ]
ASF GitHub Bot commented on FLINK-10122: ---------------------------------------- StefanRRichter closed pull request #6537: [FLINK-10122] KafkaConsumer should use partitionable state over union state if partition discovery is not active URL: https://github.com/apache/flink/pull/6537 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java index cfb5b6d510d..3857a968dd5 100644 --- a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java +++ b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java @@ -105,8 +105,11 @@ /** Configuration key to define the consumer's partition discovery interval, in milliseconds. */ public static final String KEY_PARTITION_DISCOVERY_INTERVAL_MILLIS = "flink.partition-discovery.interval-millis"; + /** For backwards compatibility. */ + private static final String OLD_OFFSETS_STATE_NAME = "topic-partition-offset-states"; + /** State name of the consumer's partition offset states. */ - private static final String OFFSETS_STATE_NAME = "topic-partition-offset-states"; + private static final String OFFSETS_STATE_NAME = "kafka-consumer-offsets"; // ------------------------------------------------------------------------ // configuration state, set on the client relevant for all subtasks @@ -180,13 +183,7 @@ private transient volatile TreeMap<KafkaTopicPartition, Long> restoredState; /** Accessor for state in the operator state backend. */ - private transient ListState<Tuple2<KafkaTopicPartition, Long>> unionOffsetStates; - - /** - * Flag indicating whether the consumer is restored from older state written with Flink 1.1 or 1.2. - * When the current run is restored from older state, partition discovery is disabled. - */ - private boolean restoredFromOldState; + private transient ListState<Tuple2<KafkaTopicPartition, Long>> offsetsState; /** Discovery loop, executed in a separate thread. */ private transient volatile Thread discoveryLoopThread; @@ -480,7 +477,7 @@ public void open(Configuration configuration) throws Exception { } for (Map.Entry<KafkaTopicPartition, Long> restoredStateEntry : restoredState.entrySet()) { - if (!restoredFromOldState) { + if (discoveryIntervalMillis != PARTITION_DISCOVERY_DISABLED) { // seed the partition discoverer with the union state while filtering out // restored partitions that should not be subscribed by this subtask if (KafkaTopicPartitionAssigner.assign( @@ -489,8 +486,7 @@ public void open(Configuration configuration) throws Exception { subscribedPartitionsToStartOffsets.put(restoredStateEntry.getKey(), restoredStateEntry.getValue()); } } else { - // when restoring from older 1.1 / 1.2 state, the restored state would not be the union state; - // in this case, just use the restored state as the subscribed partitions + // just restore from assigned partitions subscribedPartitionsToStartOffsets.put(restoredStateEntry.getKey(), restoredStateEntry.getValue()); } } @@ -783,30 +779,26 @@ public final void initializeState(FunctionInitializationContext context) throws OperatorStateStore stateStore = context.getOperatorStateStore(); - ListState<Tuple2<KafkaTopicPartition, Long>> oldRoundRobinListState = - stateStore.getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME); + final TypeInformation<Tuple2<KafkaTopicPartition, Long>> offsetStateTypeInfo = + TypeInformation.of(new TypeHint<Tuple2<KafkaTopicPartition, Long>>() {}); - this.unionOffsetStates = stateStore.getUnionListState(new ListStateDescriptor<>( - OFFSETS_STATE_NAME, - TypeInformation.of(new TypeHint<Tuple2<KafkaTopicPartition, Long>>() {}))); + ListStateDescriptor<Tuple2<KafkaTopicPartition, Long>> offsetStateDescriptor = + new ListStateDescriptor<>(OFFSETS_STATE_NAME, offsetStateTypeInfo); - if (context.isRestored() && !restoredFromOldState) { - restoredState = new TreeMap<>(new KafkaTopicPartition.Comparator()); + this.offsetsState = + discoveryIntervalMillis != PARTITION_DISCOVERY_DISABLED ? + stateStore.getUnionListState(offsetStateDescriptor) : stateStore.getListState(offsetStateDescriptor); - // migrate from 1.2 state, if there is any - for (Tuple2<KafkaTopicPartition, Long> kafkaOffset : oldRoundRobinListState.get()) { - restoredFromOldState = true; - unionOffsetStates.add(kafkaOffset); - } - oldRoundRobinListState.clear(); + if (context.isRestored()) { - if (restoredFromOldState && discoveryIntervalMillis != PARTITION_DISCOVERY_DISABLED) { - throw new IllegalArgumentException( - "Topic / partition discovery cannot be enabled if the job is restored from a savepoint from Flink 1.2.x."); - } + restoredState = new TreeMap<>(new KafkaTopicPartition.Comparator()); + + // backwards compatibility + handleMigration_1_2(stateStore); + handleMigration_1_6(stateStore, offsetStateTypeInfo); // populate actual holder for restored state - for (Tuple2<KafkaTopicPartition, Long> kafkaOffset : unionOffsetStates.get()) { + for (Tuple2<KafkaTopicPartition, Long> kafkaOffset : offsetsState.get()) { restoredState.put(kafkaOffset.f0, kafkaOffset.f1); } @@ -816,19 +808,60 @@ public final void initializeState(FunctionInitializationContext context) throws } } + private void handleMigration_1_2( + OperatorStateStore stateStore) throws Exception { + boolean restoredFromOldState = false; + ListState<Tuple2<KafkaTopicPartition, Long>> oldRoundRobinListState = + stateStore.getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME); + for (Tuple2<KafkaTopicPartition, Long> kafkaOffset : oldRoundRobinListState.get()) { + restoredFromOldState = true; + offsetsState.add(kafkaOffset); + } + + // we remove this state again immediately so it will no longer exist in future check/savepoints + oldRoundRobinListState.clear(); + stateStore.removeOperatorState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME); + + if (restoredFromOldState) { + if (discoveryIntervalMillis != PARTITION_DISCOVERY_DISABLED) { + throw new IllegalArgumentException( + "Topic / partition discovery cannot be enabled if the job is restored from a savepoint from Flink 1.2.x."); + } + } + } + + private void handleMigration_1_6( + OperatorStateStore stateStore, + TypeInformation<Tuple2<KafkaTopicPartition, Long>> offsetStateTypeInfo) throws Exception { + + ListStateDescriptor<Tuple2<KafkaTopicPartition, Long>> oldUnionStateDescriptor = + new ListStateDescriptor<>(OLD_OFFSETS_STATE_NAME, offsetStateTypeInfo); + + ListState<Tuple2<KafkaTopicPartition, Long>> oldUnionListState = + stateStore.getUnionListState(oldUnionStateDescriptor); + + for (Tuple2<KafkaTopicPartition, Long> kafkaOffset : oldUnionListState.get()) { + offsetsState.add(kafkaOffset); + } + + // we remove this state again immediately so it will no longer exist in future check/savepoints + oldUnionListState.clear(); + stateStore.removeOperatorState(oldUnionStateDescriptor.getName()); + } + @Override public final void snapshotState(FunctionSnapshotContext context) throws Exception { if (!running) { LOG.debug("snapshotState() called on closed source"); } else { - unionOffsetStates.clear(); + offsetsState.clear(); final AbstractFetcher<?, ?> fetcher = this.kafkaFetcher; if (fetcher == null) { // the fetcher has not yet been initialized, which means we need to return the // originally restored offsets or the assigned partitions for (Map.Entry<KafkaTopicPartition, Long> subscribedPartition : subscribedPartitionsToStartOffsets.entrySet()) { - unionOffsetStates.add(Tuple2.of(subscribedPartition.getKey(), subscribedPartition.getValue())); + offsetsState.add(Tuple2.of(subscribedPartition.getKey(), subscribedPartition.getValue())); } if (offsetCommitMode == OffsetCommitMode.ON_CHECKPOINTS) { @@ -846,7 +879,7 @@ public final void snapshotState(FunctionSnapshotContext context) throws Exceptio } for (Map.Entry<KafkaTopicPartition, Long> kafkaTopicPartitionLongEntry : currentOffsets.entrySet()) { - unionOffsetStates.add( + offsetsState.add( Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue())); } } diff --git a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java index c9b52415a3e..26eff023b29 100644 --- a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java +++ b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java @@ -643,7 +643,7 @@ private void testRescaling( Collections.singletonList("dummy-topic"), null, (KeyedDeserializationSchema < T >) mock(KeyedDeserializationSchema.class), - PARTITION_DISCOVERY_DISABLED, + 1L, false); this.testFetcher = testFetcher; @@ -884,16 +884,16 @@ public OperatorID getOperatorID() { private static class MockOperatorStateStore implements OperatorStateStore { - private final ListState<?> mockRestoredUnionListState; + private final ListState<?> mockListState; private MockOperatorStateStore(ListState<?> restoredUnionListState) { - this.mockRestoredUnionListState = restoredUnionListState; + this.mockListState = restoredUnionListState; } @Override @SuppressWarnings("unchecked") public <S> ListState<S> getUnionListState(ListStateDescriptor<S> stateDescriptor) throws Exception { - return (ListState<S>) mockRestoredUnionListState; + return (ListState<S>) mockListState; } @Override @@ -914,9 +914,20 @@ private MockOperatorStateStore(ListState<?> restoredUnionListState) { throw new UnsupportedOperationException(); } + @SuppressWarnings("unchecked") @Override public <S> ListState<S> getListState(ListStateDescriptor<S> stateDescriptor) throws Exception { - throw new UnsupportedOperationException(); + return (ListState<S>) mockListState; + } + + @Override + public void removeOperatorState(String name) throws Exception { + + } + + @Override + public void removeBroadcastState(String name) throws Exception { + } @Override diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java index 7a998e6149c..7f0cd6ad0a2 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java @@ -75,6 +75,16 @@ */ <S> ListState<S> getListState(ListStateDescriptor<S> stateDescriptor) throws Exception; + /** + * Removes the operator state with the given name from the state store, if it exists. + */ + void removeOperatorState(String name) throws Exception; + + /** + * Removes the broadcast state with the given name from the state store, if it exists. + */ + void removeBroadcastState(String name) throws Exception; + /** * Creates (or restores) a list state. Each state is registered under a unique name. * The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore). diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java index 090f48a3c87..ce2d40309dd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java @@ -20,7 +20,6 @@ import org.apache.flink.runtime.state.OperatorStateHandle; -import java.util.Collection; import java.util.List; /** @@ -36,7 +35,7 @@ * @return List with one entry per parallel subtask. Each subtask receives now one collection of states that build * of the new total state for this subtask. */ - List<Collection<OperatorStateHandle>> repartitionState( + List<List<OperatorStateHandle>> repartitionState( List<OperatorStateHandle> previousParallelSubtaskStates, int newParallelism); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java index e6fa687fd14..4705265430e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java @@ -26,11 +26,11 @@ import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.EnumMap; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; /** * Current default implementation of {@link OperatorStateRepartitioner} that redistributes state in round robin fashion. @@ -41,7 +41,7 @@ private static final boolean OPTIMIZE_MEMORY_USE = false; @Override - public List<Collection<OperatorStateHandle>> repartitionState( + public List<List<OperatorStateHandle>> repartitionState( List<OperatorStateHandle> previousParallelSubtaskStates, int newParallelism) { @@ -56,7 +56,7 @@ } // Assemble result from all merge maps - List<Collection<OperatorStateHandle>> result = new ArrayList<>(newParallelism); + List<List<OperatorStateHandle>> result = new ArrayList<>(newParallelism); // Do the actual repartitioning for all named states List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList = @@ -93,20 +93,19 @@ private GroupByStateNameResults groupByStateName(List<OperatorStateHandle> previ continue; } - for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e : - psh.getStateNameToPartitionOffsets().entrySet()) { + final Set<Map.Entry<String, OperatorStateHandle.StateMetaInfo>> partitionOffsetEntries = + psh.getStateNameToPartitionOffsets().entrySet(); + + for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e : partitionOffsetEntries) { OperatorStateHandle.StateMetaInfo metaInfo = e.getValue(); Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> nameToState = nameToStateByMode.get(metaInfo.getDistributionMode()); List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> stateLocations = - nameToState.get(e.getKey()); - - if (stateLocations == null) { - stateLocations = new ArrayList<>(); - nameToState.put(e.getKey(), stateLocations); - } + nameToState.computeIfAbsent( + e.getKey(), + k -> new ArrayList<>(previousParallelSubtaskStates.size() * partitionOffsetEntries.size())); stateLocations.add(new Tuple2<>(psh.getDelegateStateHandle(), e.getValue())); } @@ -203,7 +202,9 @@ private GroupByStateNameResults groupByStateName(List<OperatorStateHandle> previ Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(parallelOpIdx); OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithOffsets.f0); if (operatorStateHandle == null) { - operatorStateHandle = new OperatorStreamStateHandle(new HashMap<>(), handleWithOffsets.f0); + operatorStateHandle = new OperatorStreamStateHandle( + new HashMap<>(distributeNameToState.size()), + handleWithOffsets.f0); mergeMap.put(handleWithOffsets.f0, operatorStateHandle); } operatorStateHandle.getStateNameToPartitionOffsets().put( @@ -229,7 +230,9 @@ private GroupByStateNameResults groupByStateName(List<OperatorStateHandle> previ for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithMetaInfo : e.getValue()) { OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithMetaInfo.f0); if (operatorStateHandle == null) { - operatorStateHandle = new OperatorStreamStateHandle(new HashMap<>(), handleWithMetaInfo.f0); + operatorStateHandle = new OperatorStreamStateHandle( + new HashMap<>(broadcastNameToState.size()), + handleWithMetaInfo.f0); mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle); } operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), handleWithMetaInfo.f1); @@ -256,7 +259,9 @@ private GroupByStateNameResults groupByStateName(List<OperatorStateHandle> previ OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithMetaInfo.f0); if (operatorStateHandle == null) { - operatorStateHandle = new OperatorStreamStateHandle(new HashMap<>(), handleWithMetaInfo.f0); + operatorStateHandle = new OperatorStreamStateHandle( + new HashMap<>(uniformBroadcastNameToState.size()), + handleWithMetaInfo.f0); mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle); } operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), handleWithMetaInfo.f1); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java index 592489f2baf..b0173886d57 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java @@ -24,11 +24,11 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.OperatorInstanceID; -import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; @@ -83,7 +83,7 @@ public boolean assignStates() throws Exception { // find the states of all operators belonging to this task List<OperatorID> operatorIDs = executionJobVertex.getOperatorIDs(); List<OperatorID> altOperatorIDs = executionJobVertex.getUserDefinedOperatorIDs(); - List<OperatorState> operatorStates = new ArrayList<>(); + List<OperatorState> operatorStates = new ArrayList<>(operatorIDs.size()); boolean statelessTask = true; for (int x = 0; x < operatorIDs.size(); x++) { OperatorID operatorID = altOperatorIDs.get(x) == null @@ -124,7 +124,9 @@ private void assignAttemptState(ExecutionJobVertex executionJobVertex, List<Oper executionJobVertex.getMaxParallelism(), newParallelism); - /** + final int expectedNumberOfSubTasks = newParallelism * operatorIDs.size(); + + /* * Redistribute ManagedOperatorStates and RawOperatorStates from old parallelism to new parallelism. * * The old ManagedOperatorStates with old parallelism 3: @@ -143,8 +145,10 @@ private void assignAttemptState(ExecutionJobVertex executionJobVertex, List<Oper * op2 state2,0 state2,1 state2,2 state2,3 * op3 state3,0 state3,1 state3,2 state3,3 */ - Map<OperatorInstanceID, List<OperatorStateHandle>> newManagedOperatorStates = new HashMap<>(); - Map<OperatorInstanceID, List<OperatorStateHandle>> newRawOperatorStates = new HashMap<>(); + Map<OperatorInstanceID, List<OperatorStateHandle>> newManagedOperatorStates = + new HashMap<>(expectedNumberOfSubTasks); + Map<OperatorInstanceID, List<OperatorStateHandle>> newRawOperatorStates = + new HashMap<>(expectedNumberOfSubTasks); reDistributePartitionableStates( operatorStates, @@ -153,8 +157,10 @@ private void assignAttemptState(ExecutionJobVertex executionJobVertex, List<Oper newManagedOperatorStates, newRawOperatorStates); - Map<OperatorInstanceID, List<KeyedStateHandle>> newManagedKeyedState = new HashMap<>(); - Map<OperatorInstanceID, List<KeyedStateHandle>> newRawKeyedState = new HashMap<>(); + Map<OperatorInstanceID, List<KeyedStateHandle>> newManagedKeyedState = + new HashMap<>(expectedNumberOfSubTasks); + Map<OperatorInstanceID, List<KeyedStateHandle>> newRawKeyedState = + new HashMap<>(expectedNumberOfSubTasks); reDistributeKeyedStates( operatorStates, @@ -164,7 +170,7 @@ private void assignAttemptState(ExecutionJobVertex executionJobVertex, List<Oper newManagedKeyedState, newRawKeyedState); - /** + /* * An executionJobVertex's all state handles needed to restore are something like a matrix * * parallelism0 parallelism1 parallelism2 parallelism3 @@ -198,7 +204,7 @@ private void assignTaskStateToExecutionJobVertices( Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[subTaskIndex] .getCurrentExecutionAttempt(); - TaskStateSnapshot taskState = new TaskStateSnapshot(); + TaskStateSnapshot taskState = new TaskStateSnapshot(operatorIDs.size()); boolean statelessTask = true; for (OperatorID operatorID : operatorIDs) { @@ -276,38 +282,34 @@ private void reDistributeKeyedStates( for (int subTaskIndex = 0; subTaskIndex < newParallelism; subTaskIndex++) { OperatorInstanceID instanceID = OperatorInstanceID.of(subTaskIndex, newOperatorIDs.get(operatorIndex)); if (isHeadOperator(operatorIndex, newOperatorIDs)) { - Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> subKeyedStates = reAssignSubKeyedStates( + Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> subKeyedStates = reAssignSubKeyedStates( operatorState, newKeyGroupPartitions, subTaskIndex, newParallelism, oldParallelism); - newManagedKeyedState - .computeIfAbsent(instanceID, key -> new ArrayList<>()) - .addAll(subKeyedStates.f0); - newRawKeyedState - .computeIfAbsent(instanceID, key -> new ArrayList<>()) - .addAll(subKeyedStates.f1); + newManagedKeyedState.put(instanceID, subKeyedStates.f0); + newRawKeyedState.put(instanceID, subKeyedStates.f1); } } } } // TODO rewrite based on operator id - private Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> reAssignSubKeyedStates( + private Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> reAssignSubKeyedStates( OperatorState operatorState, List<KeyGroupRange> keyGroupPartitions, int subTaskIndex, int newParallelism, int oldParallelism) { - Collection<KeyedStateHandle> subManagedKeyedState; - Collection<KeyedStateHandle> subRawKeyedState; + List<KeyedStateHandle> subManagedKeyedState; + List<KeyedStateHandle> subRawKeyedState; if (newParallelism == oldParallelism) { if (operatorState.getState(subTaskIndex) != null) { - subManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState(); - subRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState(); + subManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState().asList(); + subRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState().asList(); } else { subManagedKeyedState = Collections.emptyList(); subRawKeyedState = Collections.emptyList(); @@ -336,8 +338,8 @@ private void reDistributePartitionableStates( "This method still depends on the order of the new and old operators"); //collect the old partitionable state - List<List<OperatorStateHandle>> oldManagedOperatorStates = new ArrayList<>(); - List<List<OperatorStateHandle>> oldRawOperatorStates = new ArrayList<>(); + List<List<OperatorStateHandle>> oldManagedOperatorStates = new ArrayList<>(oldOperatorStates.size()); + List<List<OperatorStateHandle>> oldRawOperatorStates = new ArrayList<>(oldOperatorStates.size()); collectPartionableStates(oldOperatorStates, oldManagedOperatorStates, oldRawOperatorStates); @@ -368,24 +370,29 @@ private void collectPartionableStates( List<List<OperatorStateHandle>> rawOperatorStates) { for (OperatorState operatorState : operatorStates) { + + final int parallelism = operatorState.getParallelism(); + List<OperatorStateHandle> managedOperatorState = null; List<OperatorStateHandle> rawOperatorState = null; - for (int i = 0; i < operatorState.getParallelism(); i++) { + for (int i = 0; i < parallelism; i++) { OperatorSubtaskState operatorSubtaskState = operatorState.getState(i); if (operatorSubtaskState != null) { + StateObjectCollection<OperatorStateHandle> managed = operatorSubtaskState.getManagedOperatorState(); + StateObjectCollection<OperatorStateHandle> raw = operatorSubtaskState.getRawOperatorState(); + if (managedOperatorState == null) { - managedOperatorState = new ArrayList<>(); + managedOperatorState = new ArrayList<>(parallelism * managed.size()); } - managedOperatorState.addAll(operatorSubtaskState.getManagedOperatorState()); + managedOperatorState.addAll(managed); if (rawOperatorState == null) { - rawOperatorState = new ArrayList<>(); + rawOperatorState = new ArrayList<>(parallelism * raw.size()); } - rawOperatorState.addAll(operatorSubtaskState.getRawOperatorState()); + rawOperatorState.addAll(raw); } - } managedOperatorStates.add(managedOperatorState); rawOperatorStates.add(rawOperatorState); @@ -404,12 +411,19 @@ private void collectPartionableStates( OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) { - List<KeyedStateHandle> subtaskKeyedStateHandles = new ArrayList<>(); + final int parallelism = operatorState.getParallelism(); - for (int i = 0; i < operatorState.getParallelism(); i++) { + List<KeyedStateHandle> subtaskKeyedStateHandles = null; + + for (int i = 0; i < parallelism; i++) { if (operatorState.getState(i) != null) { Collection<KeyedStateHandle> keyedStateHandles = operatorState.getState(i).getManagedKeyedState(); + + if (subtaskKeyedStateHandles == null) { + subtaskKeyedStateHandles = new ArrayList<>(parallelism * keyedStateHandles.size()); + } + extractIntersectingState( keyedStateHandles, subtaskKeyGroupRange, @@ -432,11 +446,19 @@ private void collectPartionableStates( OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) { - List<KeyedStateHandle> extractedKeyedStateHandles = new ArrayList<>(); + final int parallelism = operatorState.getParallelism(); - for (int i = 0; i < operatorState.getParallelism(); i++) { + List<KeyedStateHandle> extractedKeyedStateHandles = null; + + for (int i = 0; i < parallelism; i++) { if (operatorState.getState(i) != null) { + Collection<KeyedStateHandle> rawKeyedState = operatorState.getState(i).getRawKeyedState(); + + if (extractedKeyedStateHandles == null) { + extractedKeyedStateHandles = new ArrayList<>(parallelism * rawKeyedState.size()); + } + extractIntersectingState( rawKeyedState, subtaskKeyGroupRange, @@ -565,19 +587,18 @@ private static void checkStateMappingCompleteness( List<OperatorStateHandle> chainOpParallelStates, int oldParallelism, int newParallelism) { - Map<OperatorInstanceID, List<OperatorStateHandle>> result = new HashMap<>(); - List<Collection<OperatorStateHandle>> states = applyRepartitioner( + List<List<OperatorStateHandle>> states = applyRepartitioner( opStateRepartitioner, chainOpParallelStates, oldParallelism, newParallelism); + Map<OperatorInstanceID, List<OperatorStateHandle>> result = new HashMap<>(states.size()); + for (int subtaskIndex = 0; subtaskIndex < states.size(); subtaskIndex++) { checkNotNull(states.get(subtaskIndex) != null, "states.get(subtaskIndex) is null"); - result - .computeIfAbsent(OperatorInstanceID.of(subtaskIndex, operatorID), key -> new ArrayList<>()) - .addAll(states.get(subtaskIndex)); + result.put(OperatorInstanceID.of(subtaskIndex, operatorID), states.get(subtaskIndex)); } return result; @@ -594,7 +615,7 @@ private static void checkStateMappingCompleteness( * @return repartitioned state */ // TODO rewrite based on operator id - public static List<Collection<OperatorStateHandle>> applyRepartitioner( + public static List<List<OperatorStateHandle>> applyRepartitioner( OperatorStateRepartitioner opStateRepartitioner, List<OperatorStateHandle> chainOpParallelStates, int oldParallelism, @@ -611,7 +632,7 @@ private static void checkStateMappingCompleteness( chainOpParallelStates, newParallelism); } else { - List<Collection<OperatorStateHandle>> repackStream = new ArrayList<>(newParallelism); + List<List<OperatorStateHandle>> repackStream = new ArrayList<>(newParallelism); for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) { if (operatorStateHandle != null) { @@ -645,7 +666,7 @@ private static void checkStateMappingCompleteness( Collection<? extends KeyedStateHandle> keyedStateHandles, KeyGroupRange subtaskKeyGroupRange) { - List<KeyedStateHandle> subtaskKeyedStateHandles = new ArrayList<>(); + List<KeyedStateHandle> subtaskKeyedStateHandles = new ArrayList<>(keyedStateHandles.size()); for (KeyedStateHandle keyedStateHandle : keyedStateHandles) { KeyedStateHandle intersectedKeyedStateHandle = keyedStateHandle.getIntersection(subtaskKeyGroupRange); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateObjectCollection.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateObjectCollection.java index 38e3d15da29..30768477eed 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateObjectCollection.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateObjectCollection.java @@ -27,6 +27,7 @@ import java.util.Collection; import java.util.Collections; import java.util.Iterator; +import java.util.List; import java.util.function.Predicate; /** @@ -178,6 +179,14 @@ public String toString() { return "StateObjectCollection{" + stateObjects + '}'; } + public List<T> asList() { + return stateObjects instanceof List ? + (List<T>) stateObjects : + stateObjects != null ? + new ArrayList<>(stateObjects) : + Collections.emptyList(); + } + // ------------------------------------------------------------------------ // Helper methods. // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/TaskKvStateRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/TaskKvStateRegistry.java index a44a508ecbc..76ab7a5f9c3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/query/TaskKvStateRegistry.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/TaskKvStateRegistry.java @@ -25,8 +25,8 @@ import org.apache.flink.runtime.state.internal.InternalKvState; import org.apache.flink.util.Preconditions; -import java.util.ArrayList; -import java.util.List; +import java.util.HashMap; +import java.util.Map; /** * A helper for KvState registrations of a single task. @@ -43,7 +43,7 @@ private final JobVertexID jobVertexId; /** List of all registered KvState instances of this task. */ - private final List<KvStateInfo> registeredKvStates = new ArrayList<>(); + private final Map<String, KvStateInfo> registeredKvStates = new HashMap<>(); TaskKvStateRegistry(KvStateRegistry registry, JobID jobId, JobVertexID jobVertexId) { this.registry = Preconditions.checkNotNull(registry, "KvStateRegistry"); @@ -61,19 +61,35 @@ * @param kvState The */ public void registerKvState(KeyGroupRange keyGroupRange, String registrationName, InternalKvState<?, ?, ?> kvState) { + unregisterKvState(registrationName); KvStateID kvStateId = registry.registerKvState(jobId, jobVertexId, keyGroupRange, registrationName, kvState); - registeredKvStates.add(new KvStateInfo(keyGroupRange, registrationName, kvStateId)); + registeredKvStates.put(registrationName, new KvStateInfo(keyGroupRange, registrationName, kvStateId)); + } + + /** + * + * @param registrationName + */ + public void unregisterKvState(String registrationName) { + KvStateInfo kvStateInfo = registeredKvStates.get(registrationName); + if (kvStateInfo != null) { + unregisterInternal(kvStateInfo); + } } /** * Unregisters all registered KvState instances from the KvStateRegistry. */ public void unregisterAll() { - for (KvStateInfo kvState : registeredKvStates) { - registry.unregisterKvState(jobId, jobVertexId, kvState.keyGroupRange, kvState.registrationName, kvState.kvStateId); + for (KvStateInfo kvState : registeredKvStates.values()) { + unregisterInternal(kvState); } } + private void unregisterInternal(KvStateInfo kvState) { + registry.unregisterKvState(jobId, jobVertexId, kvState.keyGroupRange, kvState.registrationName, kvState.kvStateId); + } + /** * 3-tuple holding registered KvState meta data. */ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java index 1c2d2a3ecaf..d80b76bec85 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java @@ -62,7 +62,7 @@ private int currentKeyGroup; /** So that we can give out state when the user uses the same key. */ - private final HashMap<String, InternalKvState<K, ?, ?>> keyValueStatesByName; + protected final HashMap<String, InternalKvState<K, ?, ?>> keyValueStatesByName; /** For caching the last accessed partitioned state. */ private String lastName; @@ -319,5 +319,4 @@ StreamCompressionDecorator getKeyGroupCompressionDecorator() { public boolean requiresLegacySynchronousTimerSnapshots() { return false; } - } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java index d9fc41e6529..c80f516bcbe 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java @@ -180,6 +180,35 @@ public void dispose() { // ------------------------------------------------------------------------------------------- // State access methods // ------------------------------------------------------------------------------------------- + @Override + public void removeBroadcastState(String name) { + restoredBroadcastStateMetaInfos.remove(name); + if (registeredBroadcastStates.remove(name) != null) { + accessedBroadcastStatesByName.remove(name); + } + } + + @Override + public void removeOperatorState(String name) { + restoredOperatorStateMetaInfos.remove(name); + if (registeredOperatorStates.remove(name) != null) { + accessedStatesByName.remove(name); + } + } + + public void deleteBroadCastState(String name) { + restoredBroadcastStateMetaInfos.remove(name); + if (registeredBroadcastStates.remove(name) != null) { + accessedBroadcastStatesByName.remove(name); + } + } + + public void deleteOperatorState(String name) { + restoredOperatorStateMetaInfos.remove(name); + if (registeredOperatorStates.remove(name) != null) { + accessedStatesByName.remove(name); + } + } @SuppressWarnings("unchecked") @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java index 7ba14b3d007..14c2d842dcf 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java @@ -24,6 +24,8 @@ import org.apache.flink.runtime.state.heap.InternalKeyContext; import org.apache.flink.util.Disposable; +import javax.annotation.Nonnull; + import java.util.stream.Stream; /** @@ -104,6 +106,16 @@ TypeSerializer<N> namespaceSerializer, StateDescriptor<S, ?> stateDescriptor) throws Exception; + /** + * Removes the operator state with the given name from the state store, if it exists. + */ + void removeKeyedState(@Nonnull String stateName) throws Exception; + + /** + * Removes the queue state with the given name from the state store, if it exists. + */ + void removeQueueState(@Nonnull String name) throws Exception; + @Override void dispose(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java index 2245e72bce6..e41805b1ca2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java @@ -38,7 +38,7 @@ * @return the queue with the specified unique name. */ @Nonnull - <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T> create( + <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T> createQueueState( @Nonnull String stateName, - @Nonnull TypeSerializer<T> byteOrderedElementSerializer); + @Nonnull TypeSerializer<T> byteOrderedElementSerializer) throws Exception; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java index bc1e0f52507..e6a7a1be4d7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java @@ -62,10 +62,9 @@ import org.apache.flink.runtime.state.SnapshotResult; import org.apache.flink.runtime.state.SnapshotStrategy; import org.apache.flink.runtime.state.StateSnapshot; -import org.apache.flink.runtime.state.StateSnapshotTransformer; import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader; import org.apache.flink.runtime.state.StateSnapshotRestore; -import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory; +import org.apache.flink.runtime.state.StateSnapshotTransformer; import org.apache.flink.runtime.state.StreamCompressionDecorator; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator; @@ -185,7 +184,7 @@ public HeapKeyedStateBackend( @SuppressWarnings("unchecked") @Nonnull @Override - public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T> create( + public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T> createQueueState( @Nonnull String stateName, @Nonnull TypeSerializer<T> byteOrderedElementSerializer) { @@ -228,12 +227,29 @@ public HeapKeyedStateBackend( } } + @Override + public void removeQueueState(@Nonnull String stateName) { + restoredStateMetaInfo.remove(StateUID.of(stateName, StateMetaInfoSnapshot.BackendStateType.PRIORITY_QUEUE)); + registeredPQStates.remove(stateName); + } + + @Override + public void removeKeyedState(@Nonnull String stateName) { + restoredStateMetaInfo.remove(StateUID.of(stateName, StateMetaInfoSnapshot.BackendStateType.KEY_VALUE)); + if (registeredKVStates.remove(stateName) != null) { + if (kvStateRegistry != null) { + kvStateRegistry.unregisterKvState(stateName); + } + keyValueStatesByName.remove(stateName); + } + } + @Nonnull private <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T> createInternal( RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo) { final String stateName = metaInfo.getName(); - final HeapPriorityQueueSet<T> priorityQueue = priorityQueueSetFactory.create( + final HeapPriorityQueueSet<T> priorityQueue = priorityQueueSetFactory.createQueueState( stateName, metaInfo.getElementSerializer()); @@ -312,7 +328,7 @@ private boolean hasRegisteredState() { public <N, SV, SEV, S extends State, IS extends S> IS createInternalState( @Nonnull TypeSerializer<N> namespaceSerializer, @Nonnull StateDescriptor<S, SV> stateDesc, - @Nonnull StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception { + @Nonnull StateSnapshotTransformer.StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception { StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass()); if (stateFactory == null) { String message = String.format("State %s is not supported by %s", @@ -327,7 +343,7 @@ private boolean hasRegisteredState() { @SuppressWarnings("unchecked") private <SV, SEV> StateSnapshotTransformer<SV> getStateSnapshotTransformer( StateDescriptor<?, SV> stateDesc, - StateSnapshotTransformFactory<SEV> snapshotTransformFactory) { + StateSnapshotTransformer.StateSnapshotTransformFactory<SEV> snapshotTransformFactory) { Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState(); if (original.isPresent()) { if (stateDesc instanceof ListStateDescriptor) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java index 80d79ac1fc1..b32f4deaff0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java @@ -55,7 +55,7 @@ public HeapPriorityQueueSetFactory( @Nonnull @Override - public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> HeapPriorityQueueSet<T> create( + public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> HeapPriorityQueueSet<T> createQueueState( @Nonnull String stateName, @Nonnull TypeSerializer<T> byteOrderedElementSerializer) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index 1b2062a7481..b113e12ef69 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -53,11 +53,12 @@ import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation; import org.apache.flink.runtime.testutils.CommonTestUtils; import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore; -import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables; import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; import org.apache.flink.util.TestLogger; +import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables; + import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -2741,7 +2742,7 @@ public void testReplicateModeStateHandle() { OperatorStateHandle osh = new OperatorStreamStateHandle(metaInfoMap, new ByteStreamStateHandle("test", new byte[150])); OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE; - List<Collection<OperatorStateHandle>> repartitionedStates = + List<List<OperatorStateHandle>> repartitionedStates = repartitioner.repartitionState(Collections.singletonList(osh), 3); Map<String, Integer> checkCounts = new HashMap<>(3); @@ -3331,7 +3332,7 @@ private void doTestPartitionableStateRepartitioning( OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE; - List<Collection<OperatorStateHandle>> pshs = + List<List<OperatorStateHandle>> pshs = repartitioner.repartitionState(previousParallelOpInstanceStates, newParallelism); Map<StreamStateHandle, Map<String, List<Long>>> actual = new HashMap<>(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java index d8918e78478..61ab3cd7170 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java @@ -54,6 +54,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.Map; +import java.util.Set; import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -946,6 +947,50 @@ static MutableType of(int value) { } } + @Test + public void testDeleteBroadcastState() throws Exception { + final OperatorStateBackend operatorStateBackend = + new DefaultOperatorStateBackend(classLoader, new ExecutionConfig(), false); + + MapStateDescriptor<Integer, Integer> broadcastStateDesc1 = new MapStateDescriptor<>( + "test-broadcast-1", IntSerializer.INSTANCE, IntSerializer.INSTANCE); + + MapStateDescriptor<Integer, Integer> broadcastStateDesc2 = new MapStateDescriptor<>( + "test-broadcast-2", IntSerializer.INSTANCE, IntSerializer.INSTANCE); + + operatorStateBackend.getBroadcastState(broadcastStateDesc1); + operatorStateBackend.getBroadcastState(broadcastStateDesc2); + + Assert.assertEquals(2, operatorStateBackend.getRegisteredBroadcastStateNames().size()); + + operatorStateBackend.removeBroadcastState(broadcastStateDesc2.getName()); + Assert.assertEquals(1, operatorStateBackend.getRegisteredBroadcastStateNames().size()); + Assert.assertTrue(operatorStateBackend.getRegisteredBroadcastStateNames().contains(broadcastStateDesc1.getName())); + } + + @Test + public void testDeleteOperatorState() throws Exception { + final OperatorStateBackend operatorStateBackend = + new DefaultOperatorStateBackend(classLoader, new ExecutionConfig(), false); + + ListStateDescriptor<Integer> listStateDesc1 = new ListStateDescriptor<>("test-broadcast-1", IntSerializer.INSTANCE); + ListStateDescriptor<Integer> listStateDesc2 = new ListStateDescriptor<>("test-broadcast-2", IntSerializer.INSTANCE); + + operatorStateBackend.getListState(listStateDesc1); + operatorStateBackend.getUnionListState(listStateDesc2); + + Set<String> registeredStateNames = operatorStateBackend.getRegisteredStateNames(); + + Assert.assertEquals(2, registeredStateNames.size()); + + operatorStateBackend.removeOperatorState(listStateDesc1.getName()); + Assert.assertEquals(1, registeredStateNames.size()); + Assert.assertTrue(registeredStateNames.contains(listStateDesc2.getName())); + + operatorStateBackend.removeOperatorState(listStateDesc2.getName()); + Assert.assertEquals(0, registeredStateNames.size()); + } + // ------------------------------------------------------------------------ // utilities // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java index 059a706c6a8..66fec859560 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java @@ -1160,7 +1160,7 @@ public void testPriorityQueueSerializerUpdates() throws Exception { InternalPriorityQueueTestBase.TestElementSerializer.INSTANCE; KeyGroupedInternalPriorityQueue<InternalPriorityQueueTestBase.TestElement> priorityQueue = - keyedBackend.create(stateName, serializer); + keyedBackend.createQueueState(stateName, serializer); priorityQueue.add(new InternalPriorityQueueTestBase.TestElement(42L, 0L)); @@ -1177,7 +1177,7 @@ public void testPriorityQueueSerializerUpdates() throws Exception { serializer = new ModifiedTestElementSerializer(); - priorityQueue = keyedBackend.create(stateName, serializer); + priorityQueue = keyedBackend.createQueueState(stateName, serializer); final InternalPriorityQueueTestBase.TestElement checkElement = new InternalPriorityQueueTestBase.TestElement(4711L, 1L); @@ -1192,7 +1192,7 @@ public void testPriorityQueueSerializerUpdates() throws Exception { // test that the modified serializer was actually used --------------------------- keyedBackend = restoreKeyedBackend(IntSerializer.INSTANCE, keyedStateHandle); - priorityQueue = keyedBackend.create(stateName, serializer); + priorityQueue = keyedBackend.createQueueState(stateName, serializer); priorityQueue.poll(); @@ -1217,7 +1217,7 @@ public void testPriorityQueueSerializerUpdates() throws Exception { try { // this is expected to fail, because the old and new serializer shoulbe be incompatible through // different revision numbers. - keyedBackend.create("test", serializer); + keyedBackend.createQueueState("test", serializer); Assert.fail("Expected exception from incompatible serializer."); } catch (Exception e) { Assert.assertTrue("Exception was not caused by state migration: " + e, @@ -4129,6 +4129,39 @@ public void testCheckConcurrencyProblemWhenPerformingCheckpointAsync() throws Ex } } + @Test + public void testDeleteKeyedState() throws Exception { + + Environment env = new DummyEnvironment(); + AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env); + try { + ValueStateDescriptor<Integer> kv1 = new ValueStateDescriptor<>("kv_1", IntSerializer.INSTANCE); + ValueStateDescriptor<Integer> kv2 = new ValueStateDescriptor<>("kv_2", IntSerializer.INSTANCE); + ValueState<Integer> state1 = backend.getOrCreateKeyedState(VoidNamespaceSerializer.INSTANCE, kv1); + ValueState<Integer> state2 = backend.getOrCreateKeyedState(VoidNamespaceSerializer.INSTANCE, kv2); + + backend.removeKeyedState(kv2.getName()); + } finally { + backend.dispose(); + } + } + + @Test + public void testDeletePriorityQueueState() throws Exception { + + Environment env = new DummyEnvironment(); + AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env); + try { + String state1 = "state_1"; + String state2 = "state_2"; + backend.createQueueState(state1, InternalPriorityQueueTestBase.TestElementSerializer.INSTANCE); + backend.createQueueState(state2, InternalPriorityQueueTestBase.TestElementSerializer.INSTANCE); + backend.removeQueueState(state2); + } finally { + backend.dispose(); + } + } + protected Future<SnapshotResult<KeyedStateHandle>> runSnapshotAsync( ExecutorService executorService, RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshotRunnableFuture) throws Exception { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java index 0b5931ce1d4..5eb9c5f8a0d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java @@ -170,6 +170,16 @@ public void notifyCheckpointComplete(long checkpointId) { .map(Map.Entry::getKey); } + @Override + public void removeKeyedState(@Nonnull String stateName) { + + } + + @Override + public void removeQueueState(@Nonnull String name) { + + } + @Override public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot( long checkpointId, @@ -229,7 +239,7 @@ public void restore(Collection<KeyedStateHandle> state) { @Nonnull @Override public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T> - create( + createQueueState( @Nonnull String stateName, @Nonnull TypeSerializer<T> byteOrderedElementSerializer) { return new HeapPriorityQueueSet<>( diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java index 0fd11252b2f..37864214cd5 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java @@ -85,7 +85,6 @@ import org.apache.flink.runtime.state.StateHandleID; import org.apache.flink.runtime.state.StateObject; import org.apache.flink.runtime.state.StateSnapshotTransformer; -import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory; import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamCompressionDecorator; import org.apache.flink.runtime.state.StreamStateHandle; @@ -353,6 +352,33 @@ private static void checkAndCreateDirectory(File directory) throws IOException { } } + @Override + public void removeQueueState(@Nonnull String stateName) throws RocksDBException { + removeInternal(stateName); + } + + @Override + public void removeKeyedState(@Nonnull String stateName) throws RocksDBException { + removeInternal(stateName); + } + + private void removeInternal(@Nonnull String name) throws RocksDBException { + restoredKvStateMetaInfos.remove(name); + Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> removedStateMetaInfo = kvStateInformation.remove(name); + if (removedStateMetaInfo != null) { + if (kvStateRegistry != null) { + kvStateRegistry.unregisterKvState(name); + } + keyValueStatesByName.remove(name); + ColumnFamilyHandle removeColumnFamily = removedStateMetaInfo.f0; + try { + db.dropColumnFamily(removeColumnFamily); + } finally { + IOUtils.closeQuietly(removeColumnFamily); + } + } + } + @SuppressWarnings("unchecked") @Override public <N> Stream<K> getKeys(String state, N namespace) { @@ -444,10 +470,10 @@ public void dispose() { @Nonnull @Override public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T> - create( + createQueueState( @Nonnull String stateName, - @Nonnull TypeSerializer<T> byteOrderedElementSerializer) { - return priorityQueueFactory.create(stateName, byteOrderedElementSerializer); + @Nonnull TypeSerializer<T> byteOrderedElementSerializer) throws Exception { + return priorityQueueFactory.createQueueState(stateName, byteOrderedElementSerializer); } private void cleanInstanceBasePath() { @@ -1381,7 +1407,7 @@ private ColumnFamilyHandle createColumnFamily(String stateName) { public <N, SV, SEV, S extends State, IS extends S> IS createInternalState( @Nonnull TypeSerializer<N> namespaceSerializer, @Nonnull StateDescriptor<S, SV> stateDesc, - @Nonnull StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception { + @Nonnull StateSnapshotTransformer.StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception { StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass()); if (stateFactory == null) { String message = String.format("State %s is not supported by %s", @@ -1396,7 +1422,7 @@ private ColumnFamilyHandle createColumnFamily(String stateName) { @SuppressWarnings("unchecked") private <SV, SEV> StateSnapshotTransformer<SV> getStateSnapshotTransformer( StateDescriptor<?, SV> stateDesc, - StateSnapshotTransformFactory<SEV> snapshotTransformFactory) { + StateSnapshotTransformer.StateSnapshotTransformFactory<SEV> snapshotTransformFactory) { if (stateDesc instanceof ListStateDescriptor) { Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState(); return original.map(est -> createRocksDBListStateTransformer(stateDesc, est)).orElse(null); @@ -2756,7 +2782,7 @@ private static RocksIteratorWrapper getRocksIterator( @Nonnull @Override public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T> - create(@Nonnull String stateName, @Nonnull TypeSerializer<T> byteOrderedElementSerializer) { + createQueueState(@Nonnull String stateName, @Nonnull TypeSerializer<T> byteOrderedElementSerializer) { final Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> metaInfoTuple = tryRegisterPriorityQueueMetaInfo(stateName, byteOrderedElementSerializer); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index 52ba3e4f1f5..5313ed780ac 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -729,7 +729,7 @@ public void close() { public <K, N> InternalTimerService<N> getInternalTimerService( String name, TypeSerializer<N> namespaceSerializer, - Triggerable<K, N> triggerable) { + Triggerable<K, N> triggerable) throws Exception { checkTimerServiceInitialization(); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java index ff48c3fae03..1939a8daccd 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java @@ -82,7 +82,7 @@ public <N> InternalTimerService<N> getInternalTimerService( String name, TimerSerializer<K, N> timerSerializer, - Triggerable<K, N> triggerable) { + Triggerable<K, N> triggerable) throws Exception { InternalTimerServiceImpl<K, N> timerService = registerOrGetTimerService(name, timerSerializer); @@ -95,7 +95,10 @@ } @SuppressWarnings("unchecked") - <N> InternalTimerServiceImpl<K, N> registerOrGetTimerService(String name, TimerSerializer<K, N> timerSerializer) { + <N> InternalTimerServiceImpl<K, N> registerOrGetTimerService( + String name, + TimerSerializer<K, N> timerSerializer) throws Exception { + InternalTimerServiceImpl<K, N> timerService = (InternalTimerServiceImpl<K, N>) timerServices.get(name); if (timerService == null) { @@ -117,8 +120,8 @@ private <N> KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<K, N>> createTimerPriorityQueue( String name, - TimerSerializer<K, N> timerSerializer) { - return priorityQueueSetFactory.create( + TimerSerializer<K, N> timerSerializer) throws Exception { + return priorityQueueSetFactory.createQueueState( name, timerSerializer); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java index dea17f98665..ca9d9c0376c 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java @@ -104,9 +104,14 @@ protected void read(DataInputView in, boolean wasVersioned) throws IOException { .getReaderForVersion(readerVersion, userCodeClassLoader) .readTimersSnapshot(in); - InternalTimerServiceImpl<K, ?> timerService = registerOrGetTimerService( - serviceName, - restoredTimersSnapshot); + InternalTimerServiceImpl<K, ?> timerService; + try { + timerService = registerOrGetTimerService( + serviceName, + restoredTimersSnapshot); + } catch (Exception e) { + throw new IOException("Could not create timer service in restore.", e); + } timerService.restoreTimersForKeyGroup(restoredTimersSnapshot, keyGroupIdx); } @@ -114,7 +119,7 @@ protected void read(DataInputView in, boolean wasVersioned) throws IOException { @SuppressWarnings("unchecked") private <N> InternalTimerServiceImpl<K, N> registerOrGetTimerService( - String serviceName, InternalTimersSnapshot<?, ?> restoredTimersSnapshot) { + String serviceName, InternalTimersSnapshot<?, ?> restoredTimersSnapshot) throws Exception { final TypeSerializer<K> keySerializer = (TypeSerializer<K>) restoredTimersSnapshot.getKeySerializer(); final TypeSerializer<N> namespaceSerializer = (TypeSerializer<N>) restoredTimersSnapshot.getNamespaceSerializer(); TimerSerializer<K, N> timerSerializer = new TimerSerializer<>(keySerializer, namespaceSerializer); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImplTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImplTest.java index f2da6da3b05..c9f37e5d21a 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImplTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimerServiceImplTest.java @@ -79,7 +79,7 @@ public InternalTimerServiceImplTest(int startKeyGroup, int endKeyGroup, int maxP } @Test - public void testKeyGroupStartIndexSetting() { + public void testKeyGroupStartIndexSetting() throws Exception { int startKeyGroupIdx = 7; int endKeyGroupIdx = 21; @@ -101,7 +101,7 @@ public void testKeyGroupStartIndexSetting() { } @Test - public void testTimerAssignmentToKeyGroups() { + public void testTimerAssignmentToKeyGroups() throws Exception { int totalNoOfTimers = 100; int totalNoOfKeyGroups = 100; @@ -811,7 +811,7 @@ private static int getKeyInKeyGroupRange(KeyGroupRange range, int maxParallelism KeyContext keyContext, ProcessingTimeService processingTimeService, KeyGroupRange keyGroupList, - PriorityQueueSetFactory priorityQueueSetFactory) { + PriorityQueueSetFactory priorityQueueSetFactory) throws Exception { InternalTimerServiceImpl<Integer, String> service = createInternalTimerService( keyGroupList, keyContext, @@ -892,7 +892,7 @@ protected PriorityQueueSetFactory createQueueFactory(KeyGroupRange keyGroupRange ProcessingTimeService processingTimeService, TypeSerializer<K> keySerializer, TypeSerializer<N> namespaceSerializer, - PriorityQueueSetFactory priorityQueueSetFactory) { + PriorityQueueSetFactory priorityQueueSetFactory) throws Exception { TimerSerializer<K, N> timerSerializer = new TimerSerializer<>(keySerializer, namespaceSerializer); @@ -907,8 +907,8 @@ protected PriorityQueueSetFactory createQueueFactory(KeyGroupRange keyGroupRange private static <K, N> KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<K, N>> createTimerQueue( String name, TimerSerializer<K, N> timerSerializer, - PriorityQueueSetFactory priorityQueueSetFactory) { - return priorityQueueSetFactory.create( + PriorityQueueSetFactory priorityQueueSetFactory) throws Exception { + return priorityQueueSetFactory.createQueueState( name, timerSerializer); } ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > KafkaConsumer should use partitionable state over union state if partition > discovery is not active > -------------------------------------------------------------------------------------------------- > > Key: FLINK-10122 > URL: https://issues.apache.org/jira/browse/FLINK-10122 > Project: Flink > Issue Type: Improvement > Components: Kafka Connector > Reporter: Stefan Richter > Assignee: Stefan Richter > Priority: Major > Labels: pull-request-available > Fix For: 1.7.0 > > > KafkaConsumer store its offsets state always as union state. I think this is > only required in the case that partition discovery is active. For jobs with a > very high parallelism, the union state can lead to prohibitively expensive > deployments. For example, a job with 2000 source and a total of 10MB > checkpointed union state offsets state would have to ship ~ 2000 x 10MB = > 20GB of state. With partitionable state, it would have to ship ~10MB. > For now, I would suggest to go back to partitionable state in case that > partition discovery is not active. In the long run, I have some ideas for > more efficient partitioning schemes that would also work for active discovery. -- This message was sent by Atlassian JIRA (v7.6.3#76005)