Myasuka commented on code in PR #20405:
URL: https://github.com/apache/flink/pull/20405#discussion_r942346683


##########
flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/EmbeddedRocksDBStateBackendTest.java:
##########
@@ -637,11 +644,73 @@ public void testMapStateClear() throws Exception {
                             throw new RocksDBException("Artificial failure");
                         })
                 .when(keyedStateBackend.db)
-                .newIterator(any(ColumnFamilyHandle.class), 
any(ReadOptions.class));
+                .deleteRange(any(ColumnFamilyHandle.class), any(byte[].class), 
any(byte[].class));
 
         state.clear();
     }
 
+    @Test
+    public void testMapStateClearCorrectly() throws Exception {
+        verifyMapStateClear(Byte.MAX_VALUE + 1); // one byte prefix
+        
verifyMapStateClear(KeyGroupRangeAssignment.UPPER_BOUND_MAX_PARALLELISM); // 
two byte prefix

Review Comment:
   We should separate these two tests as two unit tests.



##########
flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/EmbeddedRocksDBStateBackendTest.java:
##########
@@ -145,6 +149,7 @@ public static List<Object[]> modes() {
     private RocksDB db = null;
     private ColumnFamilyHandle defaultCFHandle = null;
     private RocksDBStateUploader rocksDBStateUploader = null;
+    private int maxParallelism = 2;

Review Comment:
   I think making `maxParallelism` as a private field is a bit strange. We can 
just pass `2` to the `RocksDBTestUtils.builderForTestDB` method.



##########
flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/EmbeddedRocksDBStateBackendTest.java:
##########
@@ -637,11 +644,73 @@ public void testMapStateClear() throws Exception {
                             throw new RocksDBException("Artificial failure");
                         })
                 .when(keyedStateBackend.db)
-                .newIterator(any(ColumnFamilyHandle.class), 
any(ReadOptions.class));
+                .deleteRange(any(ColumnFamilyHandle.class), any(byte[].class), 
any(byte[].class));
 
         state.clear();
     }
 
+    @Test
+    public void testMapStateClearCorrectly() throws Exception {
+        verifyMapStateClear(Byte.MAX_VALUE + 1); // one byte prefix
+        
verifyMapStateClear(KeyGroupRangeAssignment.UPPER_BOUND_MAX_PARALLELISM); // 
two byte prefix
+    }
+
+    public void verifyMapStateClear(int maxParallelism) throws Exception {
+        try {
+            this.maxParallelism = maxParallelism;
+            setupRocksKeyedStateBackend();
+            MapStateDescriptor<Integer, String> kvId =
+                    new MapStateDescriptor<>("id", Integer.class, 
String.class);
+            MapState<Integer, String> state =
+                    keyedStateBackend.getPartitionedState(
+                            VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
+            keyedStateBackend = spy(keyedStateBackend);
+            doAnswer(
+                            invocationOnMock -> { // ensure that each KeyGroup 
can use the same Key
+                                keyedStateBackend
+                                        .getKeyContext()
+                                        
.setCurrentKey(invocationOnMock.getArgument(0));
+                                keyedStateBackend
+                                        .getSharedRocksKeyBuilder()
+                                        .setKeyAndKeyGroup(
+                                                
keyedStateBackend.getCurrentKey(),
+                                                
keyedStateBackend.getCurrentKeyGroupIndex());
+                                return null;
+                            })
+                    .when(keyedStateBackend)
+                    .setCurrentKey(any());
+
+            Set<Integer> testKeyGroups =
+                    new HashSet<>(Arrays.asList(0, maxParallelism / 2, 
maxParallelism - 1));
+            for (int testKeyGroup : testKeyGroups) { // initialize
+                if (testKeyGroup >= maxParallelism - 1) {
+                    break;
+                }
+                keyedStateBackend.setCurrentKeyGroupIndex(testKeyGroup);
+                keyedStateBackend.setCurrentKey(testKeyGroup);
+                state.put(testKeyGroup, "retain " + testKeyGroup);
+                keyedStateBackend.setCurrentKey(-1); // 0xffff for the key
+                state.put(testKeyGroup, "clear " + testKeyGroup);
+            }
+
+            for (int testKeyGroup : testKeyGroups) { // test for clear
+                if (testKeyGroup >= maxParallelism - 1) {
+                    break;
+                }
+                keyedStateBackend.setCurrentKeyGroupIndex(testKeyGroup);
+                keyedStateBackend.setCurrentKey(-1);
+                assertEquals("clear " + testKeyGroup, state.get(testKeyGroup));
+                state.clear();
+                assertNull(state.get(testKeyGroup));
+                keyedStateBackend.setCurrentKey(testKeyGroup);
+                assertEquals("retain " + testKeyGroup, 
state.get(testKeyGroup));
+            }

Review Comment:
   I feel like the unit test is not easy to understand.
   I think we can put data per key group in the map state (just let different 
keys stay in different key groups), and then we would get an empty iterator for 
`state.iterator()` with cleared keys and get all correct results on existing 
keys. I think this is easy to understand and really check the core logic. The 
only left thing is how long will we run these two tests, maybe we need to 
speedup the test via using `state#get`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to