ableegoldman commented on code in PR #16201: URL: https://github.com/apache/kafka/pull/16201#discussion_r1628531098
########## streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java: ########## @@ -238,24 +310,37 @@ public static Map<ProcessId, KafkaStreamsAssignment> optimizeRackAwareStandbyTas final int crossRackTrafficCost = applicationState.assignmentConfigs().rackAwareTrafficCost().getAsInt(); final int nonOverlapCost = applicationState.assignmentConfigs().rackAwareNonOverlapCost().getAsInt(); + final Set<TaskId> standbyTasksToOptimize = new HashSet<>(); + kafkaStreamsAssignments.values().forEach(assignment -> { + final Set<TaskId> standbyTasksForAssignment = assignment.tasks().values().stream() + .filter(task -> task.type() == AssignedTask.Type.STANDBY) + .map(AssignedTask::id) + .collect(Collectors.toSet()); + standbyTasksToOptimize.addAll(standbyTasksForAssignment); + }); + final Map<TaskId, Set<TaskTopicPartition>> topicPartitionsByTaskId = applicationState.allTasks().values().stream().collect(Collectors.toMap( TaskInfo::id, t -> t.topicPartitions().stream().filter(TaskTopicPartition::isChangelog).collect(Collectors.toSet())) ); - final List<TaskId> taskIds = new ArrayList<>(topicPartitionsByTaskId.keySet()); + final List<TaskId> taskIds = new ArrayList<>(standbyTasksToOptimize); Review Comment: If the only point of the `standbyTasksToOptimize` set is to create this list, we can just do that directly with `.flatMap` (though I think we should keep the name, ie rename `taskIds` to `standbyTasksToOptimize`): ``` final List<TaskId> taskIds = kafkaStreamsAssignments.values().stream() .flatMap(r -> r.tasks().values().stream()) .filter(task -> task.type() == AssignedTask.Type.STANDBY) .map(AssignedTask::id) .distinct().collect(Collectors.toList()); ``` ########## streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java: ########## @@ -92,7 +164,7 @@ public static Map<ProcessId, KafkaStreamsAssignment> identityAssignment(final Ap */ public static Map<ProcessId, KafkaStreamsAssignment> defaultStandbyTaskAssignment(final ApplicationState applicationState, final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) { - if (applicationState.assignmentConfigs().rackAwareAssignmentTags().isEmpty()) { Review Comment: LOL great bug 😂 ########## streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java: ########## @@ -238,24 +306,37 @@ public static Map<ProcessId, KafkaStreamsAssignment> optimizeRackAwareStandbyTas final int crossRackTrafficCost = applicationState.assignmentConfigs().rackAwareTrafficCost().getAsInt(); final int nonOverlapCost = applicationState.assignmentConfigs().rackAwareNonOverlapCost().getAsInt(); + final Set<TaskId> standbyTasksToOptimize = new HashSet<>(); + kafkaStreamsAssignments.values().forEach(assignment -> { + final Set<TaskId> standbyTasksForAssignment = assignment.tasks().values().stream() + .filter(task -> task.type() == AssignedTask.Type.STANDBY) + .map(AssignedTask::id) + .collect(Collectors.toSet()); + standbyTasksToOptimize.addAll(standbyTasksForAssignment); + }); + final Map<TaskId, Set<TaskTopicPartition>> topicPartitionsByTaskId = applicationState.allTasks().values().stream().collect(Collectors.toMap( TaskInfo::id, t -> t.topicPartitions().stream().filter(TaskTopicPartition::isChangelog).collect(Collectors.toSet())) ); - final List<TaskId> taskIds = new ArrayList<>(topicPartitionsByTaskId.keySet()); + final List<TaskId> taskIds = new ArrayList<>(standbyTasksToOptimize); final Map<ProcessId, KafkaStreamsState> kafkaStreamsStates = applicationState.kafkaStreamsStates(false); final Map<UUID, Optional<String>> clientRacks = new HashMap<>(); final List<UUID> clientIds = new ArrayList<>(); final Map<UUID, KafkaStreamsAssignment> assignmentsByUuid = new HashMap<>(); - for (final Map.Entry<ProcessId, KafkaStreamsAssignment> entry : kafkaStreamsAssignments.entrySet()) { - final UUID uuid = entry.getKey().id(); - clientIds.add(uuid); - clientRacks.put(uuid, kafkaStreamsStates.get(entry.getKey()).rackId()); - assignmentsByUuid.put(uuid, entry.getValue()); + for (final Map.Entry<ProcessId, KafkaStreamsState> entry : kafkaStreamsStates.entrySet()) { + final ProcessId processId = entry.getKey(); + clientIds.add(processId.id()); + clientRacks.put(processId.id(), entry.getValue().rackId()); + if (!kafkaStreamsAssignments.containsKey(processId)) { Review Comment: We have this exact same loop in the active task optimization method -- I assume we should make this same change up there? ########## streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java: ########## @@ -338,8 +419,8 @@ public static Map<ProcessId, KafkaStreamsAssignment> optimizeRackAwareStandbyTas taskMoved |= graphConstructor.assignTaskFromMinCostFlow( assignmentGraph.graph, - clientIds, - taskIds, + clients, + taskIdList, Review Comment: oof -- this was way too easy of a bug to hit, can we rename both of these variables so it's easier to keep track of which thing is what & which thing should be passed in where? ########## streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignmentUtilsTest.java: ########## @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_3; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_4; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_5; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.uuidForInt; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; + +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.assignment.ApplicationState; +import org.apache.kafka.streams.processor.assignment.AssignmentConfigs; +import org.apache.kafka.streams.processor.assignment.KafkaStreamsAssignment; +import org.apache.kafka.streams.processor.assignment.KafkaStreamsAssignment.AssignedTask; +import org.apache.kafka.streams.processor.assignment.KafkaStreamsState; +import org.apache.kafka.streams.processor.assignment.ProcessId; +import org.apache.kafka.streams.processor.assignment.TaskAssignmentUtils; +import org.apache.kafka.streams.processor.assignment.TaskInfo; +import org.apache.kafka.streams.processor.assignment.TaskTopicPartition; +import org.junit.Rule; +import org.junit.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.junit.rules.Timeout; + +public class TaskAssignmentUtilsTest { + + @Rule + public Timeout timeout = new Timeout(30, TimeUnit.SECONDS); + + @ParameterizedTest + @ValueSource(strings = { + StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_MIN_TRAFFIC, + StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_BALANCE_SUBTOPOLOGY, + }) + public void shouldOptimizeActiveTaskSimple(final String strategy) { Review Comment: These tests are so much more readable than the old style, nice to have all the input info and outputs so cleanly displayed 😄 ########## streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsAssignmentScaleTest.java: ########## @@ -190,11 +197,35 @@ private void completeLargeAssignment(final int numPartitions, configMap.put(InternalConfig.INTERNAL_TASK_ASSIGNOR_CLASS, taskAssignor.getName()); configMap.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, numStandbys); - final MockInternalTopicManager mockInternalTopicManager = new MockInternalTopicManager( + configMap.put(StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_CONFIG, + StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_NONE); Review Comment: I forget, can we run it with the other rack-aware strategy? Or did both of them take too long for this scale test? ########## streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java: ########## @@ -613,9 +704,17 @@ private static Map<ProcessId, KafkaStreamsAssignment> loadBasedStandbyTaskAssign .collect(Collectors.toSet()); final Map<TaskId, Integer> tasksToRemainingStandbys = statefulTaskIds.stream() .collect(Collectors.toMap(Function.identity(), t -> numStandbyReplicas)); - final Map<UUID, KafkaStreamsAssignment> clients = kafkaStreamsAssignments.entrySet().stream().collect(Collectors.toMap( + final Map<UUID, KafkaStreamsAssignment> clients = streamStates.entrySet().stream().collect(Collectors.toMap( entry -> entry.getKey().id(), - Map.Entry::getValue + entry -> { + final KafkaStreamsState state = entry.getValue(); + if (kafkaStreamsAssignments.containsKey(state.processId())) { + return kafkaStreamsAssignments.get(state.processId()); + } + final KafkaStreamsAssignment newAssignment = KafkaStreamsAssignment.of(state.processId(), new HashSet<>()); + kafkaStreamsAssignments.put(state.processId(), newAssignment); + return newAssignment; + } Review Comment: IIUC this is essentially the same "bug" as above, ie accounting for an invalid assignment being passed in where not all ProcessIds in the KafkaStreamsState map have a corresponding KafkaStreamsAssignment, right? It feels a little "sneaky" to just fill in missing assignments while constructing this map. Also, it seems like we'd want to do the same thing -- initialize any missing ProcessIds with empty an KafkaStreamsAssignment -- for each of the methods here, ie #defaultStandbyTaskAssignment and the two rack-aware optimization methods for active & standby tasks. So maybe just split out this logic into a separate helper function that each method can call at the initialization period before the actual algorithm begins. That way we don't need to worry about this ever again Also: once we do the UUID->ProcessID thing, we'll want to get rid of this particular variable altogether since it's only purpose in the beginning was to convert from a Map<ProcessId, KafkaStreamsAssignment> to a Map<UUID, KafkaStreamsAssignment>. We don't want to add more logic here specifically ########## streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignmentUtilsTest.java: ########## @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_3; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_4; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_5; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.uuidForInt; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; + +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.assignment.ApplicationState; +import org.apache.kafka.streams.processor.assignment.AssignmentConfigs; +import org.apache.kafka.streams.processor.assignment.KafkaStreamsAssignment; +import org.apache.kafka.streams.processor.assignment.KafkaStreamsAssignment.AssignedTask; +import org.apache.kafka.streams.processor.assignment.KafkaStreamsState; +import org.apache.kafka.streams.processor.assignment.ProcessId; +import org.apache.kafka.streams.processor.assignment.TaskAssignmentUtils; +import org.apache.kafka.streams.processor.assignment.TaskInfo; +import org.apache.kafka.streams.processor.assignment.TaskTopicPartition; +import org.junit.Rule; +import org.junit.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.junit.rules.Timeout; + +public class TaskAssignmentUtilsTest { + + @Rule + public Timeout timeout = new Timeout(30, TimeUnit.SECONDS); + + @ParameterizedTest + @ValueSource(strings = { + StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_MIN_TRAFFIC, + StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_BALANCE_SUBTOPOLOGY, + }) + public void shouldOptimizeActiveTaskSimple(final String strategy) { + final AssignmentConfigs assignmentConfigs = defaultAssignmentConfigs( + strategy, 100, 1, 1, Collections.emptyList()); + final Map<TaskId, TaskInfo> tasks = mkMap( + mkTaskInfo(TASK_0_0, true, mkSet("rack-2")), + mkTaskInfo(TASK_0_1, true, mkSet("rack-1")) + ); + final Map<ProcessId, KafkaStreamsState> kafkaStreamsStates = mkMap( + mkStreamState(1, 1, Optional.of("rack-1")), + mkStreamState(2, 1, Optional.of("rack-2")) + ); + final ApplicationState applicationState = new TestApplicationState( + assignmentConfigs, kafkaStreamsStates, tasks); + + final Map<ProcessId, KafkaStreamsAssignment> assignments = mkMap( + mkAssignment(AssignedTask.Type.ACTIVE, 1, TASK_0_0), + mkAssignment(AssignedTask.Type.ACTIVE, 2, TASK_0_1) + ); + + TaskAssignmentUtils.optimizeRackAwareActiveTasks( + applicationState, assignments, new TreeSet<>(tasks.keySet())); + assertThat(assignments.size(), equalTo(2)); + assertThat(assignments.get(processId(1)).tasks().keySet(), equalTo(mkSet(TASK_0_1))); + assertThat(assignments.get(processId(2)).tasks().keySet(), equalTo(mkSet(TASK_0_0))); + + TaskAssignmentUtils.optimizeRackAwareActiveTasks( + applicationState, assignments, new TreeSet<>(tasks.keySet())); + assertThat(assignments.size(), equalTo(2)); + assertThat(assignments.get(processId(1)).tasks().keySet(), equalTo(mkSet(TASK_0_1))); + assertThat(assignments.get(processId(2)).tasks().keySet(), equalTo(mkSet(TASK_0_0))); Review Comment: Did you mean to repeat this? If it's intentional that's fine (though maybe leave a comment as to why) but if not then please remove the duplicate code ########## streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java: ########## @@ -92,7 +164,7 @@ public static Map<ProcessId, KafkaStreamsAssignment> identityAssignment(final Ap */ public static Map<ProcessId, KafkaStreamsAssignment> defaultStandbyTaskAssignment(final ApplicationState applicationState, final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) { - if (applicationState.assignmentConfigs().rackAwareAssignmentTags().isEmpty()) { Review Comment: LOL whoops 😂 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org