dmvk commented on code in PR #21981: URL: https://github.com/apache/flink/pull/21981#discussion_r1116631711
########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/JobSchedulingPlan.java: ########## @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptive; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.jobmaster.SlotInfo; +import org.apache.flink.runtime.scheduler.adaptive.allocator.VertexParallelism; + +import java.util.Collection; + +/** + * A plan that describes how to execute {@link org.apache.flink.runtime.jobgraph.JobGraph JobGraph}. + * + * <ol> + * <li>{@link #vertexParallelism} is necessary to create {@link + * org.apache.flink.runtime.executiongraph.ExecutionGraph ExecutionGraph} + * <li>{@link #slotAssignments} are used to schedule it onto the cluster + * </ol> + * + * {@link AdaptiveScheduler} passes this structure from {@link WaitingForResources} to {@link + * CreatingExecutionGraph} stages. + */ +@Internal +public class JobSchedulingPlan { + private final VertexParallelism vertexParallelism; + private final Collection<SlotAssignment> slotAssignments; + + public JobSchedulingPlan( + VertexParallelism vertexParallelism, Collection<SlotAssignment> slotAssignments) { + this.vertexParallelism = vertexParallelism; + this.slotAssignments = slotAssignments; + } + + public VertexParallelism getVertexParallelism() { + return vertexParallelism; + } + + public Collection<SlotAssignment> getSlotAssignments() { + return slotAssignments; + } + + /** Assignment of a slot to some target (e.g. a slot sharing group). */ + public static class SlotAssignment { + private final SlotInfo slotInfo; + /** + * Interpreted by {@link + * org.apache.flink.runtime.scheduler.adaptive.allocator.SlotAllocator#tryReserveResources(JobSchedulingPlan)}. + * This can be a slot sharing group, a task, or something else. + */ + private final Object target; Review Comment: In the current code this is always `ExecutionSlotSharingGroup` ########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/SlotSharingSlotAllocator.java: ########## @@ -93,20 +93,15 @@ private static Map<SlotSharingGroupId, Integer> getMaxParallelismForSlotSharingG } @Override - public Optional<VertexParallelismWithSlotSharing> determineParallelism( + public Optional<? extends VertexParallelism> determineParallelism( Review Comment: Since there are no classes extending from `VertexParallelism`, would it make sense to change this? ```suggestion public Optional<VertexParallelism> determineParallelism( ``` ########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/StateLocalitySlotAssigner.java: ########## @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptive.allocator; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.clusterframework.types.AllocationID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; +import org.apache.flink.runtime.jobmaster.SlotInfo; +import org.apache.flink.runtime.scheduler.adaptive.JobSchedulingPlan.SlotAssignment; +import org.apache.flink.runtime.scheduler.adaptive.allocator.SlotSharingSlotAllocator.ExecutionSlotSharingGroup; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupRangeAssignment; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.PriorityQueue; + +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toMap; +import static org.apache.flink.runtime.scheduler.adaptive.allocator.DefaultSlotAssigner.createExecutionSlotSharingGroups; + +/** A {@link SlotAssigner} that assigns slots based on the number of local key groups. */ +@Internal +public class StateLocalitySlotAssigner implements SlotAssigner { + + private static class AllocationScore implements Comparable<AllocationScore> { + + private final String group; + private final AllocationID allocationId; + + public AllocationScore(String group, AllocationID allocationId, long score) { + this.group = group; + this.allocationId = allocationId; + this.score = score; + } + + private final long score; + + public String getGroup() { + return group; + } + + public AllocationID getAllocationId() { + return allocationId; + } + + public long getScore() { + return score; + } + + @Override + public int compareTo(StateLocalitySlotAssigner.AllocationScore other) { + int result = Long.compare(score, other.score); + if (result != 0) { + return result; + } + result = other.allocationId.compareTo(allocationId); + if (result != 0) { + return result; + } + return other.group.compareTo(group); + } + } + + @Override + public Collection<SlotAssignment> assignSlots( + JobInformation jobInformation, + Collection<? extends SlotInfo> freeSlots, + VertexParallelism vertexParallelism, + AllocationsInfo previousAllocations, + StateSizeEstimates stateSizeEstimates) { + final List<ExecutionSlotSharingGroup> allGroups = new ArrayList<>(); + for (SlotSharingGroup slotSharingGroup : jobInformation.getSlotSharingGroups()) { + allGroups.addAll(createExecutionSlotSharingGroups(vertexParallelism, slotSharingGroup)); + } + final Map<JobVertexID, Integer> parallelism = getParallelism(allGroups); + + // PQ orders the pairs (allocationID, groupID) by score, decreasing + // the score is computed as the potential amount of state that would reside locally + final PriorityQueue<AllocationScore> scores = + new PriorityQueue<>(Comparator.reverseOrder()); + for (ExecutionSlotSharingGroup group : allGroups) { + calculateScore( + group, + parallelism, + jobInformation, + previousAllocations, + stateSizeEstimates) + .entrySet().stream() + .map(e -> new AllocationScore(group.getId(), e.getKey(), e.getValue())) + .forEach(scores::add); + } + + Map<String, ExecutionSlotSharingGroup> groupsById = + allGroups.stream().collect(toMap(ExecutionSlotSharingGroup::getId, identity())); + Map<AllocationID, SlotInfo> slotsById = + freeSlots.stream().collect(toMap(SlotInfo::getAllocationId, identity())); + AllocationScore score; + final Collection<SlotAssignment> assignments = new ArrayList<>(); + while ((score = scores.poll()) != null) { + SlotInfo slot = slotsById.remove(score.getAllocationId()); + if (slot != null) { + ExecutionSlotSharingGroup group = groupsById.remove(score.getGroup()); + if (group != null) { Review Comment: If the group happens to be null, does it mean we're wasting a slot (because it has already been removed from the slotById table)? ########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/StateSizeEstimates.java: ########## @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptive.allocator; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.OperatorIDPair; +import org.apache.flink.runtime.checkpoint.CompletedCheckpoint; +import org.apache.flink.runtime.checkpoint.OperatorState; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.KeyedStateHandle; + +import javax.annotation.Nullable; + +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.stream.Collectors.toMap; + +/** Managed Keyed State size estimates used to make scheduling decisions. */ +class StateSizeEstimates { + private final Map<JobVertexID, Long> averages; + + public StateSizeEstimates() { + this(Collections.emptyMap()); + } + + public StateSizeEstimates(Map<JobVertexID, Long> averages) { + this.averages = averages; + } + + public Optional<Long> estimate(JobVertexID jobVertexId) { + return Optional.ofNullable(averages.get(jobVertexId)); + } + + static StateSizeEstimates empty() { + return new StateSizeEstimates(); + } + + static StateSizeEstimates fromGraph(@Nullable ExecutionGraph executionGraph) { + return Optional.ofNullable(executionGraph) + .flatMap(graph -> Optional.ofNullable(graph.getCheckpointCoordinator())) + .flatMap(coordinator -> Optional.ofNullable(coordinator.getCheckpointStore())) + .flatMap(store -> Optional.ofNullable(store.getLatestCheckpoint())) + .map( + cp -> + build( + fromCompletedCheckpoint(cp), + mapVerticesToOperators(executionGraph))) + .orElse(empty()); + } + + private static StateSizeEstimates build( + Map<OperatorID, Long> sizePerOperator, + Map<JobVertexID, Set<OperatorID>> verticesToOperators) { + Map<JobVertexID, Long> verticesToSizes = + verticesToOperators.entrySet().stream() + .collect( + toMap(Map.Entry::getKey, e -> size(e.getValue(), sizePerOperator))); + return new StateSizeEstimates(verticesToSizes); + } + + private static long size(Set<OperatorID> ids, Map<OperatorID, Long> sizes) { + return ids.stream() + .mapToLong(key -> sizes.getOrDefault(key, 0L)) + .boxed() + .reduce(Long::sum) + .orElse(0L); + } + + private static Map<JobVertexID, Set<OperatorID>> mapVerticesToOperators( + ExecutionGraph executionGraph) { + return executionGraph.getAllVertices().entrySet().stream() + .collect(toMap(Map.Entry::getKey, e -> getOperatorIDS(e.getValue()))); + } + + private static Set<OperatorID> getOperatorIDS(ExecutionJobVertex v) { + return v.getOperatorIDs().stream() + .map(OperatorIDPair::getGeneratedOperatorID) + .collect(Collectors.toSet()); + } + + private static Map<OperatorID, Long> fromCompletedCheckpoint(CompletedCheckpoint cp) { + Stream<Map.Entry<OperatorID, OperatorState>> states = + cp.getOperatorStates().entrySet().stream(); + Map<OperatorID, Long> estimates = + states.collect( + toMap(Map.Entry::getKey, e -> estimateKeyGroupStateSize(e.getValue()))); + return estimates; + } + + private static long estimateKeyGroupStateSize(OperatorState state) { Review Comment: ```suggestion private static long calculateAverageKeyGroupStateSizeInBytes(OperatorState state) { ``` ########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/StateSizeEstimates.java: ########## @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptive.allocator; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.OperatorIDPair; +import org.apache.flink.runtime.checkpoint.CompletedCheckpoint; +import org.apache.flink.runtime.checkpoint.OperatorState; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.KeyedStateHandle; + +import javax.annotation.Nullable; + +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.stream.Collectors.toMap; + +/** Managed Keyed State size estimates used to make scheduling decisions. */ +class StateSizeEstimates { Review Comment: Can we have some tests for this class? ########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/StateSizeEstimates.java: ########## @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptive.allocator; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.OperatorIDPair; +import org.apache.flink.runtime.checkpoint.CompletedCheckpoint; +import org.apache.flink.runtime.checkpoint.OperatorState; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.KeyedStateHandle; + +import javax.annotation.Nullable; + +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.stream.Collectors.toMap; + +/** Managed Keyed State size estimates used to make scheduling decisions. */ +class StateSizeEstimates { + private final Map<JobVertexID, Long> averages; + + public StateSizeEstimates() { + this(Collections.emptyMap()); + } + + public StateSizeEstimates(Map<JobVertexID, Long> averages) { + this.averages = averages; + } + + public Optional<Long> estimate(JobVertexID jobVertexId) { + return Optional.ofNullable(averages.get(jobVertexId)); + } + + static StateSizeEstimates empty() { + return new StateSizeEstimates(); + } + + static StateSizeEstimates fromGraph(@Nullable ExecutionGraph executionGraph) { + return Optional.ofNullable(executionGraph) + .flatMap(graph -> Optional.ofNullable(graph.getCheckpointCoordinator())) + .flatMap(coordinator -> Optional.ofNullable(coordinator.getCheckpointStore())) + .flatMap(store -> Optional.ofNullable(store.getLatestCheckpoint())) + .map( + cp -> + build( + fromCompletedCheckpoint(cp), + mapVerticesToOperators(executionGraph))) + .orElse(empty()); + } + + private static StateSizeEstimates build( + Map<OperatorID, Long> sizePerOperator, + Map<JobVertexID, Set<OperatorID>> verticesToOperators) { + Map<JobVertexID, Long> verticesToSizes = + verticesToOperators.entrySet().stream() + .collect( + toMap(Map.Entry::getKey, e -> size(e.getValue(), sizePerOperator))); + return new StateSizeEstimates(verticesToSizes); + } + + private static long size(Set<OperatorID> ids, Map<OperatorID, Long> sizes) { + return ids.stream() + .mapToLong(key -> sizes.getOrDefault(key, 0L)) + .boxed() + .reduce(Long::sum) + .orElse(0L); + } + + private static Map<JobVertexID, Set<OperatorID>> mapVerticesToOperators( + ExecutionGraph executionGraph) { + return executionGraph.getAllVertices().entrySet().stream() + .collect(toMap(Map.Entry::getKey, e -> getOperatorIDS(e.getValue()))); + } + + private static Set<OperatorID> getOperatorIDS(ExecutionJobVertex v) { + return v.getOperatorIDs().stream() + .map(OperatorIDPair::getGeneratedOperatorID) + .collect(Collectors.toSet()); + } + + private static Map<OperatorID, Long> fromCompletedCheckpoint(CompletedCheckpoint cp) { + Stream<Map.Entry<OperatorID, OperatorState>> states = + cp.getOperatorStates().entrySet().stream(); + Map<OperatorID, Long> estimates = + states.collect( + toMap(Map.Entry::getKey, e -> estimateKeyGroupStateSize(e.getValue()))); + return estimates; + } + + private static long estimateKeyGroupStateSize(OperatorState state) { + Stream<KeyedStateHandle> handles = + state.getSubtaskStates().values().stream() + .flatMap(s -> s.getManagedKeyedState().stream()); + Stream<Tuple2<Long, Integer>> sizeAndCount = + handles.map( + h -> + Tuple2.of( + h.getStateSize(), + h.getKeyGroupRange().getNumberOfKeyGroups())); + Optional<Tuple2<Long, Integer>> totalSizeAndCount = + sizeAndCount.reduce( + (left, right) -> Tuple2.of(left.f0 + right.f0, left.f1 + right.f1)); + Optional<Long> average = totalSizeAndCount.filter(t2 -> t2.f1 > 0).map(t2 -> t2.f0 / t2.f1); + return average.orElse(0L); Review Comment: This method is tough to understand. I'm wondering whether the average size of the KeyGroup across states is what we should be looking for here 🤔 Wouldn't the maximum average size be more representative (assuming we're not worried about redistributing small states)? ```suggestion return state.getSubtaskStates().values().stream() .map(OperatorSubtaskState::getManagedKeyedState) .flatMap(Collection::stream) .mapToLong(handle -> handle.getStateSize() / handle.getKeyGroupRange().getNumberOfKeyGroups()) .max() .orElse(0L); ``` ########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/StateLocalitySlotAssigner.java: ########## @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptive.allocator; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.clusterframework.types.AllocationID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; +import org.apache.flink.runtime.jobmaster.SlotInfo; +import org.apache.flink.runtime.scheduler.adaptive.JobSchedulingPlan.SlotAssignment; +import org.apache.flink.runtime.scheduler.adaptive.allocator.SlotSharingSlotAllocator.ExecutionSlotSharingGroup; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupRangeAssignment; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.PriorityQueue; + +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toMap; +import static org.apache.flink.runtime.scheduler.adaptive.allocator.DefaultSlotAssigner.createExecutionSlotSharingGroups; + +/** A {@link SlotAssigner} that assigns slots based on the number of local key groups. */ +@Internal +public class StateLocalitySlotAssigner implements SlotAssigner { + + private static class AllocationScore implements Comparable<AllocationScore> { + + private final String group; + private final AllocationID allocationId; + + public AllocationScore(String group, AllocationID allocationId, long score) { + this.group = group; + this.allocationId = allocationId; + this.score = score; + } + + private final long score; + + public String getGroup() { + return group; + } + + public AllocationID getAllocationId() { + return allocationId; + } + + public long getScore() { + return score; + } + + @Override + public int compareTo(StateLocalitySlotAssigner.AllocationScore other) { + int result = Long.compare(score, other.score); + if (result != 0) { + return result; + } + result = other.allocationId.compareTo(allocationId); + if (result != 0) { + return result; + } + return other.group.compareTo(group); + } + } + + @Override + public Collection<SlotAssignment> assignSlots( + JobInformation jobInformation, + Collection<? extends SlotInfo> freeSlots, + VertexParallelism vertexParallelism, + AllocationsInfo previousAllocations, + StateSizeEstimates stateSizeEstimates) { + final List<ExecutionSlotSharingGroup> allGroups = new ArrayList<>(); + for (SlotSharingGroup slotSharingGroup : jobInformation.getSlotSharingGroups()) { + allGroups.addAll(createExecutionSlotSharingGroups(vertexParallelism, slotSharingGroup)); + } + final Map<JobVertexID, Integer> parallelism = getParallelism(allGroups); + + // PQ orders the pairs (allocationID, groupID) by score, decreasing + // the score is computed as the potential amount of state that would reside locally + final PriorityQueue<AllocationScore> scores = + new PriorityQueue<>(Comparator.reverseOrder()); + for (ExecutionSlotSharingGroup group : allGroups) { + calculateScore( + group, + parallelism, + jobInformation, + previousAllocations, + stateSizeEstimates) + .entrySet().stream() + .map(e -> new AllocationScore(group.getId(), e.getKey(), e.getValue())) + .forEach(scores::add); + } + + Map<String, ExecutionSlotSharingGroup> groupsById = + allGroups.stream().collect(toMap(ExecutionSlotSharingGroup::getId, identity())); + Map<AllocationID, SlotInfo> slotsById = + freeSlots.stream().collect(toMap(SlotInfo::getAllocationId, identity())); + AllocationScore score; + final Collection<SlotAssignment> assignments = new ArrayList<>(); + while ((score = scores.poll()) != null) { + SlotInfo slot = slotsById.remove(score.getAllocationId()); + if (slot != null) { + ExecutionSlotSharingGroup group = groupsById.remove(score.getGroup()); + if (group != null) { + assignments.add(new SlotAssignment(slot, group)); + } + } + } + // Distribute the remaining slots with no score + Iterator<? extends SlotInfo> remainingSlots = slotsById.values().iterator(); + for (ExecutionSlotSharingGroup group : groupsById.values()) { + assignments.add(new SlotAssignment(remainingSlots.next(), group)); + remainingSlots.remove(); + } + + return assignments; + } + + private static Map<JobVertexID, Integer> getParallelism( + List<ExecutionSlotSharingGroup> groups) { + final Map<JobVertexID, Integer> parallelism = new HashMap<>(); + for (ExecutionSlotSharingGroup group : groups) { + for (ExecutionVertexID evi : group.getContainedExecutionVertices()) { + parallelism.merge(evi.getJobVertexId(), 1, Integer::sum); + } + } + return parallelism; + } + + public Map<AllocationID, Long> calculateScore( + ExecutionSlotSharingGroup group, + Map<JobVertexID, Integer> parallelism, + JobInformation jobInformation, + AllocationsInfo previousAllocations, + StateSizeEstimates stateSizeEstimates) { + final Map<AllocationID, Long> score = new HashMap<>(); + for (ExecutionVertexID evi : group.getContainedExecutionVertices()) { + final KeyGroupRange kgr = + KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex( + jobInformation + .getVertexInformation(evi.getJobVertexId()) + .getMaxParallelism(), + parallelism.get(evi.getJobVertexId()), + evi.getSubtaskIndex()); + // Estimate state size per key group. For scoring, assume 1 if size estimate is 0 to + // accommodate for averaging non-zero states Review Comment: Do we care about states that have less than a byte on average? Can this even happen? Wouldn't this state be inlined anyway? We could then simplify `#estimate` to always return a number and only consider ones larger than zero. ########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/DefaultSlotAssigner.java: ########## @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptive.allocator; + +import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; +import org.apache.flink.runtime.jobmaster.SlotInfo; +import org.apache.flink.runtime.scheduler.adaptive.JobSchedulingPlan.SlotAssignment; +import org.apache.flink.runtime.scheduler.adaptive.allocator.SlotSharingSlotAllocator.ExecutionSlotSharingGroup; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** Simple {@link SlotAssigner} that treats all slots and slot sharing groups equally. */ +public class DefaultSlotAssigner implements SlotAssigner { + + @Override + public Collection<SlotAssignment> assignSlots( + JobInformation jobInformation, + Collection<? extends SlotInfo> freeSlots, + VertexParallelism vertexParallelism, + AllocationsInfo previousAllocations, + StateSizeEstimates stateSizeEstimates) { + List<ExecutionSlotSharingGroup> allGroups = new ArrayList<>(); + for (SlotSharingGroup slotSharingGroup : jobInformation.getSlotSharingGroups()) { + allGroups.addAll(createExecutionSlotSharingGroups(vertexParallelism, slotSharingGroup)); + } + + Iterator<? extends SlotInfo> iterator = freeSlots.iterator(); + Collection<SlotAssignment> assignments = new ArrayList<>(); + for (ExecutionSlotSharingGroup group : allGroups) { + assignments.add(new SlotAssignment(iterator.next(), group)); Review Comment: Should we have a safeguard against running out of free slots? ########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/StateSizeEstimates.java: ########## @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptive.allocator; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.OperatorIDPair; +import org.apache.flink.runtime.checkpoint.CompletedCheckpoint; +import org.apache.flink.runtime.checkpoint.OperatorState; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.KeyedStateHandle; + +import javax.annotation.Nullable; + +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.stream.Collectors.toMap; + +/** Managed Keyed State size estimates used to make scheduling decisions. */ +class StateSizeEstimates { + private final Map<JobVertexID, Long> averages; + + public StateSizeEstimates() { + this(Collections.emptyMap()); + } + + public StateSizeEstimates(Map<JobVertexID, Long> averages) { + this.averages = averages; + } + + public Optional<Long> estimate(JobVertexID jobVertexId) { + return Optional.ofNullable(averages.get(jobVertexId)); + } + + static StateSizeEstimates empty() { + return new StateSizeEstimates(); + } + + static StateSizeEstimates fromGraph(@Nullable ExecutionGraph executionGraph) { + return Optional.ofNullable(executionGraph) + .flatMap(graph -> Optional.ofNullable(graph.getCheckpointCoordinator())) + .flatMap(coordinator -> Optional.ofNullable(coordinator.getCheckpointStore())) + .flatMap(store -> Optional.ofNullable(store.getLatestCheckpoint())) + .map( + cp -> + build( + fromCompletedCheckpoint(cp), + mapVerticesToOperators(executionGraph))) + .orElse(empty()); + } + + private static StateSizeEstimates build( + Map<OperatorID, Long> sizePerOperator, + Map<JobVertexID, Set<OperatorID>> verticesToOperators) { + Map<JobVertexID, Long> verticesToSizes = + verticesToOperators.entrySet().stream() + .collect( + toMap(Map.Entry::getKey, e -> size(e.getValue(), sizePerOperator))); + return new StateSizeEstimates(verticesToSizes); + } + + private static long size(Set<OperatorID> ids, Map<OperatorID, Long> sizes) { + return ids.stream() + .mapToLong(key -> sizes.getOrDefault(key, 0L)) + .boxed() + .reduce(Long::sum) + .orElse(0L); Review Comment: ```suggestion .sum(); ``` ########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/AllocationsInfo.java: ########## @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptive.allocator; + +import org.apache.flink.runtime.clusterframework.types.AllocationID; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupRangeAssignment; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; + +import java.util.HashMap; +import java.util.Map; + +import static java.util.Collections.emptyMap; + +class AllocationsInfo { Review Comment: I'm wondering whether we could incorporate size estimates into the output of this class so we only have a single data structure that we need to pass around 🤔 Also, looking at how `#getAllocations` is used, what we want is a data structure that is indexed by JobVertexID, so we don't have to iterate over all allocations all over again. Something along the lines of: ``` static class AllocatedState { private final AllocationId allocationId; private final KeyGroupRange keyGroupRange; private final long keyGroupSizeEstimateInBytes; } Collection<AllocatedState> getAllocatedStates(JobVertexID); ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org