JunRuiLee commented on code in PR #26180:
URL: https://github.com/apache/flink/pull/26180#discussion_r1963052390


##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtils.java:
##########
@@ -658,4 +660,148 @@ public static long calculateDataVolumePerTaskForInput(
             long globalDataVolumePerTask, long inputsGroupBytes, long 
totalDataBytes) {
         return (long) ((double) inputsGroupBytes / totalDataBytes * 
globalDataVolumePerTask);
     }
+
+    public static Optional<String> constructOptimizationLog(
+            BlockingInputInfo inputInfo, JobVertexInputInfo 
jobVertexInputInfo) {
+        if (inputInfo.areInterInputsKeysCorrelated() && 
inputInfo.isIntraInputKeyCorrelated()) {
+            return Optional.empty();
+        }
+        boolean optimized = false;
+        List<ExecutionVertexInputInfo> executionVertexInputInfos =
+                jobVertexInputInfo.getExecutionVertexInputInfos();
+        int parallelism = executionVertexInputInfos.size();
+        long[] optimizedDataBytes = new long[parallelism];
+        long optimizedMin = Long.MAX_VALUE, optimizedMax = 0;
+        long[] nonOptimizedDataBytes = new long[parallelism];
+        long nonOptimizedMin = Long.MAX_VALUE, nonOptimizedMax = 0;
+        for (int i = 0; i < parallelism; ++i) {
+            Map<IndexRange, IndexRange> consumedSubpartitionGroups =
+                    
executionVertexInputInfos.get(i).getConsumedSubpartitionGroups();
+            for (Map.Entry<IndexRange, IndexRange> entry : 
consumedSubpartitionGroups.entrySet()) {
+                IndexRange partitionRange = entry.getKey();
+                IndexRange subpartitionRange = entry.getValue();
+                optimizedDataBytes[i] +=
+                        inputInfo.getNumBytesProduced(partitionRange, 
subpartitionRange);
+            }
+            optimizedMin = Math.min(optimizedMin, optimizedDataBytes[i]);
+            optimizedMax = Math.max(optimizedMax, optimizedDataBytes[i]);
+
+            Map<IndexRange, IndexRange> nonOptimizedConsumedSubpartitionGroup =
+                    computeNumBasedConsumedSubpartitionGroup(parallelism, i, 
inputInfo);
+            checkState(nonOptimizedConsumedSubpartitionGroup.size() == 1);
+            nonOptimizedDataBytes[i] +=
+                    inputInfo.getNumBytesProduced(
+                            nonOptimizedConsumedSubpartitionGroup
+                                    .entrySet()
+                                    .iterator()
+                                    .next()
+                                    .getKey(),
+                            nonOptimizedConsumedSubpartitionGroup
+                                    .entrySet()
+                                    .iterator()
+                                    .next()
+                                    .getValue());
+            nonOptimizedMin = Math.min(nonOptimizedMin, 
nonOptimizedDataBytes[i]);
+            nonOptimizedMax = Math.max(nonOptimizedMax, 
nonOptimizedDataBytes[i]);
+
+            if (!optimized
+                    && 
!consumedSubpartitionGroups.equals(nonOptimizedConsumedSubpartitionGroup)) {
+                optimized = true;
+            }
+        }
+        if (optimized) {
+            long optimizedMed = median(optimizedDataBytes);
+            long nonOptimizedMed = median(nonOptimizedDataBytes);
+            String logMessage =
+                    String.format(
+                            "Result id: %s, "
+                                    + "type number: %d, "
+                                    + "input data size: "
+                                    + "[ Before: {min: %s, median: %s, max: 
%s}, "
+                                    + "After: {min: %s, median: %s, max: %s} 
]",
+                            inputInfo.getResultId(),
+                            inputInfo.getInputTypeNumber(),
+                            new 
MemorySize(nonOptimizedMin).toHumanReadableString(),
+                            new 
MemorySize(nonOptimizedMed).toHumanReadableString(),
+                            new 
MemorySize(nonOptimizedMax).toHumanReadableString(),
+                            new 
MemorySize(optimizedMin).toHumanReadableString(),
+                            new 
MemorySize(optimizedMed).toHumanReadableString(),
+                            new 
MemorySize(optimizedMax).toHumanReadableString());
+            return Optional.of(logMessage);
+        }
+        return Optional.empty();
+    }
+
+    private static Map<IndexRange, IndexRange> 
computeNumBasedConsumedSubpartitionGroup(
+            int parallelism, int currentIndex, BlockingInputInfo inputInfo) {
+        int sourceParallelism = inputInfo.getNumPartitions();
+
+        if (inputInfo.isPointwise()) {
+            return computeNumBasedConsumedSubpartitionGroupForPointwise(
+                    sourceParallelism, parallelism, currentIndex, 
inputInfo::getNumSubpartitions);
+        } else {
+            return computeNumBasedConsumedSubpartitionGroupForAllToAll(
+                    sourceParallelism,
+                    parallelism,
+                    currentIndex,
+                    inputInfo::getNumSubpartitions,
+                    inputInfo.isBroadcast(),
+                    inputInfo.isSingleSubpartitionContainsAllData());
+        }
+    }
+
+    static Map<IndexRange, IndexRange> 
computeNumBasedConsumedSubpartitionGroupForPointwise(
+            int sourceCount,
+            int targetCount,
+            int currentIndex,
+            Function<Integer, Integer> numOfSubpartitionsRetriever) {
+        if (sourceCount >= targetCount) {
+            int start = currentIndex * sourceCount / targetCount;
+            int end = (currentIndex + 1) * sourceCount / targetCount;
+            IndexRange partitionRange = new IndexRange(start, end - 1);
+            IndexRange subpartitionRange =
+                    computeConsumedSubpartitionRange(
+                            currentIndex,
+                            1,
+                            () -> numOfSubpartitionsRetriever.apply(start),
+                            true,
+                            false,
+                            false);
+            return Map.of(partitionRange, subpartitionRange);
+        } else {
+            int partitionNum = (currentIndex * sourceCount - sourceCount + 1) 
/ targetCount;
+            int start = (partitionNum * targetCount + sourceCount - 1) / 
sourceCount;
+            int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / 
sourceCount;
+            int numConsumers = end - start;
+            IndexRange partitionRange = new IndexRange(partitionNum, 
partitionNum);
+            IndexRange subpartitionRange =
+                    computeConsumedSubpartitionRange(
+                            currentIndex,
+                            numConsumers,
+                            () -> 
numOfSubpartitionsRetriever.apply(partitionNum),
+                            true,
+                            false,
+                            false);
+            return Map.of(partitionRange, subpartitionRange);
+        }
+    }
+
+    static Map<IndexRange, IndexRange> 
computeNumBasedConsumedSubpartitionGroupForAllToAll(

Review Comment:
   ditto



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtils.java:
##########
@@ -658,4 +660,148 @@ public static long calculateDataVolumePerTaskForInput(
             long globalDataVolumePerTask, long inputsGroupBytes, long 
totalDataBytes) {
         return (long) ((double) inputsGroupBytes / totalDataBytes * 
globalDataVolumePerTask);
     }
+
+    public static Optional<String> constructOptimizationLog(
+            BlockingInputInfo inputInfo, JobVertexInputInfo 
jobVertexInputInfo) {
+        if (inputInfo.areInterInputsKeysCorrelated() && 
inputInfo.isIntraInputKeyCorrelated()) {
+            return Optional.empty();
+        }
+        boolean optimized = false;
+        List<ExecutionVertexInputInfo> executionVertexInputInfos =
+                jobVertexInputInfo.getExecutionVertexInputInfos();
+        int parallelism = executionVertexInputInfos.size();
+        long[] optimizedDataBytes = new long[parallelism];
+        long optimizedMin = Long.MAX_VALUE, optimizedMax = 0;
+        long[] nonOptimizedDataBytes = new long[parallelism];
+        long nonOptimizedMin = Long.MAX_VALUE, nonOptimizedMax = 0;
+        for (int i = 0; i < parallelism; ++i) {
+            Map<IndexRange, IndexRange> consumedSubpartitionGroups =
+                    
executionVertexInputInfos.get(i).getConsumedSubpartitionGroups();
+            for (Map.Entry<IndexRange, IndexRange> entry : 
consumedSubpartitionGroups.entrySet()) {
+                IndexRange partitionRange = entry.getKey();
+                IndexRange subpartitionRange = entry.getValue();
+                optimizedDataBytes[i] +=
+                        inputInfo.getNumBytesProduced(partitionRange, 
subpartitionRange);
+            }
+            optimizedMin = Math.min(optimizedMin, optimizedDataBytes[i]);
+            optimizedMax = Math.max(optimizedMax, optimizedDataBytes[i]);
+
+            Map<IndexRange, IndexRange> nonOptimizedConsumedSubpartitionGroup =
+                    computeNumBasedConsumedSubpartitionGroup(parallelism, i, 
inputInfo);
+            checkState(nonOptimizedConsumedSubpartitionGroup.size() == 1);
+            nonOptimizedDataBytes[i] +=
+                    inputInfo.getNumBytesProduced(
+                            nonOptimizedConsumedSubpartitionGroup
+                                    .entrySet()
+                                    .iterator()
+                                    .next()
+                                    .getKey(),
+                            nonOptimizedConsumedSubpartitionGroup
+                                    .entrySet()
+                                    .iterator()
+                                    .next()
+                                    .getValue());
+            nonOptimizedMin = Math.min(nonOptimizedMin, 
nonOptimizedDataBytes[i]);
+            nonOptimizedMax = Math.max(nonOptimizedMax, 
nonOptimizedDataBytes[i]);
+
+            if (!optimized
+                    && 
!consumedSubpartitionGroups.equals(nonOptimizedConsumedSubpartitionGroup)) {
+                optimized = true;
+            }
+        }
+        if (optimized) {
+            long optimizedMed = median(optimizedDataBytes);
+            long nonOptimizedMed = median(nonOptimizedDataBytes);
+            String logMessage =
+                    String.format(
+                            "Result id: %s, "
+                                    + "type number: %d, "
+                                    + "input data size: "
+                                    + "[ Before: {min: %s, median: %s, max: 
%s}, "
+                                    + "After: {min: %s, median: %s, max: %s} 
]",
+                            inputInfo.getResultId(),
+                            inputInfo.getInputTypeNumber(),
+                            new 
MemorySize(nonOptimizedMin).toHumanReadableString(),
+                            new 
MemorySize(nonOptimizedMed).toHumanReadableString(),
+                            new 
MemorySize(nonOptimizedMax).toHumanReadableString(),
+                            new 
MemorySize(optimizedMin).toHumanReadableString(),
+                            new 
MemorySize(optimizedMed).toHumanReadableString(),
+                            new 
MemorySize(optimizedMax).toHumanReadableString());
+            return Optional.of(logMessage);
+        }
+        return Optional.empty();
+    }
+
+    private static Map<IndexRange, IndexRange> 
computeNumBasedConsumedSubpartitionGroup(
+            int parallelism, int currentIndex, BlockingInputInfo inputInfo) {
+        int sourceParallelism = inputInfo.getNumPartitions();
+
+        if (inputInfo.isPointwise()) {
+            return computeNumBasedConsumedSubpartitionGroupForPointwise(
+                    sourceParallelism, parallelism, currentIndex, 
inputInfo::getNumSubpartitions);
+        } else {
+            return computeNumBasedConsumedSubpartitionGroupForAllToAll(
+                    sourceParallelism,
+                    parallelism,
+                    currentIndex,
+                    inputInfo::getNumSubpartitions,
+                    inputInfo.isBroadcast(),
+                    inputInfo.isSingleSubpartitionContainsAllData());
+        }
+    }
+
+    static Map<IndexRange, IndexRange> 
computeNumBasedConsumedSubpartitionGroupForPointwise(

Review Comment:
   VisibleForTesting



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtils.java:
##########
@@ -658,4 +660,148 @@ public static long calculateDataVolumePerTaskForInput(
             long globalDataVolumePerTask, long inputsGroupBytes, long 
totalDataBytes) {
         return (long) ((double) inputsGroupBytes / totalDataBytes * 
globalDataVolumePerTask);
     }
+
+    public static Optional<String> constructOptimizationLog(

Review Comment:
   Could we print the optimization log in this method and split this long 
method into several smaller methods?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to