Github user sihuazhou commented on a diff in the pull request: https://github.com/apache/flink/pull/6196#discussion_r197331898 --- Diff: flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java --- @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.ttl; + +import org.apache.flink.api.common.state.AggregatingStateDescriptor; +import org.apache.flink.api.common.state.FoldingStateDescriptor; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.CompositeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.state.KeyedStateFactory; +import org.apache.flink.util.FlinkRuntimeException; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * This state factory wraps state objects, produced by backends, with TTL logic. + */ +public class TtlStateFactory { + public static <N, SV, S extends State, IS extends S> IS createStateAndWrapWithTtlIfEnabled( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc, + KeyedStateFactory originalStateFactory, + TtlConfig ttlConfig, + TtlTimeProvider timeProvider) throws Exception { + return ttlConfig.getTtlUpdateType() == TtlUpdateType.Disabled ? + originalStateFactory.createState(namespaceSerializer, stateDesc) : + new TtlStateFactory(originalStateFactory, ttlConfig, timeProvider) + .createState(namespaceSerializer, stateDesc); + } + + private final Map<Class<? extends StateDescriptor>, StateFactory> stateFactories; + + private final KeyedStateFactory originalStateFactory; + private final TtlConfig ttlConfig; + private final TtlTimeProvider timeProvider; + + private TtlStateFactory(KeyedStateFactory originalStateFactory, TtlConfig ttlConfig, TtlTimeProvider timeProvider) { + this.originalStateFactory = originalStateFactory; + this.ttlConfig = ttlConfig; + this.timeProvider = timeProvider; + this.stateFactories = createStateFactories(); + } + + private Map<Class<? extends StateDescriptor>, StateFactory> createStateFactories() { + return Stream.of( + Tuple2.of(ValueStateDescriptor.class, (StateFactory) this::createValueState), + Tuple2.of(ListStateDescriptor.class, (StateFactory) this::createListState), + Tuple2.of(MapStateDescriptor.class, (StateFactory) this::createMapState), + Tuple2.of(ReducingStateDescriptor.class, (StateFactory) this::createReducingState), + Tuple2.of(AggregatingStateDescriptor.class, (StateFactory) this::createAggregatingState), + Tuple2.of(FoldingStateDescriptor.class, (StateFactory) this::createFoldingState) + ).collect(Collectors.toMap(t -> t.f0, t -> t.f1)); + } + + private interface StateFactory { + <N, SV, S extends State, IS extends S> IS create( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception; + } + + private <N, SV, S extends State, IS extends S> IS createState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + StateFactory stateFactory = stateFactories.get(stateDesc.getClass()); + if (stateFactory == null) { + String message = String.format("State %s is not supported by %s", + stateDesc.getClass(), TtlStateFactory.class); + throw new FlinkRuntimeException(message); + } + return stateFactory.create(namespaceSerializer, stateDesc); + } + + @SuppressWarnings("unchecked") + private <N, SV, S extends State, IS extends S> IS createValueState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + SV defVal = stateDesc.getDefaultValue(); + TtlValue<SV> ttlDefVal = defVal == null ? null : new TtlValue<>(defVal, Long.MAX_VALUE); + ValueStateDescriptor<TtlValue<SV>> ttlDescriptor = new ValueStateDescriptor<>( + stateDesc.getName(), new TtlSerializer<>(stateDesc.getSerializer()), ttlDefVal); + return (IS) new TtlValueState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, stateDesc.getSerializer()); + } + + @SuppressWarnings("unchecked") + private <T, N, SV, S extends State, IS extends S> IS createListState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + ListStateDescriptor<T> listStateDesc = (ListStateDescriptor<T>) stateDesc; + ListStateDescriptor<TtlValue<T>> ttlDescriptor = new ListStateDescriptor<>( + stateDesc.getName(), new TtlSerializer<>(listStateDesc.getElementSerializer())); + return (IS) new TtlListState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, listStateDesc.getSerializer()); + } + + @SuppressWarnings("unchecked") + private <UK, UV, N, SV, S extends State, IS extends S> IS createMapState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + MapStateDescriptor<UK, UV> mapStateDesc = (MapStateDescriptor<UK, UV>) stateDesc; + MapStateDescriptor<UK, TtlValue<UV>> ttlDescriptor = new MapStateDescriptor<>( + stateDesc.getName(), + mapStateDesc.getKeySerializer(), + new TtlSerializer<>(mapStateDesc.getValueSerializer())); + return (IS) new TtlMapState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, mapStateDesc.getSerializer()); + } + + @SuppressWarnings("unchecked") + private <N, SV, S extends State, IS extends S> IS createReducingState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + ReducingStateDescriptor<SV> reducingStateDesc = (ReducingStateDescriptor<SV>) stateDesc; + ReducingStateDescriptor<TtlValue<SV>> ttlDescriptor = new ReducingStateDescriptor<>( + stateDesc.getName(), + new TtlReduceFunction<>(reducingStateDesc.getReduceFunction(), ttlConfig, timeProvider), + new TtlSerializer<>(stateDesc.getSerializer())); + return (IS) new TtlReducingState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, stateDesc.getSerializer()); + } + + @SuppressWarnings("unchecked") + private <IN, OUT, N, SV, S extends State, IS extends S> IS createAggregatingState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + AggregatingStateDescriptor<IN, SV, OUT> aggregatingStateDescriptor = + (AggregatingStateDescriptor<IN, SV, OUT>) stateDesc; + TtlAggregateFunction<IN, SV, OUT> ttlAggregateFunction = new TtlAggregateFunction<>( + aggregatingStateDescriptor.getAggregateFunction(), ttlConfig, timeProvider); + AggregatingStateDescriptor<IN, TtlValue<SV>, OUT> ttlDescriptor = new AggregatingStateDescriptor<>( + stateDesc.getName(), ttlAggregateFunction, new TtlSerializer<>(stateDesc.getSerializer())); + return (IS) new TtlAggregatingState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, stateDesc.getSerializer(), ttlAggregateFunction); + } + + @SuppressWarnings("unchecked") + private <T, N, SV, S extends State, IS extends S> IS createFoldingState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + FoldingStateDescriptor<T, SV> foldingStateDescriptor = (FoldingStateDescriptor<T, SV>) stateDesc; + SV initAcc = stateDesc.getDefaultValue(); + TtlValue<SV> ttlInitAcc = initAcc == null ? null : new TtlValue<>(initAcc, Long.MAX_VALUE); + FoldingStateDescriptor<T, TtlValue<SV>> ttlDescriptor = new FoldingStateDescriptor<>( + stateDesc.getName(), + ttlInitAcc, + new TtlFoldFunction<>(foldingStateDescriptor.getFoldFunction(), ttlConfig, timeProvider), + new TtlSerializer<>(stateDesc.getSerializer())); + return (IS) new TtlFoldingState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, stateDesc.getSerializer()); + } + + private static class TtlSerializer<T> extends CompositeSerializer<TtlValue<T>> { + TtlSerializer(TypeSerializer<T> userValueSerializer) { --- End diff -- The `userValueSerializer` seems can't be null?
---