showuon commented on a change in pull request #9888:
URL: https://github.com/apache/kafka/pull/9888#discussion_r557447421



##########
File path: 
streams/src/test/java/org/apache/kafka/streams/integration/AdjustStreamThreadCountTest.java
##########
@@ -88,69 +89,96 @@ public void setup() {
                 mkEntry(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, 
Serdes.StringSerde.class)
             )
         );
+
+        createAndRunStream();
     }
 
     @After
     public void teardown() throws IOException {
+        if (kafkaStreams != null) {
+            kafkaStreams.close();
+        }
         purgeLocalStreamsState(properties);
     }
 
+    private void addStreamStateChangeListener(final KafkaStreams kafkaStreams) 
{
+        // we store each new state in state transition so that we won't miss 
any state change
+        kafkaStreams.setStateListener(
+            (newState, oldState) -> stateToTransitions.add(newState)
+        );
+    }
+
+    private void waitForStateTransition(final KafkaStreams.State expected) 
throws InterruptedException {
+        waitForCondition(
+            () -> !stateToTransitions.isEmpty() && 
stateToTransitions.contains(expected),
+            DEFAULT_DURATION.toMillis(),
+            () -> String.format("Client did not change to the %s state in 
time. Observed new state transitions: %s",
+                expected, stateToTransitions)
+        );
+    }
+
+    private void createAndRunStream() throws InterruptedException {
+        kafkaStreams = new KafkaStreams(builder.build(), properties);
+        addStreamStateChangeListener(kafkaStreams);
+        kafkaStreams.start();
+        waitForStateTransition(KafkaStreams.State.RUNNING);
+    }
+
     @Test
     public void shouldAddStreamThread() throws Exception {
-        try (final KafkaStreams kafkaStreams = new 
KafkaStreams(builder.build(), properties)) {
-            
StreamsTestUtils.startKafkaStreamsAndWaitForRunningState(kafkaStreams);
-            final int oldThreadCount = 
kafkaStreams.localThreadsMetadata().size();
-            assertThat(kafkaStreams.localThreadsMetadata().stream().map(t -> 
t.threadName().split("-StreamThread-")[1]).sorted().toArray(), equalTo(new 
String[] {"1", "2"}));
-
-            final Optional<String> name = kafkaStreams.addStreamThread();
-
-            assertThat(name, not(Optional.empty()));
-            TestUtils.waitForCondition(
-                () -> kafkaStreams.localThreadsMetadata().stream().sequential()
-                        .map(ThreadMetadata::threadName).anyMatch(t -> 
t.equals(name.orElse(""))),
-                "Wait for the thread to be added"
-            );
-            assertThat(kafkaStreams.localThreadsMetadata().size(), 
equalTo(oldThreadCount + 1));
-            assertThat(
-                kafkaStreams
-                    .localThreadsMetadata()
-                    .stream()
-                    .map(t -> t.threadName().split("-StreamThread-")[1])
-                    .sorted().toArray(),
-                equalTo(new String[] {"1", "2", "3"})
-            );
-            waitForApplicationState(Collections.singletonList(kafkaStreams), 
KafkaStreams.State.REBALANCING, DEFAULT_DURATION);
-            waitForApplicationState(Collections.singletonList(kafkaStreams), 
KafkaStreams.State.RUNNING, DEFAULT_DURATION);
-        }
+        final int oldThreadCount = kafkaStreams.localThreadsMetadata().size();
+        assertThat(kafkaStreams.localThreadsMetadata().stream().map(t -> 
t.threadName().split("-StreamThread-")[1]).sorted().toArray(), equalTo(new 
String[] {"1", "2"}));
+
+        stateToTransitions.clear();
+        final Optional<String> name = kafkaStreams.addStreamThread();
+
+        assertThat(name, not(Optional.empty()));
+        TestUtils.waitForCondition(
+            () -> kafkaStreams.localThreadsMetadata().stream().sequential()
+                    .map(ThreadMetadata::threadName).anyMatch(t -> 
t.equals(name.orElse(""))),
+            "Wait for the thread to be added"
+        );
+        assertThat(kafkaStreams.localThreadsMetadata().size(), 
equalTo(oldThreadCount + 1));
+        assertThat(
+            kafkaStreams
+                .localThreadsMetadata()
+                .stream()
+                .map(t -> t.threadName().split("-StreamThread-")[1])
+                .sorted().toArray(),
+            equalTo(new String[] {"1", "2", "3"})
+        );
+
+        waitForStateTransition(KafkaStreams.State.REBALANCING);
+        waitForStateTransition(KafkaStreams.State.RUNNING);

Review comment:
       Good suggestion! Added a `hasStateTransition` method to verify that. 
Thanks.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to