dmvk commented on code in PR #21981:
URL: https://github.com/apache/flink/pull/21981#discussion_r1116631711


##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/JobSchedulingPlan.java:
##########
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptive;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.jobmaster.SlotInfo;
+import org.apache.flink.runtime.scheduler.adaptive.allocator.VertexParallelism;
+
+import java.util.Collection;
+
+/**
+ * A plan that describes how to execute {@link 
org.apache.flink.runtime.jobgraph.JobGraph JobGraph}.
+ *
+ * <ol>
+ *   <li>{@link #vertexParallelism} is necessary to create {@link
+ *       org.apache.flink.runtime.executiongraph.ExecutionGraph ExecutionGraph}
+ *   <li>{@link #slotAssignments} are used to schedule it onto the cluster
+ * </ol>
+ *
+ * {@link AdaptiveScheduler} passes this structure from {@link 
WaitingForResources} to {@link
+ * CreatingExecutionGraph} stages.
+ */
+@Internal
+public class JobSchedulingPlan {
+    private final VertexParallelism vertexParallelism;
+    private final Collection<SlotAssignment> slotAssignments;
+
+    public JobSchedulingPlan(
+            VertexParallelism vertexParallelism, Collection<SlotAssignment> 
slotAssignments) {
+        this.vertexParallelism = vertexParallelism;
+        this.slotAssignments = slotAssignments;
+    }
+
+    public VertexParallelism getVertexParallelism() {
+        return vertexParallelism;
+    }
+
+    public Collection<SlotAssignment> getSlotAssignments() {
+        return slotAssignments;
+    }
+
+    /** Assignment of a slot to some target (e.g. a slot sharing group). */
+    public static class SlotAssignment {
+        private final SlotInfo slotInfo;
+        /**
+         * Interpreted by {@link
+         * 
org.apache.flink.runtime.scheduler.adaptive.allocator.SlotAllocator#tryReserveResources(JobSchedulingPlan)}.
+         * This can be a slot sharing group, a task, or something else.
+         */
+        private final Object target;

Review Comment:
   In the current code this is always `ExecutionSlotSharingGroup`



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/SlotSharingSlotAllocator.java:
##########
@@ -93,20 +93,15 @@ private static Map<SlotSharingGroupId, Integer> 
getMaxParallelismForSlotSharingG
     }
 
     @Override
-    public Optional<VertexParallelismWithSlotSharing> determineParallelism(
+    public Optional<? extends VertexParallelism> determineParallelism(

Review Comment:
   Since there are no classes extending from `VertexParallelism`, would it make 
sense to change this?
   ```suggestion
       public Optional<VertexParallelism> determineParallelism(
   ```



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/StateLocalitySlotAssigner.java:
##########
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptive.allocator;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.clusterframework.types.AllocationID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
+import org.apache.flink.runtime.jobmaster.SlotInfo;
+import 
org.apache.flink.runtime.scheduler.adaptive.JobSchedulingPlan.SlotAssignment;
+import 
org.apache.flink.runtime.scheduler.adaptive.allocator.SlotSharingSlotAllocator.ExecutionSlotSharingGroup;
+import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.PriorityQueue;
+
+import static java.util.function.Function.identity;
+import static java.util.stream.Collectors.toMap;
+import static 
org.apache.flink.runtime.scheduler.adaptive.allocator.DefaultSlotAssigner.createExecutionSlotSharingGroups;
+
+/** A {@link SlotAssigner} that assigns slots based on the number of local key 
groups. */
+@Internal
+public class StateLocalitySlotAssigner implements SlotAssigner {
+
+    private static class AllocationScore implements 
Comparable<AllocationScore> {
+
+        private final String group;
+        private final AllocationID allocationId;
+
+        public AllocationScore(String group, AllocationID allocationId, long 
score) {
+            this.group = group;
+            this.allocationId = allocationId;
+            this.score = score;
+        }
+
+        private final long score;
+
+        public String getGroup() {
+            return group;
+        }
+
+        public AllocationID getAllocationId() {
+            return allocationId;
+        }
+
+        public long getScore() {
+            return score;
+        }
+
+        @Override
+        public int compareTo(StateLocalitySlotAssigner.AllocationScore other) {
+            int result = Long.compare(score, other.score);
+            if (result != 0) {
+                return result;
+            }
+            result = other.allocationId.compareTo(allocationId);
+            if (result != 0) {
+                return result;
+            }
+            return other.group.compareTo(group);
+        }
+    }
+
+    @Override
+    public Collection<SlotAssignment> assignSlots(
+            JobInformation jobInformation,
+            Collection<? extends SlotInfo> freeSlots,
+            VertexParallelism vertexParallelism,
+            AllocationsInfo previousAllocations,
+            StateSizeEstimates stateSizeEstimates) {
+        final List<ExecutionSlotSharingGroup> allGroups = new ArrayList<>();
+        for (SlotSharingGroup slotSharingGroup : 
jobInformation.getSlotSharingGroups()) {
+            
allGroups.addAll(createExecutionSlotSharingGroups(vertexParallelism, 
slotSharingGroup));
+        }
+        final Map<JobVertexID, Integer> parallelism = 
getParallelism(allGroups);
+
+        // PQ orders the pairs (allocationID, groupID) by score, decreasing
+        // the score is computed as the potential amount of state that would 
reside locally
+        final PriorityQueue<AllocationScore> scores =
+                new PriorityQueue<>(Comparator.reverseOrder());
+        for (ExecutionSlotSharingGroup group : allGroups) {
+            calculateScore(
+                            group,
+                            parallelism,
+                            jobInformation,
+                            previousAllocations,
+                            stateSizeEstimates)
+                    .entrySet().stream()
+                    .map(e -> new AllocationScore(group.getId(), e.getKey(), 
e.getValue()))
+                    .forEach(scores::add);
+        }
+
+        Map<String, ExecutionSlotSharingGroup> groupsById =
+                
allGroups.stream().collect(toMap(ExecutionSlotSharingGroup::getId, identity()));
+        Map<AllocationID, SlotInfo> slotsById =
+                freeSlots.stream().collect(toMap(SlotInfo::getAllocationId, 
identity()));
+        AllocationScore score;
+        final Collection<SlotAssignment> assignments = new ArrayList<>();
+        while ((score = scores.poll()) != null) {
+            SlotInfo slot = slotsById.remove(score.getAllocationId());
+            if (slot != null) {
+                ExecutionSlotSharingGroup group = 
groupsById.remove(score.getGroup());
+                if (group != null) {

Review Comment:
   If the group happens to be null, does it mean we're wasting a slot (because 
it has already been removed from the slotById table)?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/StateSizeEstimates.java:
##########
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptive.allocator;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.OperatorIDPair;
+import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
+import org.apache.flink.runtime.checkpoint.OperatorState;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+
+import javax.annotation.Nullable;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static java.util.stream.Collectors.toMap;
+
+/** Managed Keyed State size estimates used to make scheduling decisions. */
+class StateSizeEstimates {
+    private final Map<JobVertexID, Long> averages;
+
+    public StateSizeEstimates() {
+        this(Collections.emptyMap());
+    }
+
+    public StateSizeEstimates(Map<JobVertexID, Long> averages) {
+        this.averages = averages;
+    }
+
+    public Optional<Long> estimate(JobVertexID jobVertexId) {
+        return Optional.ofNullable(averages.get(jobVertexId));
+    }
+
+    static StateSizeEstimates empty() {
+        return new StateSizeEstimates();
+    }
+
+    static StateSizeEstimates fromGraph(@Nullable ExecutionGraph 
executionGraph) {
+        return Optional.ofNullable(executionGraph)
+                .flatMap(graph -> 
Optional.ofNullable(graph.getCheckpointCoordinator()))
+                .flatMap(coordinator -> 
Optional.ofNullable(coordinator.getCheckpointStore()))
+                .flatMap(store -> 
Optional.ofNullable(store.getLatestCheckpoint()))
+                .map(
+                        cp ->
+                                build(
+                                        fromCompletedCheckpoint(cp),
+                                        
mapVerticesToOperators(executionGraph)))
+                .orElse(empty());
+    }
+
+    private static StateSizeEstimates build(
+            Map<OperatorID, Long> sizePerOperator,
+            Map<JobVertexID, Set<OperatorID>> verticesToOperators) {
+        Map<JobVertexID, Long> verticesToSizes =
+                verticesToOperators.entrySet().stream()
+                        .collect(
+                                toMap(Map.Entry::getKey, e -> 
size(e.getValue(), sizePerOperator)));
+        return new StateSizeEstimates(verticesToSizes);
+    }
+
+    private static long size(Set<OperatorID> ids, Map<OperatorID, Long> sizes) 
{
+        return ids.stream()
+                .mapToLong(key -> sizes.getOrDefault(key, 0L))
+                .boxed()
+                .reduce(Long::sum)
+                .orElse(0L);
+    }
+
+    private static Map<JobVertexID, Set<OperatorID>> mapVerticesToOperators(
+            ExecutionGraph executionGraph) {
+        return executionGraph.getAllVertices().entrySet().stream()
+                .collect(toMap(Map.Entry::getKey, e -> 
getOperatorIDS(e.getValue())));
+    }
+
+    private static Set<OperatorID> getOperatorIDS(ExecutionJobVertex v) {
+        return v.getOperatorIDs().stream()
+                .map(OperatorIDPair::getGeneratedOperatorID)
+                .collect(Collectors.toSet());
+    }
+
+    private static Map<OperatorID, Long> 
fromCompletedCheckpoint(CompletedCheckpoint cp) {
+        Stream<Map.Entry<OperatorID, OperatorState>> states =
+                cp.getOperatorStates().entrySet().stream();
+        Map<OperatorID, Long> estimates =
+                states.collect(
+                        toMap(Map.Entry::getKey, e -> 
estimateKeyGroupStateSize(e.getValue())));
+        return estimates;
+    }
+
+    private static long estimateKeyGroupStateSize(OperatorState state) {

Review Comment:
   ```suggestion
       private static long 
calculateAverageKeyGroupStateSizeInBytes(OperatorState state) {
   ```



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/StateSizeEstimates.java:
##########
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptive.allocator;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.OperatorIDPair;
+import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
+import org.apache.flink.runtime.checkpoint.OperatorState;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+
+import javax.annotation.Nullable;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static java.util.stream.Collectors.toMap;
+
+/** Managed Keyed State size estimates used to make scheduling decisions. */
+class StateSizeEstimates {

Review Comment:
   Can we have some tests for this class?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/StateSizeEstimates.java:
##########
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptive.allocator;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.OperatorIDPair;
+import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
+import org.apache.flink.runtime.checkpoint.OperatorState;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+
+import javax.annotation.Nullable;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static java.util.stream.Collectors.toMap;
+
+/** Managed Keyed State size estimates used to make scheduling decisions. */
+class StateSizeEstimates {
+    private final Map<JobVertexID, Long> averages;
+
+    public StateSizeEstimates() {
+        this(Collections.emptyMap());
+    }
+
+    public StateSizeEstimates(Map<JobVertexID, Long> averages) {
+        this.averages = averages;
+    }
+
+    public Optional<Long> estimate(JobVertexID jobVertexId) {
+        return Optional.ofNullable(averages.get(jobVertexId));
+    }
+
+    static StateSizeEstimates empty() {
+        return new StateSizeEstimates();
+    }
+
+    static StateSizeEstimates fromGraph(@Nullable ExecutionGraph 
executionGraph) {
+        return Optional.ofNullable(executionGraph)
+                .flatMap(graph -> 
Optional.ofNullable(graph.getCheckpointCoordinator()))
+                .flatMap(coordinator -> 
Optional.ofNullable(coordinator.getCheckpointStore()))
+                .flatMap(store -> 
Optional.ofNullable(store.getLatestCheckpoint()))
+                .map(
+                        cp ->
+                                build(
+                                        fromCompletedCheckpoint(cp),
+                                        
mapVerticesToOperators(executionGraph)))
+                .orElse(empty());
+    }
+
+    private static StateSizeEstimates build(
+            Map<OperatorID, Long> sizePerOperator,
+            Map<JobVertexID, Set<OperatorID>> verticesToOperators) {
+        Map<JobVertexID, Long> verticesToSizes =
+                verticesToOperators.entrySet().stream()
+                        .collect(
+                                toMap(Map.Entry::getKey, e -> 
size(e.getValue(), sizePerOperator)));
+        return new StateSizeEstimates(verticesToSizes);
+    }
+
+    private static long size(Set<OperatorID> ids, Map<OperatorID, Long> sizes) 
{
+        return ids.stream()
+                .mapToLong(key -> sizes.getOrDefault(key, 0L))
+                .boxed()
+                .reduce(Long::sum)
+                .orElse(0L);
+    }
+
+    private static Map<JobVertexID, Set<OperatorID>> mapVerticesToOperators(
+            ExecutionGraph executionGraph) {
+        return executionGraph.getAllVertices().entrySet().stream()
+                .collect(toMap(Map.Entry::getKey, e -> 
getOperatorIDS(e.getValue())));
+    }
+
+    private static Set<OperatorID> getOperatorIDS(ExecutionJobVertex v) {
+        return v.getOperatorIDs().stream()
+                .map(OperatorIDPair::getGeneratedOperatorID)
+                .collect(Collectors.toSet());
+    }
+
+    private static Map<OperatorID, Long> 
fromCompletedCheckpoint(CompletedCheckpoint cp) {
+        Stream<Map.Entry<OperatorID, OperatorState>> states =
+                cp.getOperatorStates().entrySet().stream();
+        Map<OperatorID, Long> estimates =
+                states.collect(
+                        toMap(Map.Entry::getKey, e -> 
estimateKeyGroupStateSize(e.getValue())));
+        return estimates;
+    }
+
+    private static long estimateKeyGroupStateSize(OperatorState state) {
+        Stream<KeyedStateHandle> handles =
+                state.getSubtaskStates().values().stream()
+                        .flatMap(s -> s.getManagedKeyedState().stream());
+        Stream<Tuple2<Long, Integer>> sizeAndCount =
+                handles.map(
+                        h ->
+                                Tuple2.of(
+                                        h.getStateSize(),
+                                        
h.getKeyGroupRange().getNumberOfKeyGroups()));
+        Optional<Tuple2<Long, Integer>> totalSizeAndCount =
+                sizeAndCount.reduce(
+                        (left, right) -> Tuple2.of(left.f0 + right.f0, left.f1 
+ right.f1));
+        Optional<Long> average = totalSizeAndCount.filter(t2 -> t2.f1 > 
0).map(t2 -> t2.f0 / t2.f1);
+        return average.orElse(0L);

Review Comment:
   This method is tough to understand. I'm wondering whether the average size 
of the KeyGroup across states is what we should be looking for here 🤔 Wouldn't 
the maximum average size be more representative (assuming we're not worried 
about redistributing small states)?
   ```suggestion
           return state.getSubtaskStates().values().stream()
                   .map(OperatorSubtaskState::getManagedKeyedState)
                   .flatMap(Collection::stream)
                   .mapToLong(handle -> handle.getStateSize() / 
handle.getKeyGroupRange().getNumberOfKeyGroups())
                   .max()
                   .orElse(0L);
   ```



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/StateLocalitySlotAssigner.java:
##########
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptive.allocator;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.clusterframework.types.AllocationID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
+import org.apache.flink.runtime.jobmaster.SlotInfo;
+import 
org.apache.flink.runtime.scheduler.adaptive.JobSchedulingPlan.SlotAssignment;
+import 
org.apache.flink.runtime.scheduler.adaptive.allocator.SlotSharingSlotAllocator.ExecutionSlotSharingGroup;
+import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.PriorityQueue;
+
+import static java.util.function.Function.identity;
+import static java.util.stream.Collectors.toMap;
+import static 
org.apache.flink.runtime.scheduler.adaptive.allocator.DefaultSlotAssigner.createExecutionSlotSharingGroups;
+
+/** A {@link SlotAssigner} that assigns slots based on the number of local key 
groups. */
+@Internal
+public class StateLocalitySlotAssigner implements SlotAssigner {
+
+    private static class AllocationScore implements 
Comparable<AllocationScore> {
+
+        private final String group;
+        private final AllocationID allocationId;
+
+        public AllocationScore(String group, AllocationID allocationId, long 
score) {
+            this.group = group;
+            this.allocationId = allocationId;
+            this.score = score;
+        }
+
+        private final long score;
+
+        public String getGroup() {
+            return group;
+        }
+
+        public AllocationID getAllocationId() {
+            return allocationId;
+        }
+
+        public long getScore() {
+            return score;
+        }
+
+        @Override
+        public int compareTo(StateLocalitySlotAssigner.AllocationScore other) {
+            int result = Long.compare(score, other.score);
+            if (result != 0) {
+                return result;
+            }
+            result = other.allocationId.compareTo(allocationId);
+            if (result != 0) {
+                return result;
+            }
+            return other.group.compareTo(group);
+        }
+    }
+
+    @Override
+    public Collection<SlotAssignment> assignSlots(
+            JobInformation jobInformation,
+            Collection<? extends SlotInfo> freeSlots,
+            VertexParallelism vertexParallelism,
+            AllocationsInfo previousAllocations,
+            StateSizeEstimates stateSizeEstimates) {
+        final List<ExecutionSlotSharingGroup> allGroups = new ArrayList<>();
+        for (SlotSharingGroup slotSharingGroup : 
jobInformation.getSlotSharingGroups()) {
+            
allGroups.addAll(createExecutionSlotSharingGroups(vertexParallelism, 
slotSharingGroup));
+        }
+        final Map<JobVertexID, Integer> parallelism = 
getParallelism(allGroups);
+
+        // PQ orders the pairs (allocationID, groupID) by score, decreasing
+        // the score is computed as the potential amount of state that would 
reside locally
+        final PriorityQueue<AllocationScore> scores =
+                new PriorityQueue<>(Comparator.reverseOrder());
+        for (ExecutionSlotSharingGroup group : allGroups) {
+            calculateScore(
+                            group,
+                            parallelism,
+                            jobInformation,
+                            previousAllocations,
+                            stateSizeEstimates)
+                    .entrySet().stream()
+                    .map(e -> new AllocationScore(group.getId(), e.getKey(), 
e.getValue()))
+                    .forEach(scores::add);
+        }
+
+        Map<String, ExecutionSlotSharingGroup> groupsById =
+                
allGroups.stream().collect(toMap(ExecutionSlotSharingGroup::getId, identity()));
+        Map<AllocationID, SlotInfo> slotsById =
+                freeSlots.stream().collect(toMap(SlotInfo::getAllocationId, 
identity()));
+        AllocationScore score;
+        final Collection<SlotAssignment> assignments = new ArrayList<>();
+        while ((score = scores.poll()) != null) {
+            SlotInfo slot = slotsById.remove(score.getAllocationId());
+            if (slot != null) {
+                ExecutionSlotSharingGroup group = 
groupsById.remove(score.getGroup());
+                if (group != null) {
+                    assignments.add(new SlotAssignment(slot, group));
+                }
+            }
+        }
+        // Distribute the remaining slots with no score
+        Iterator<? extends SlotInfo> remainingSlots = 
slotsById.values().iterator();
+        for (ExecutionSlotSharingGroup group : groupsById.values()) {
+            assignments.add(new SlotAssignment(remainingSlots.next(), group));
+            remainingSlots.remove();
+        }
+
+        return assignments;
+    }
+
+    private static Map<JobVertexID, Integer> getParallelism(
+            List<ExecutionSlotSharingGroup> groups) {
+        final Map<JobVertexID, Integer> parallelism = new HashMap<>();
+        for (ExecutionSlotSharingGroup group : groups) {
+            for (ExecutionVertexID evi : 
group.getContainedExecutionVertices()) {
+                parallelism.merge(evi.getJobVertexId(), 1, Integer::sum);
+            }
+        }
+        return parallelism;
+    }
+
+    public Map<AllocationID, Long> calculateScore(
+            ExecutionSlotSharingGroup group,
+            Map<JobVertexID, Integer> parallelism,
+            JobInformation jobInformation,
+            AllocationsInfo previousAllocations,
+            StateSizeEstimates stateSizeEstimates) {
+        final Map<AllocationID, Long> score = new HashMap<>();
+        for (ExecutionVertexID evi : group.getContainedExecutionVertices()) {
+            final KeyGroupRange kgr =
+                    
KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(
+                            jobInformation
+                                    .getVertexInformation(evi.getJobVertexId())
+                                    .getMaxParallelism(),
+                            parallelism.get(evi.getJobVertexId()),
+                            evi.getSubtaskIndex());
+            // Estimate state size per key group. For scoring, assume 1 if 
size estimate is 0 to
+            // accommodate for averaging non-zero states

Review Comment:
   Do we care about states that have less than a byte on average? Can this even 
happen? Wouldn't this state be inlined anyway?
   
   We could then simplify `#estimate` to always return a number and only 
consider ones larger than zero.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/DefaultSlotAssigner.java:
##########
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptive.allocator;
+
+import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
+import org.apache.flink.runtime.jobmaster.SlotInfo;
+import 
org.apache.flink.runtime.scheduler.adaptive.JobSchedulingPlan.SlotAssignment;
+import 
org.apache.flink.runtime.scheduler.adaptive.allocator.SlotSharingSlotAllocator.ExecutionSlotSharingGroup;
+import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/** Simple {@link SlotAssigner} that treats all slots and slot sharing groups 
equally. */
+public class DefaultSlotAssigner implements SlotAssigner {
+
+    @Override
+    public Collection<SlotAssignment> assignSlots(
+            JobInformation jobInformation,
+            Collection<? extends SlotInfo> freeSlots,
+            VertexParallelism vertexParallelism,
+            AllocationsInfo previousAllocations,
+            StateSizeEstimates stateSizeEstimates) {
+        List<ExecutionSlotSharingGroup> allGroups = new ArrayList<>();
+        for (SlotSharingGroup slotSharingGroup : 
jobInformation.getSlotSharingGroups()) {
+            
allGroups.addAll(createExecutionSlotSharingGroups(vertexParallelism, 
slotSharingGroup));
+        }
+
+        Iterator<? extends SlotInfo> iterator = freeSlots.iterator();
+        Collection<SlotAssignment> assignments = new ArrayList<>();
+        for (ExecutionSlotSharingGroup group : allGroups) {
+            assignments.add(new SlotAssignment(iterator.next(), group));

Review Comment:
   Should we have a safeguard against running out of free slots?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/StateSizeEstimates.java:
##########
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptive.allocator;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.OperatorIDPair;
+import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
+import org.apache.flink.runtime.checkpoint.OperatorState;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+
+import javax.annotation.Nullable;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static java.util.stream.Collectors.toMap;
+
+/** Managed Keyed State size estimates used to make scheduling decisions. */
+class StateSizeEstimates {
+    private final Map<JobVertexID, Long> averages;
+
+    public StateSizeEstimates() {
+        this(Collections.emptyMap());
+    }
+
+    public StateSizeEstimates(Map<JobVertexID, Long> averages) {
+        this.averages = averages;
+    }
+
+    public Optional<Long> estimate(JobVertexID jobVertexId) {
+        return Optional.ofNullable(averages.get(jobVertexId));
+    }
+
+    static StateSizeEstimates empty() {
+        return new StateSizeEstimates();
+    }
+
+    static StateSizeEstimates fromGraph(@Nullable ExecutionGraph 
executionGraph) {
+        return Optional.ofNullable(executionGraph)
+                .flatMap(graph -> 
Optional.ofNullable(graph.getCheckpointCoordinator()))
+                .flatMap(coordinator -> 
Optional.ofNullable(coordinator.getCheckpointStore()))
+                .flatMap(store -> 
Optional.ofNullable(store.getLatestCheckpoint()))
+                .map(
+                        cp ->
+                                build(
+                                        fromCompletedCheckpoint(cp),
+                                        
mapVerticesToOperators(executionGraph)))
+                .orElse(empty());
+    }
+
+    private static StateSizeEstimates build(
+            Map<OperatorID, Long> sizePerOperator,
+            Map<JobVertexID, Set<OperatorID>> verticesToOperators) {
+        Map<JobVertexID, Long> verticesToSizes =
+                verticesToOperators.entrySet().stream()
+                        .collect(
+                                toMap(Map.Entry::getKey, e -> 
size(e.getValue(), sizePerOperator)));
+        return new StateSizeEstimates(verticesToSizes);
+    }
+
+    private static long size(Set<OperatorID> ids, Map<OperatorID, Long> sizes) 
{
+        return ids.stream()
+                .mapToLong(key -> sizes.getOrDefault(key, 0L))
+                .boxed()
+                .reduce(Long::sum)
+                .orElse(0L);

Review Comment:
   ```suggestion
                   .sum();
   ```



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/allocator/AllocationsInfo.java:
##########
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptive.allocator;
+
+import org.apache.flink.runtime.clusterframework.types.AllocationID;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static java.util.Collections.emptyMap;
+
+class AllocationsInfo {

Review Comment:
   I'm wondering whether we could incorporate size estimates into the output of 
this class so we only have a single data structure that we need to pass around 🤔
   
   Also, looking at how `#getAllocations` is used, what we want is a data 
structure that is indexed by JobVertexID, so we don't have to iterate over all 
allocations all over again.
   
   Something along the lines of:
   
   ```
   static class AllocatedState {
     private final AllocationId allocationId;
     private final KeyGroupRange keyGroupRange;
     private final long keyGroupSizeEstimateInBytes;
   }
   
   Collection<AllocatedState> getAllocatedStates(JobVertexID);
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to