Github user sihuazhou commented on a diff in the pull request: https://github.com/apache/flink/pull/6173#discussion_r195903433 --- Diff: flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java --- @@ -203,91 +216,16 @@ private boolean hasRegisteredState() { } @Override - public <N, V> InternalValueState<K, N, V> createValueState( - TypeSerializer<N> namespaceSerializer, - ValueStateDescriptor<V> stateDesc) throws Exception { - - StateTable<K, N, V> stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc); - return new HeapValueState<>( - stateTable, - keySerializer, - stateTable.getStateSerializer(), - stateTable.getNamespaceSerializer(), - stateDesc.getDefaultValue()); - } - - @Override - public <N, T> InternalListState<K, N, T> createListState( - TypeSerializer<N> namespaceSerializer, - ListStateDescriptor<T> stateDesc) throws Exception { - - StateTable<K, N, List<T>> stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc); - return new HeapListState<>( - stateTable, - keySerializer, - stateTable.getStateSerializer(), - stateTable.getNamespaceSerializer(), - stateDesc.getDefaultValue()); - } - - @Override - public <N, T> InternalReducingState<K, N, T> createReducingState( - TypeSerializer<N> namespaceSerializer, - ReducingStateDescriptor<T> stateDesc) throws Exception { - - StateTable<K, N, T> stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc); - return new HeapReducingState<>( - stateTable, - keySerializer, - stateTable.getStateSerializer(), - stateTable.getNamespaceSerializer(), - stateDesc.getDefaultValue(), - stateDesc.getReduceFunction()); - } - - @Override - public <N, T, ACC, R> InternalAggregatingState<K, N, T, ACC, R> createAggregatingState( - TypeSerializer<N> namespaceSerializer, - AggregatingStateDescriptor<T, ACC, R> stateDesc) throws Exception { - - StateTable<K, N, ACC> stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc); - return new HeapAggregatingState<>( - stateTable, - keySerializer, - stateTable.getStateSerializer(), - stateTable.getNamespaceSerializer(), - stateDesc.getDefaultValue(), - stateDesc.getAggregateFunction()); - } - - @Override - public <N, T, ACC> InternalFoldingState<K, N, T, ACC> createFoldingState( - TypeSerializer<N> namespaceSerializer, - FoldingStateDescriptor<T, ACC> stateDesc) throws Exception { - - StateTable<K, N, ACC> stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc); - return new HeapFoldingState<>( - stateTable, - keySerializer, - stateTable.getStateSerializer(), - stateTable.getNamespaceSerializer(), - stateDesc.getDefaultValue(), - stateDesc.getFoldFunction()); - } - - @Override - protected <N, UK, UV> InternalMapState<K, N, UK, UV> createMapState( - TypeSerializer<N> namespaceSerializer, - MapStateDescriptor<UK, UV> stateDesc) throws Exception { - - StateTable<K, N, Map<UK, UV>> stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc); - - return new HeapMapState<>( - stateTable, - keySerializer, - stateTable.getStateSerializer(), - stateTable.getNamespaceSerializer(), - stateDesc.getDefaultValue()); + public <N, SV, S extends State, IS extends S> IS createState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + if (!STATE_FACTORIES.containsKey(stateDesc.getClass())) { + String message = String.format("State %s is not supported by %s", + stateDesc.getClass(), this.getClass()); + throw new FlinkRuntimeException(message); + } + StateTable<K, N, SV> stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc); + return STATE_FACTORIES.get(stateDesc.getClass()).createState(stateDesc, stateTable, keySerializer); --- End diff -- The same like above, maybe the `get()` and `containsKey()` could be merged into one `get()`.
---