[ https://issues.apache.org/jira/browse/FLINK-9513?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16519968#comment-16519968 ]
ASF GitHub Bot commented on FLINK-9513: --------------------------------------- Github user sihuazhou commented on a diff in the pull request: https://github.com/apache/flink/pull/6196#discussion_r197331741 --- Diff: flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java --- @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.ttl; + +import org.apache.flink.api.common.state.AggregatingStateDescriptor; +import org.apache.flink.api.common.state.FoldingStateDescriptor; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.CompositeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.state.KeyedStateFactory; +import org.apache.flink.util.FlinkRuntimeException; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * This state factory wraps state objects, produced by backends, with TTL logic. + */ +public class TtlStateFactory { + public static <N, SV, S extends State, IS extends S> IS createStateAndWrapWithTtlIfEnabled( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc, + KeyedStateFactory originalStateFactory, + TtlConfig ttlConfig, + TtlTimeProvider timeProvider) throws Exception { + return ttlConfig.getTtlUpdateType() == TtlUpdateType.Disabled ? + originalStateFactory.createState(namespaceSerializer, stateDesc) : + new TtlStateFactory(originalStateFactory, ttlConfig, timeProvider) + .createState(namespaceSerializer, stateDesc); + } + + private final Map<Class<? extends StateDescriptor>, StateFactory> stateFactories; + + private final KeyedStateFactory originalStateFactory; + private final TtlConfig ttlConfig; + private final TtlTimeProvider timeProvider; + + private TtlStateFactory(KeyedStateFactory originalStateFactory, TtlConfig ttlConfig, TtlTimeProvider timeProvider) { + this.originalStateFactory = originalStateFactory; + this.ttlConfig = ttlConfig; + this.timeProvider = timeProvider; + this.stateFactories = createStateFactories(); + } + + private Map<Class<? extends StateDescriptor>, StateFactory> createStateFactories() { + return Stream.of( + Tuple2.of(ValueStateDescriptor.class, (StateFactory) this::createValueState), + Tuple2.of(ListStateDescriptor.class, (StateFactory) this::createListState), + Tuple2.of(MapStateDescriptor.class, (StateFactory) this::createMapState), + Tuple2.of(ReducingStateDescriptor.class, (StateFactory) this::createReducingState), + Tuple2.of(AggregatingStateDescriptor.class, (StateFactory) this::createAggregatingState), + Tuple2.of(FoldingStateDescriptor.class, (StateFactory) this::createFoldingState) + ).collect(Collectors.toMap(t -> t.f0, t -> t.f1)); + } + + private interface StateFactory { + <N, SV, S extends State, IS extends S> IS create( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception; + } + + private <N, SV, S extends State, IS extends S> IS createState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + StateFactory stateFactory = stateFactories.get(stateDesc.getClass()); + if (stateFactory == null) { + String message = String.format("State %s is not supported by %s", + stateDesc.getClass(), TtlStateFactory.class); + throw new FlinkRuntimeException(message); + } + return stateFactory.create(namespaceSerializer, stateDesc); + } + + @SuppressWarnings("unchecked") + private <N, SV, S extends State, IS extends S> IS createValueState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + SV defVal = stateDesc.getDefaultValue(); + TtlValue<SV> ttlDefVal = defVal == null ? null : new TtlValue<>(defVal, Long.MAX_VALUE); + ValueStateDescriptor<TtlValue<SV>> ttlDescriptor = new ValueStateDescriptor<>( + stateDesc.getName(), new TtlSerializer<>(stateDesc.getSerializer()), ttlDefVal); + return (IS) new TtlValueState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, stateDesc.getSerializer()); + } + + @SuppressWarnings("unchecked") + private <T, N, SV, S extends State, IS extends S> IS createListState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + ListStateDescriptor<T> listStateDesc = (ListStateDescriptor<T>) stateDesc; + ListStateDescriptor<TtlValue<T>> ttlDescriptor = new ListStateDescriptor<>( + stateDesc.getName(), new TtlSerializer<>(listStateDesc.getElementSerializer())); + return (IS) new TtlListState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, listStateDesc.getSerializer()); + } + + @SuppressWarnings("unchecked") + private <UK, UV, N, SV, S extends State, IS extends S> IS createMapState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + MapStateDescriptor<UK, UV> mapStateDesc = (MapStateDescriptor<UK, UV>) stateDesc; + MapStateDescriptor<UK, TtlValue<UV>> ttlDescriptor = new MapStateDescriptor<>( + stateDesc.getName(), + mapStateDesc.getKeySerializer(), + new TtlSerializer<>(mapStateDesc.getValueSerializer())); + return (IS) new TtlMapState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, mapStateDesc.getSerializer()); + } + + @SuppressWarnings("unchecked") + private <N, SV, S extends State, IS extends S> IS createReducingState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + ReducingStateDescriptor<SV> reducingStateDesc = (ReducingStateDescriptor<SV>) stateDesc; + ReducingStateDescriptor<TtlValue<SV>> ttlDescriptor = new ReducingStateDescriptor<>( + stateDesc.getName(), + new TtlReduceFunction<>(reducingStateDesc.getReduceFunction(), ttlConfig, timeProvider), + new TtlSerializer<>(stateDesc.getSerializer())); + return (IS) new TtlReducingState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, stateDesc.getSerializer()); + } + + @SuppressWarnings("unchecked") + private <IN, OUT, N, SV, S extends State, IS extends S> IS createAggregatingState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + AggregatingStateDescriptor<IN, SV, OUT> aggregatingStateDescriptor = + (AggregatingStateDescriptor<IN, SV, OUT>) stateDesc; + TtlAggregateFunction<IN, SV, OUT> ttlAggregateFunction = new TtlAggregateFunction<>( + aggregatingStateDescriptor.getAggregateFunction(), ttlConfig, timeProvider); + AggregatingStateDescriptor<IN, TtlValue<SV>, OUT> ttlDescriptor = new AggregatingStateDescriptor<>( + stateDesc.getName(), ttlAggregateFunction, new TtlSerializer<>(stateDesc.getSerializer())); + return (IS) new TtlAggregatingState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, stateDesc.getSerializer(), ttlAggregateFunction); + } + + @SuppressWarnings("unchecked") + private <T, N, SV, S extends State, IS extends S> IS createFoldingState( + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, SV> stateDesc) throws Exception { + FoldingStateDescriptor<T, SV> foldingStateDescriptor = (FoldingStateDescriptor<T, SV>) stateDesc; + SV initAcc = stateDesc.getDefaultValue(); + TtlValue<SV> ttlInitAcc = initAcc == null ? null : new TtlValue<>(initAcc, Long.MAX_VALUE); + FoldingStateDescriptor<T, TtlValue<SV>> ttlDescriptor = new FoldingStateDescriptor<>( + stateDesc.getName(), + ttlInitAcc, + new TtlFoldFunction<>(foldingStateDescriptor.getFoldFunction(), ttlConfig, timeProvider), + new TtlSerializer<>(stateDesc.getSerializer())); + return (IS) new TtlFoldingState<>( + originalStateFactory.createState(namespaceSerializer, ttlDescriptor), + ttlConfig, timeProvider, stateDesc.getSerializer()); + } + + private static class TtlSerializer<T> extends CompositeSerializer<TtlValue<T>> { + TtlSerializer(TypeSerializer<T> userValueSerializer) { + super(Arrays.asList(userValueSerializer, new LongSerializer())); + } + + @Override + @SuppressWarnings("unchecked") + protected TtlValue<T> composeValue(List values) { --- End diff -- Should we check that `values != null && values.size() == 2 `? > Wrap state binder with TTL logic > -------------------------------- > > Key: FLINK-9513 > URL: https://issues.apache.org/jira/browse/FLINK-9513 > Project: Flink > Issue Type: Sub-task > Components: State Backends, Checkpointing > Affects Versions: 1.6.0 > Reporter: Andrey Zagrebin > Assignee: Andrey Zagrebin > Priority: Major > Labels: pull-request-available > Fix For: 1.6.0 > > > The main idea is to wrap user state value with a class holding the value and > the expiration timestamp (maybe meta data in future) and use the new object > as a value in the existing implementations: > {code:java} > class TtlValue<V> { > V value; > long expirationTimestamp; > } > {code} > The original state binder factory is wrapped with TtlStateBinder if TTL is > enabled: > {code:java} > state = ttlConfig.updateType == DISABLED ? > bind(binder) : bind(new TtlStateBinder(binder, timerService)); > {code} > TtlStateBinder decorates the states produced by the original binder with TTL > logic wrappers and adds TtlValue serialisation logic: > {code:java} > TtlStateBinder { > StateBinder binder; > ProcessingTimeProvier timeProvider; // System.currentTimeMillis() > <V> TtlValueState<V> createValueState(valueDesc) { > serializer = new TtlValueSerializer(valueDesc.getSerializer); > ttlValueDesc = new ValueDesc(serializer, ...); > // or implement custom TypeInfo > originalStateWithTtl = binder.createValueState(valueDesc); > return new TtlValueState(originalStateWithTtl, timeProvider); > } > // List, Map, ... > } > {code} > TTL serializer should add expiration timestamp -- This message was sent by Atlassian JIRA (v7.6.3#76005)