Hi,

I recently learned that timers can be set in the KeyedStateFunction that is
passed to KeyedBroadcastProcessFunction.Context#applyToKeyedState. The
"trick" is to store a reference to the timerService that is available in
processElement.

This is behavior I have not seen explicitly documented before. The closest
I could find is

Context in the processBroadcastElement() method contains the method
applyToKeyedState(StateDescriptor<S,
> VS> stateDescriptor, KeyedStateFunction<KS, S> function). This allows to
> register a KeyedStateFunction to be applied to all states of all keys 
> associated
> with the provided stateDescriptor.


from
https://nightlies.apache.org/flink/flink-docs-release-1.19/docs/dev/datastream/fault-tolerance/broadcast_state/#broadcastprocessfunction-and-keyedbroadcastprocessfunction


A developer would need to know that timers are implied by *all states*.
However, the section immediately after that says

Registering timers is only possible at processElement() of the
> KeyedBroadcastProcessFunction and only there. It is not possible in the
> processBroadcastElement() method, as there is no key associated to the
> broadcasted elements.


which made me think that setting timers *anywhere* in
processBroadcastElement, including in the user supplied KeyedStateFunction,
is not possible, even though it is.


*Should the documentation be updated to mention that timers can be set from
applyToKeyedState?*
I did notice that the new DataStream v2 API makes this behavior much
clearer via the new ApplyPartitionFunction#apply(Collector<OUT> collector,
PartitionedContext ctx) method, which is great!

Below is an example of how to register timers in Flink 1.17.

import lombok.NoArgsConstructor;
import lombok.Setter;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.KeyedStateFunction;
import org.apache.flink.streaming.api.TimerService;
import 
org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction;
import org.apache.flink.util.Collector;

/**
 * Function that shows that timers can be set in
processBroadcastElement via ctx.applyToKeyedState.
 */
public class ApplyToKeyedStateTimerExample
    extends KeyedBroadcastProcessFunction<String, String, Long, String> {
  private static final long serialVersionUID = -8682056773621520927L;
  private ValueState<String> state;
  private CachedTimerService cachedTimerService = new CachedTimerService();

  @Override
  public void processElement(String value, ReadOnlyContext ctx,
Collector<String> out)
      throws Exception {
    state.update(value);
    cachedTimerService.setTimerService(ctx.timerService());
  }

  @Override
  public void processBroadcastElement(Long value, Context ctx,
Collector<String> out)
      throws Exception {
    ctx.applyToKeyedState(
        new ValueStateDescriptor<>("state", Types.STRING),
cachedTimerService.registerTimer(value));
  }

  @Override
  public void onTimer(final long timestamp, final OnTimerContext ctx,
final Collector<String> out)
      throws Exception {
    var value = state.value();
    if (value != null) {
      out.collect(value);
      state.clear();
    }
  }

  @Override
  public void open(Configuration parameters) throws Exception {
    state = getRuntimeContext().getState(new
ValueStateDescriptor<>("state", Types.STRING));
  }

  @NoArgsConstructor
  private static final class CachedTimerService {
    @Setter private TimerService timerService;

    public <K, S extends State> KeyedStateFunction<K, S>
registerTimer(long timestamp) {
      return (key, state) -> {
        if (timerService != null) {
          timerService.registerEventTimeTimer(timestamp);
        }
      };
    }
  }
}

Here is a unit test for the above function.

import java.util.stream.Collectors;

import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.ProcessFunctionTestHarnesses;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class ApplyToKeyedStateTimerExampleTest {
  @Test
  public void testTimersFromBroadcast() throws Exception {
    var harness =
        ProcessFunctionTestHarnesses.forKeyedBroadcastProcessFunction(
            new ApplyToKeyedStateTimerExample(),
            x -> x,
            Types.STRING,
            new MapStateDescriptor<>("foo", Types.STRING, Types.LONG));
    harness.open();
    harness.processElement("foo", 1);
    harness.processElement("bar", 1);
    harness.processBroadcastElement(2L, 2);
    harness.watermark(2);
    var output =
        harness.getOutput().stream()
            .filter(x -> x instanceof StreamRecord) // Filter out the
watermark StreamElement
            .collect(Collectors.toList());
    Assertions.assertEquals(2, output.size());
  }
}


Thanks,
Jose

Reply via email to