xishuaidelin commented on code in PR #25717:
URL: https://github.com/apache/flink/pull/25717#discussion_r1909730165


##########
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/rank/async/AbstractAsyncStateTopNFunction.java:
##########
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.rank.async;
+
+import org.apache.flink.api.common.state.StateTtlConfig;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.runtime.generated.GeneratedRecordComparator;
+import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
+import org.apache.flink.table.runtime.operators.rank.AbstractTopNFunction;
+import org.apache.flink.table.runtime.operators.rank.RankRange;
+import org.apache.flink.table.runtime.operators.rank.RankType;
+import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
+
+/**
+ * Base class for TopN Function with async state api.
+ *
+ * <p>TODO FLINK-36831 support variable rank end in async state rank later.
+ */
+public abstract class AbstractAsyncStateTopNFunction extends 
AbstractTopNFunction {
+
+    public AbstractAsyncStateTopNFunction(
+            StateTtlConfig ttlConfig,
+            InternalTypeInfo<RowData> inputRowType,
+            GeneratedRecordComparator generatedSortKeyComparator,
+            RowDataKeySelector sortKeySelector,
+            RankType rankType,
+            RankRange rankRange,
+            boolean generateUpdateBefore,
+            boolean outputRankNumber) {
+        super(
+                ttlConfig,
+                inputRowType,
+                generatedSortKeyComparator,
+                sortKeySelector,
+                rankType,
+                rankRange,
+                generateUpdateBefore,
+                outputRankNumber);
+        if (!isConstantRankEnd) {
+            throw new UnsupportedOperationException(
+                    "Variable rank end is not supported in rank with async 
state api.");

Review Comment:
   Nit:how about giving a hint of supported version in the error message?



##########
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/rank/AbstractTopNFunction.java:
##########
@@ -326,4 +294,69 @@ private RowData createOutputRow(RowData inputRow, long 
rank, RowKind rowKind) {
     public void setKeyContext(KeyContext keyContext) {
         this.keyContext = keyContext;
     }
+
+    /** An abstract helper to do the logic Top-n used for all top-n functions. 
*/
+    public abstract static class AbstractTopNHelper {
+
+        protected final AbstractTopNFunction topNFunction;
+
+        protected final StateTtlConfig ttlConfig;
+
+        protected final KeySelector<RowData, RowData> sortKeySelector;
+
+        protected final Comparator<RowData> sortKeyComparator;
+
+        protected final boolean outputRankNumber;
+
+        protected final KeyContext keyContext;
+
+        // metrics

Review Comment:
   Nit: could this comment be clarified further by elaborating on the meaning 
of this metric?



##########
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/rank/FastTop1Function.java:
##########
@@ -112,66 +89,58 @@ public void open(OpenContext openContext) throws Exception 
{
         }
         dataState = getRuntimeContext().getState(valueStateDescriptor);
 
+        helper = new SyncStateFastTop1Helper();
+
         // metrics

Review Comment:
   Ditto



##########
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/rank/async/AsyncStateFastTop1Function.java:
##########
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.rank.async;
+
+import org.apache.flink.api.common.functions.OpenContext;
+import org.apache.flink.api.common.serialization.SerializerConfigImpl;
+import org.apache.flink.api.common.state.StateTtlConfig;
+import org.apache.flink.api.common.state.v2.StateFuture;
+import org.apache.flink.api.common.state.v2.ValueState;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.state.StateFutureUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.v2.ValueStateDescriptor;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import 
org.apache.flink.streaming.runtime.operators.asyncprocessing.AsyncStateProcessingOperator;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.runtime.generated.GeneratedRecordComparator;
+import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
+import org.apache.flink.table.runtime.operators.rank.AppendOnlyTopNFunction;
+import org.apache.flink.table.runtime.operators.rank.FastTop1Function;
+import org.apache.flink.table.runtime.operators.rank.RankRange;
+import org.apache.flink.table.runtime.operators.rank.RankType;
+import org.apache.flink.table.runtime.operators.rank.UpdatableTopNFunction;
+import org.apache.flink.table.runtime.operators.rank.utils.FastTop1Helper;
+import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
+import org.apache.flink.util.Collector;
+
+/**
+ * A more concise implementation for {@link AppendOnlyTopNFunction} and {@link
+ * UpdatableTopNFunction} when only Top-1 is desired. This function can handle 
updating stream
+ * because the RankProcessStrategy is inferred as UpdateFastStrategy, i.e., 1) 
the upsert key of
+ * input steam contains partition key; 2) the sort field is updated monotonely 
under the upsert key.
+ *
+ * <p>Different with {@link FastTop1Function}, this function is used with 
async state api.
+ */
+public class AsyncStateFastTop1Function extends AbstractAsyncStateTopNFunction
+        implements CheckpointedFunction {
+
+    private static final long serialVersionUID = 1L;
+
+    private final TypeSerializer<RowData> inputRowSer;
+    private final long cacheSize;
+
+    // a value state stores the latest record
+    private transient ValueState<RowData> dataState;
+
+    private transient AsyncStateFastTop1Helper helper;
+
+    public AsyncStateFastTop1Function(
+            StateTtlConfig ttlConfig,
+            InternalTypeInfo<RowData> inputRowType,
+            GeneratedRecordComparator generatedSortKeyComparator,
+            RowDataKeySelector sortKeySelector,
+            RankType rankType,
+            RankRange rankRange,
+            boolean generateUpdateBefore,
+            boolean outputRankNumber,
+            long cacheSize) {
+        super(
+                ttlConfig,
+                inputRowType,
+                generatedSortKeyComparator,
+                sortKeySelector,
+                rankType,
+                rankRange,
+                generateUpdateBefore,
+                outputRankNumber);
+
+        this.inputRowSer = inputRowType.createSerializer(new 
SerializerConfigImpl());
+        this.cacheSize = cacheSize;
+    }
+
+    @Override
+    public void open(OpenContext openContext) throws Exception {
+        super.open(openContext);
+
+        ValueStateDescriptor<RowData> valueStateDescriptor =
+                new ValueStateDescriptor<>("Top1-Rank-State", inputRowType);
+        if (ttlConfig.isEnabled()) {
+            valueStateDescriptor.enableTimeToLive(ttlConfig);
+        }
+        dataState =
+                ((StreamingRuntimeContext) 
getRuntimeContext()).getValueState(valueStateDescriptor);
+
+        helper = new AsyncStateFastTop1Helper();
+
+        // metrics

Review Comment:
   Nit: the comment here appears to be redundant because the method name is 
already clear and self-explanatory. 



##########
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/rank/FastTop1Function.java:
##########
@@ -112,66 +89,58 @@ public void open(OpenContext openContext) throws Exception 
{
         }
         dataState = getRuntimeContext().getState(valueStateDescriptor);
 
+        helper = new SyncStateFastTop1Helper();
+
         // metrics
-        registerMetric(kvCache.size() * getDefaultTopNSize());
+        helper.registerMetric();
     }
 
     @Override
     public void processElement(RowData input, Context ctx, Collector<RowData> 
out)
             throws Exception {
-        requestCount += 1;
+        helper.accRequestCount();
+
         // load state under current key if necessary
         RowData currentKey = (RowData) keyContext.getCurrentKey();
-        RowData prevRow = kvCache.getIfPresent(currentKey);
+        RowData prevRow = helper.getPrevRowFromCache(currentKey);
         if (prevRow == null) {
             prevRow = dataState.value();
         } else {
-            hitCount += 1;
+            helper.accHitCount();
         }
 
         // first row under current key.
         if (prevRow == null) {
-            kvCache.put(currentKey, inputRowSer.copy(input));
-            if (outputRankNumber) {
-                collectInsert(out, input, 1);
-            } else {
-                collectInsert(out, input);
-            }
-            return;
-        }
-
-        RowData curSortKey = sortKeySelector.getKey(input);
-        RowData oldSortKey = sortKeySelector.getKey(prevRow);
-        int compare = sortKeyComparator.compare(curSortKey, oldSortKey);
-        // current sort key is higher than old sort key
-        if (compare < 0) {
-            kvCache.put(currentKey, inputRowSer.copy(input));
-            // Note: partition key is unique key if only top-1 is desired,
-            //  thus emitting UB and UA here
-            if (outputRankNumber) {
-                collectUpdateBefore(out, prevRow, 1);
-                collectUpdateAfter(out, input, 1);
-            } else {
-                collectUpdateBefore(out, prevRow);
-                collectUpdateAfter(out, input);
-            }
+            helper.processAsFirstRow(input, currentKey, out);
+        } else {
+            helper.processWithPrevRow(input, currentKey, prevRow, out);
         }
     }
 
     @Override
     public void snapshotState(FunctionSnapshotContext context) throws 
Exception {
-        for (Map.Entry<RowData, RowData> entry : kvCache.asMap().entrySet()) {
-            keyContext.setCurrentKey(entry.getKey());
-            flushBufferToState(entry.getValue());
-        }
+        helper.flushAllCacheToState();
     }
 
     @Override
     public void initializeState(FunctionInitializationContext context) throws 
Exception {
         // nothing to do
     }
 
-    private void flushBufferToState(RowData rowData) throws Exception {
-        dataState.update(rowData);
+    private class SyncStateFastTop1Helper extends FastTop1Helper {
+
+        public SyncStateFastTop1Helper() {
+            super(
+                    FastTop1Function.this,
+                    inputRowSer,
+                    cacheSize,
+                    FastTop1Function.this.getDefaultTopNSize());
+        }
+
+        @Override
+        public void flushBufferToState(RowData currentKey, RowData value) 
throws Exception {
+            keyContext.setCurrentKey(currentKey);
+            FastTop1Function.this.dataState.update(value);

Review Comment:
   This implementation breaks the encapsulation of the class by directly 
accessing the private field.. Would it be possible to add a new method in 
FastTop1Function to handle this logic instead?



##########
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/rank/AbstractSyncStateTopNFunction.java:
##########
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.rank;
+
+import org.apache.flink.api.common.functions.OpenContext;
+import org.apache.flink.api.common.state.StateTtlConfig;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.runtime.generated.GeneratedRecordComparator;
+import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
+import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
+
+/** Base class for TopN Function with sync state api. */
+public abstract class AbstractSyncStateTopNFunction extends 
AbstractTopNFunction {
+
+    private ValueState<Long> rankEndState;
+
+    public AbstractSyncStateTopNFunction(
+            StateTtlConfig ttlConfig,
+            InternalTypeInfo<RowData> inputRowType,
+            GeneratedRecordComparator generatedSortKeyComparator,
+            RowDataKeySelector sortKeySelector,
+            RankType rankType,
+            RankRange rankRange,
+            boolean generateUpdateBefore,
+            boolean outputRankNumber) {
+        super(
+                ttlConfig,
+                inputRowType,
+                generatedSortKeyComparator,
+                sortKeySelector,
+                rankType,
+                rankRange,
+                generateUpdateBefore,
+                outputRankNumber);
+    }
+
+    @Override
+    public void open(OpenContext openContext) throws Exception {
+        super.open(openContext);
+
+        if (!isConstantRankEnd) {
+            ValueStateDescriptor<Long> rankStateDesc =
+                    new ValueStateDescriptor<>("rankEnd", Types.LONG);
+            if (ttlConfig.isEnabled()) {
+                rankStateDesc.enableTimeToLive(ttlConfig);
+            }
+            rankEndState = getRuntimeContext().getState(rankStateDesc);
+        }
+    }
+
+    /**
+     * Initialize rank end.
+     *
+     * @param row input record
+     * @return rank end
+     * @throws Exception
+     */
+    protected long initRankEnd(RowData row) throws Exception {

Review Comment:
   The return value of this function is never used.



##########
flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/RankHarnessTest.scala:
##########
@@ -451,4 +467,128 @@ class RankHarnessTest(mode: StateBackendMode) extends 
HarnessTestBase(mode) {
     assertor.assertOutputEqualsSorted("result mismatch", expectedOutput, 
result)
     testHarness.close()
   }
+
+  def prepareTop1Tester(query: String, operatorNameIdentifier: String)
+      : (KeyedOneInputStreamOperatorTestHarness[RowData, RowData, RowData], 
RowDataHarnessAssertor) = {
+    val sourceDDL =
+      s"""
+         |CREATE TEMPORARY TABLE T(
+         |  a STRING PRIMARY KEY NOT ENFORCED,
+         |  b BIGINT
+         |) WITH (
+         |  'connector' = 'values',
+         |  'changelog-mode' = 'I'
+         |)
+         |""".stripMargin
+    tEnv.executeSql(sourceDDL)
+
+    val t1 = tEnv.sqlQuery(query)
+
+    val testHarness =
+      createHarnessTester(t1.toRetractStream[Row], operatorNameIdentifier)
+    val assertor = new RowDataHarnessAssertor(
+      Array(
+        DataTypes.STRING().getLogicalType,
+        DataTypes.BIGINT().getLogicalType,
+        DataTypes.BIGINT().getLogicalType))
+
+    (testHarness, assertor)
+  }
+
+  @TestTemplate
+  def testAppendFastTop1(): Unit = {
+    tEnv.getConfig.setIdleStateRetention(Duration.ofSeconds(1))
+    val query =
+      """
+        |SELECT a, b, rn
+        |FROM
+        |(
+        |    SELECT a, b,
+        |        ROW_NUMBER() OVER (PARTITION BY a ORDER BY b DESC) AS rn
+        |    FROM T
+        |) t1
+        |WHERE rn <= 1
+      """.stripMargin
+    val (testHarness, assertor) =
+      prepareTop1Tester(query, "Rank(strategy=[AppendFastStrategy")
+
+    if (enableAsyncState) {
+      assertThat(isAsyncStateOperator(testHarness)).isTrue
+    } else {
+      assertThat(isAsyncStateOperator(testHarness)).isFalse
+    }
+
+    testHarness.open()
+
+    testHarness.processElement(binaryRecord(INSERT, "a", 2L: JLong))
+    testHarness.processElement(binaryRecord(INSERT, "a", 1L: JLong))
+    testHarness.processElement(binaryRecord(INSERT, "a", 3L: JLong))
+
+    val result = dropWatermarks(testHarness.getOutput.toArray)
+
+    val expectedOutput = new ConcurrentLinkedQueue[Object]()
+    expectedOutput.add(binaryRecord(INSERT, "a", 2L: JLong))
+    expectedOutput.add(binaryRecord(UPDATE_BEFORE, "a", 2L: JLong))
+    expectedOutput.add(binaryRecord(UPDATE_AFTER, "a", 3L: JLong))
+
+    assertor.assertOutputEqualsSorted("result mismatch", expectedOutput, 
result)
+
+    testHarness.close()
+  }
+
+  @TestTemplate
+  def testUpdateFastTop1(): Unit = {
+    tEnv.getConfig.setIdleStateRetention(Duration.ofSeconds(1))
+    val query =
+      """
+        |SELECT a, b, rn
+        |FROM
+        |(
+        |    SELECT a, b,
+        |        ROW_NUMBER() OVER (PARTITION BY a ORDER BY b DESC) AS rn
+        |    FROM (
+        |       select a, count(*) as b from T group by a
+        |    ) t1
+        |) t2
+        |WHERE rn <= 1
+      """.stripMargin
+    val (testHarness, assertor) =
+      prepareTop1Tester(query, "Rank(strategy=[UpdateFastStrategy")
+
+    if (enableAsyncState) {
+      assertThat(isAsyncStateOperator(testHarness)).isTrue
+    } else {
+      assertThat(isAsyncStateOperator(testHarness)).isFalse
+    }
+
+    testHarness.open()
+
+    testHarness.processElement(binaryRecord(INSERT, "a", 2L: JLong))
+    testHarness.processElement(binaryRecord(UPDATE_AFTER, "a", 3L: JLong))
+    testHarness.processElement(binaryRecord(UPDATE_AFTER, "a", 4L: JLong))
+
+    val result = dropWatermarks(testHarness.getOutput.toArray)
+
+    val expectedOutput = new ConcurrentLinkedQueue[Object]()
+    expectedOutput.add(binaryRecord(INSERT, "a", 2L: JLong))
+    expectedOutput.add(binaryRecord(UPDATE_BEFORE, "a", 2L: JLong))
+    expectedOutput.add(binaryRecord(UPDATE_AFTER, "a", 3L: JLong))
+    expectedOutput.add(binaryRecord(UPDATE_BEFORE, "a", 3L: JLong))
+    expectedOutput.add(binaryRecord(UPDATE_AFTER, "a", 4L: JLong))
+

Review Comment:
   This test has left me a bit confused. According to the SQL query, we are 
calculating the maximum number of rows for each distinct value of a. How does 
number which is greater than 1 is calculated in the expected?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to