davidradl commented on code in PR #26424: URL: https://github.com/apache/flink/pull/26424#discussion_r2035544454
########## flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/over/NonTimeRowsUnboundedPrecedingFunction.java: ########## @@ -0,0 +1,719 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.over; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.StateTtlConfig; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.ListTypeInfo; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.metrics.Counter; +import org.apache.flink.streaming.api.functions.KeyedProcessFunction; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.utils.JoinedRowData; +import org.apache.flink.table.runtime.dataview.PerKeyStateDataViewStore; +import org.apache.flink.table.runtime.generated.AggsHandleFunction; +import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction; +import org.apache.flink.table.runtime.generated.GeneratedRecordComparator; +import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.generated.RecordEqualiser; +import org.apache.flink.table.runtime.keyselector.RowDataKeySelector; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.types.RowKind; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +import static org.apache.flink.table.runtime.util.StateConfigUtil.createTtlConfig; + +/** + * The NonTimeRowsUnboundedPrecedingFunction class is a specialized implementation for processing + * unbounded OVER window aggregations, particularly for non-time-based rows queries in Apache Flink. + * It maintains strict ordering of rows within partitions and handles the full changelog lifecycle + * (inserts, updates, deletes). + * + * <p>Key Components and Assumptions + * + * <p>Data Structure Design: (1) Maintains a sorted list of tuples containing sort keys and lists of + * IDs for each key (2) Each incoming row is assigned a unique Long ID (starting from + * Long.MIN_VALUE) (3) Uses multiple state types to track rows, sort orders, and aggregations + * + * <p>State Management: (1) idState: Counter for generating unique row IDs (2) sortedListState: + * Ordered list of sort keys with their associated row IDs (3) valueMapState: Maps IDs to their + * corresponding input rows (4) accMapState: Maps sort keys to their accumulated values + * + * <p>Processing Model: (1) For inserts/updates: Adds rows to the appropriate position based on sort + * key (2) For deletes: Removes rows by matching both sort key and row content (3) Recalculates + * aggregates for affected rows and emits the appropriate events (4) Skips redundant events when + * accumulators haven't changed to reduce network traffic + * + * <p>Optimization Assumptions: (1) Skip emitting updates when accumulators haven't changed to + * reduce network traffic (2) Uses state TTL for automatic cleanup of stale data (3) Carefully + * manages row state to support incremental calculations + * + * <p>Retraction Handling: (1) Handles retraction mode (DELETE/UPDATE_BEFORE) events properly (2) + * Supports the proper processing of changelog streams + * + * <p>Limitations + * + * <p>Linear search performance: - The current implementation uses a linear search to find the + * correct position for each sort key. This can be optimized using a binary search for large state + * sizes. + * + * <p>State size and performance: - The implementation maintains multiple state types that could + * grow large with high cardinality data + * + * <p>Linear recalculation: - When processing updates, all subsequent elements need to be + * recalculated, which could be inefficient for large windows + */ +public class NonTimeRowsUnboundedPrecedingFunction<K> + extends KeyedProcessFunction<K, RowData, RowData> { + private static final long serialVersionUID = 1L; + + private static final Logger LOG = + LoggerFactory.getLogger(NonTimeRowsUnboundedPrecedingFunction.class); + + private final long stateRetentionTime; + + private final GeneratedAggsHandleFunction generatedAggsHandler; + private final GeneratedRecordEqualiser generatedRecordEqualiser; + private final GeneratedRecordEqualiser generatedSortKeyEqualiser; + private final GeneratedRecordComparator generatedSortKeyComparator; + + // The util to compare two rows based on the sort attribute. + private transient Comparator<RowData> sortKeyComparator; + + protected final KeySelector<RowData, RowData> sortKeySelector; + // The record equaliser used to equal RowData. + private transient RecordEqualiser valueEqualiser; + private transient RecordEqualiser sortKeyEqualiser; + + private final LogicalType[] accTypes; + private final LogicalType[] inputFieldTypes; + private final LogicalType[] sortKeyTypes; + protected transient JoinedRowData output; + + // state to hold the Long ID counter + private transient ValueState<Long> idState; + @VisibleForTesting protected transient ValueStateDescriptor<Long> idStateDescriptor; + + // state to hold a sorted list each containing a tuple of sort key and list of IDs + private transient ValueState<List<Tuple2<RowData, List<Long>>>> sortedListState; + + @VisibleForTesting + protected transient ValueStateDescriptor<List<Tuple2<RowData, List<Long>>>> + sortedListStateDescriptor; + + // state to hold ID and its associated input row until state ttl expires + private transient MapState<Long, RowData> valueMapState; + @VisibleForTesting protected transient MapStateDescriptor<Long, RowData> valueStateDescriptor; + // state to hold ID and its associated accumulator + private transient MapState<RowData, RowData> accMapState; + @VisibleForTesting protected transient MapStateDescriptor<RowData, RowData> accStateDescriptor; + + protected transient AggsHandleFunction aggFuncs; + + // Metrics + private static final String IDS_NOT_FOUND_METRIC_NAME = "numOfIdsNotFound"; + private transient Counter numOfIdsNotFound; + private static final String SORT_KEYS_NOT_FOUND_METRIC_NAME = "numOfSortKeysNotFound"; + private transient Counter numOfSortKeysNotFound; + + @VisibleForTesting + protected Counter getNumOfIdsNotFound() { + return numOfIdsNotFound; + } + + @VisibleForTesting + protected Counter getNumOfSortKeysNotFound() { + return numOfSortKeysNotFound; + } + + public NonTimeRowsUnboundedPrecedingFunction( + long stateRetentionTime, + GeneratedAggsHandleFunction genAggsHandler, + GeneratedRecordEqualiser genRecordEqualiser, + GeneratedRecordEqualiser genSortKeyEqualiser, + GeneratedRecordComparator genSortKeyComparator, + LogicalType[] accTypes, + LogicalType[] inputFieldTypes, + LogicalType[] sortKeyTypes, + RowDataKeySelector sortKeySelector) { + this.stateRetentionTime = stateRetentionTime; + this.generatedAggsHandler = genAggsHandler; + this.generatedRecordEqualiser = genRecordEqualiser; + this.generatedSortKeyEqualiser = genSortKeyEqualiser; + this.generatedSortKeyComparator = genSortKeyComparator; + this.accTypes = accTypes; + this.inputFieldTypes = inputFieldTypes; + this.sortKeyTypes = sortKeyTypes; + this.sortKeySelector = sortKeySelector; + } + + @Override + public void open(OpenContext openContext) throws Exception { + // Initialize agg functions + aggFuncs = generatedAggsHandler.newInstance(getRuntimeContext().getUserCodeClassLoader()); + aggFuncs.open(new PerKeyStateDataViewStore(getRuntimeContext())); + + // Initialize output record + output = new JoinedRowData(); + + // Initialize value/row equaliser + valueEqualiser = + generatedRecordEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader()); + + // Initialize sortKey equaliser + sortKeyEqualiser = + generatedSortKeyEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader()); + + // Initialize sort comparator + sortKeyComparator = + generatedSortKeyComparator.newInstance( + getRuntimeContext().getUserCodeClassLoader()); + + StateTtlConfig ttlConfig = createTtlConfig(stateRetentionTime); + + // Initialize state to maintain id counter + idStateDescriptor = new ValueStateDescriptor<Long>("idState", Long.class); + if (ttlConfig.isEnabled()) { + idStateDescriptor.enableTimeToLive(ttlConfig); + } + idState = getRuntimeContext().getState(idStateDescriptor); + + // Input elements are all binary rows as they came from network + InternalTypeInfo<RowData> inputRowTypeInfo = InternalTypeInfo.ofFields(inputFieldTypes); + InternalTypeInfo<RowData> sortKeyRowTypeInfo = InternalTypeInfo.ofFields(this.sortKeyTypes); + InternalTypeInfo<RowData> idRowTypeInfo = InternalTypeInfo.ofFields(new BigIntType()); + + // Initialize state which maintains a sorted list of tuples(sortKey, List of IDs) + ListTypeInfo<Long> idListTypeInfo = new ListTypeInfo<Long>(Types.LONG); + ListTypeInfo<Tuple2<RowData, List<Long>>> listTypeInfo = + new ListTypeInfo<>(new TupleTypeInfo<>(sortKeyRowTypeInfo, idListTypeInfo)); + sortedListStateDescriptor = + new ValueStateDescriptor<List<Tuple2<RowData, List<Long>>>>( + "sortedListState", listTypeInfo); + if (ttlConfig.isEnabled()) { + sortedListStateDescriptor.enableTimeToLive(ttlConfig); + } + sortedListState = getRuntimeContext().getState(sortedListStateDescriptor); + + // Initialize state which maintains the actual row + valueStateDescriptor = + new MapStateDescriptor<Long, RowData>( + "valueMapState", Types.LONG, inputRowTypeInfo); + if (ttlConfig.isEnabled()) { + valueStateDescriptor.enableTimeToLive(ttlConfig); + } + valueMapState = getRuntimeContext().getMapState(valueStateDescriptor); + + // Initialize accumulator state per row + InternalTypeInfo<RowData> accTypeInfo = InternalTypeInfo.ofFields(accTypes); + accStateDescriptor = + new MapStateDescriptor<RowData, RowData>("accMapState", idRowTypeInfo, accTypeInfo); + if (ttlConfig.isEnabled()) { + accStateDescriptor.enableTimeToLive(ttlConfig); + } + accMapState = getRuntimeContext().getMapState(accStateDescriptor); + + // metrics + this.numOfIdsNotFound = + getRuntimeContext().getMetricGroup().counter(IDS_NOT_FOUND_METRIC_NAME); + this.numOfSortKeysNotFound = + getRuntimeContext().getMetricGroup().counter(SORT_KEYS_NOT_FOUND_METRIC_NAME); + } + + /** + * Puts an element from the input stream into state or removes it from state if the input is a + * retraction. Emits the aggregated value for the newly inserted element and updates all results + * that are affected by the added or removed row. Emits different aggregated value for all + * elements with the same sortKey to comply with the sql ROWS syntax. + * + * @param input The input value. + * @param ctx A {@link Context} that allows querying the timestamp of the element and getting + * TimerService for registering timers and querying the time. The context is only valid + * during the invocation of this method, do not store it. + * @param out The collector for returning result values. + * @throws Exception + */ + @Override + public void processElement( + RowData input, + KeyedProcessFunction<K, RowData, RowData>.Context ctx, + Collector<RowData> out) + throws Exception { + RowKind rowKind = input.getRowKind(); + + switch (rowKind) { + case INSERT: + case UPDATE_AFTER: + insertIntoSortedList(input, out); + break; + + case DELETE: + case UPDATE_BEFORE: + removeFromSortedList(input, out); + break; + } + + // Reset acc state since we can have out of order inserts into the ordered list + aggFuncs.resetAccumulators(); + aggFuncs.cleanup(); + } + + /** + * Adds a new element(insRow) to a sortedList. The sortedList contains a list of Tuple(SortKey, + * List[Long Ids]>). Extracts the inputSortKey from the insRow and compares it with every + * element in the sortedList. If a sortKey already exists in the sortedList for the input, add + * the id to the list of ids and update the sortedList, otherwise find the right position in the + * sortedList and add a new entry in the sortedList. After the insRow is successfully inserted, + * an INSERT/UPDATE_AFTER event is emitted for the newly inserted element, and for all + * subsequent elements an UPDATE_BEFORE and UPDATE_AFTER event is emitted based on the previous + * and newly aggregated values. Some updates are skipped if the previously accumulated value is + * the same as the newly accumulated value to save on network bandwidth and downstream + * processing including writing the result to the sink system. + * + * @param insRow The input value. + * @param out The collector for returning result values. + * @throws Exception + */ + private void insertIntoSortedList(RowData insRow, Collector<RowData> out) throws Exception { + Long id = getNextId(); + List<Tuple2<RowData, List<Long>>> sortedList = getSortedList(); + RowKind origRowKind = insRow.getRowKind(); + insRow.setRowKind(RowKind.INSERT); + RowData inputSortKey = sortKeySelector.getKey(insRow); + Tuple2<Integer, Boolean> indexForInsertOrUpdate = + findIndexOfSortKey(sortedList, inputSortKey, false); + boolean isInsert = indexForInsertOrUpdate.f1; + int index = indexForInsertOrUpdate.f0; + if (isInsert) { + if (index == -1) { + // Insert at the end of the sortedList + sortedList.add(new Tuple2<>(inputSortKey, List.of(id))); + index = sortedList.size() - 1; + } else { + // Insert at position i of the sortedList + sortedList.add(index, new Tuple2<>(inputSortKey, List.of(id))); + } + setAccumulatorOfPrevId(sortedList, index - 1, -1); + aggFuncs.accumulate(insRow); + collectInsertOrUpdateAfter(out, insRow, origRowKind, aggFuncs.getValue()); + } else { + // Update at position i + List<Long> ids = new ArrayList<>(sortedList.get(index).f1); + ids.add(id); + sortedList.set(index, new Tuple2<>(inputSortKey, ids)); + setAccumulatorOfPrevId(sortedList, index, ids.size() - 2); + reAccumulateIdsAfterInsert(aggFuncs.getAccumulators(), ids, insRow); + emitUpdatesForIds( + ids, + ids.size() - 1, + accMapState.get(GenericRowData.of(ids.get(ids.size() - 2))), // prevAcc + aggFuncs.getAccumulators(), // currAcc + origRowKind, + insRow, + out); + } + + // Add/Update state + valueMapState.put(id, insRow); + accMapState.put(GenericRowData.of(id), aggFuncs.getAccumulators()); + sortedListState.update(sortedList); + idState.update(++id); + + processRemainingElements(sortedList, index + 1, aggFuncs.getAccumulators(), out); + } + + /** + * @return the next id after reading from the idState + * @throws IOException + */ + private Long getNextId() throws IOException { + Long id = idState.value(); + if (id == null) { + id = Long.MIN_VALUE; + } + return id; + } + + /** + * @return the sortedList containing sortKeys and Ids with the same sortKey + * @throws IOException + */ + private List<Tuple2<RowData, List<Long>>> getSortedList() throws IOException { + List<Tuple2<RowData, List<Long>>> sortedList = sortedListState.value(); + if (sortedList == null) { + sortedList = new ArrayList<>(); + } + return sortedList; + } + + /** + * Returns the position of the index where the inputRow must be inserted, updated or deleted. + * For insertion, if a suitable position is not found, return -1 to be inserted at the end of Review Comment: and what does -1 mean for updated? nit: remove the word `matching ` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org