JunRuiLee commented on code in PR #26180: URL: https://github.com/apache/flink/pull/26180#discussion_r1963052390
########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtils.java: ########## @@ -658,4 +660,148 @@ public static long calculateDataVolumePerTaskForInput( long globalDataVolumePerTask, long inputsGroupBytes, long totalDataBytes) { return (long) ((double) inputsGroupBytes / totalDataBytes * globalDataVolumePerTask); } + + public static Optional<String> constructOptimizationLog( + BlockingInputInfo inputInfo, JobVertexInputInfo jobVertexInputInfo) { + if (inputInfo.areInterInputsKeysCorrelated() && inputInfo.isIntraInputKeyCorrelated()) { + return Optional.empty(); + } + boolean optimized = false; + List<ExecutionVertexInputInfo> executionVertexInputInfos = + jobVertexInputInfo.getExecutionVertexInputInfos(); + int parallelism = executionVertexInputInfos.size(); + long[] optimizedDataBytes = new long[parallelism]; + long optimizedMin = Long.MAX_VALUE, optimizedMax = 0; + long[] nonOptimizedDataBytes = new long[parallelism]; + long nonOptimizedMin = Long.MAX_VALUE, nonOptimizedMax = 0; + for (int i = 0; i < parallelism; ++i) { + Map<IndexRange, IndexRange> consumedSubpartitionGroups = + executionVertexInputInfos.get(i).getConsumedSubpartitionGroups(); + for (Map.Entry<IndexRange, IndexRange> entry : consumedSubpartitionGroups.entrySet()) { + IndexRange partitionRange = entry.getKey(); + IndexRange subpartitionRange = entry.getValue(); + optimizedDataBytes[i] += + inputInfo.getNumBytesProduced(partitionRange, subpartitionRange); + } + optimizedMin = Math.min(optimizedMin, optimizedDataBytes[i]); + optimizedMax = Math.max(optimizedMax, optimizedDataBytes[i]); + + Map<IndexRange, IndexRange> nonOptimizedConsumedSubpartitionGroup = + computeNumBasedConsumedSubpartitionGroup(parallelism, i, inputInfo); + checkState(nonOptimizedConsumedSubpartitionGroup.size() == 1); + nonOptimizedDataBytes[i] += + inputInfo.getNumBytesProduced( + nonOptimizedConsumedSubpartitionGroup + .entrySet() + .iterator() + .next() + .getKey(), + nonOptimizedConsumedSubpartitionGroup + .entrySet() + .iterator() + .next() + .getValue()); + nonOptimizedMin = Math.min(nonOptimizedMin, nonOptimizedDataBytes[i]); + nonOptimizedMax = Math.max(nonOptimizedMax, nonOptimizedDataBytes[i]); + + if (!optimized + && !consumedSubpartitionGroups.equals(nonOptimizedConsumedSubpartitionGroup)) { + optimized = true; + } + } + if (optimized) { + long optimizedMed = median(optimizedDataBytes); + long nonOptimizedMed = median(nonOptimizedDataBytes); + String logMessage = + String.format( + "Result id: %s, " + + "type number: %d, " + + "input data size: " + + "[ Before: {min: %s, median: %s, max: %s}, " + + "After: {min: %s, median: %s, max: %s} ]", + inputInfo.getResultId(), + inputInfo.getInputTypeNumber(), + new MemorySize(nonOptimizedMin).toHumanReadableString(), + new MemorySize(nonOptimizedMed).toHumanReadableString(), + new MemorySize(nonOptimizedMax).toHumanReadableString(), + new MemorySize(optimizedMin).toHumanReadableString(), + new MemorySize(optimizedMed).toHumanReadableString(), + new MemorySize(optimizedMax).toHumanReadableString()); + return Optional.of(logMessage); + } + return Optional.empty(); + } + + private static Map<IndexRange, IndexRange> computeNumBasedConsumedSubpartitionGroup( + int parallelism, int currentIndex, BlockingInputInfo inputInfo) { + int sourceParallelism = inputInfo.getNumPartitions(); + + if (inputInfo.isPointwise()) { + return computeNumBasedConsumedSubpartitionGroupForPointwise( + sourceParallelism, parallelism, currentIndex, inputInfo::getNumSubpartitions); + } else { + return computeNumBasedConsumedSubpartitionGroupForAllToAll( + sourceParallelism, + parallelism, + currentIndex, + inputInfo::getNumSubpartitions, + inputInfo.isBroadcast(), + inputInfo.isSingleSubpartitionContainsAllData()); + } + } + + static Map<IndexRange, IndexRange> computeNumBasedConsumedSubpartitionGroupForPointwise( + int sourceCount, + int targetCount, + int currentIndex, + Function<Integer, Integer> numOfSubpartitionsRetriever) { + if (sourceCount >= targetCount) { + int start = currentIndex * sourceCount / targetCount; + int end = (currentIndex + 1) * sourceCount / targetCount; + IndexRange partitionRange = new IndexRange(start, end - 1); + IndexRange subpartitionRange = + computeConsumedSubpartitionRange( + currentIndex, + 1, + () -> numOfSubpartitionsRetriever.apply(start), + true, + false, + false); + return Map.of(partitionRange, subpartitionRange); + } else { + int partitionNum = (currentIndex * sourceCount - sourceCount + 1) / targetCount; + int start = (partitionNum * targetCount + sourceCount - 1) / sourceCount; + int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount; + int numConsumers = end - start; + IndexRange partitionRange = new IndexRange(partitionNum, partitionNum); + IndexRange subpartitionRange = + computeConsumedSubpartitionRange( + currentIndex, + numConsumers, + () -> numOfSubpartitionsRetriever.apply(partitionNum), + true, + false, + false); + return Map.of(partitionRange, subpartitionRange); + } + } + + static Map<IndexRange, IndexRange> computeNumBasedConsumedSubpartitionGroupForAllToAll( Review Comment: ditto ########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtils.java: ########## @@ -658,4 +660,148 @@ public static long calculateDataVolumePerTaskForInput( long globalDataVolumePerTask, long inputsGroupBytes, long totalDataBytes) { return (long) ((double) inputsGroupBytes / totalDataBytes * globalDataVolumePerTask); } + + public static Optional<String> constructOptimizationLog( + BlockingInputInfo inputInfo, JobVertexInputInfo jobVertexInputInfo) { + if (inputInfo.areInterInputsKeysCorrelated() && inputInfo.isIntraInputKeyCorrelated()) { + return Optional.empty(); + } + boolean optimized = false; + List<ExecutionVertexInputInfo> executionVertexInputInfos = + jobVertexInputInfo.getExecutionVertexInputInfos(); + int parallelism = executionVertexInputInfos.size(); + long[] optimizedDataBytes = new long[parallelism]; + long optimizedMin = Long.MAX_VALUE, optimizedMax = 0; + long[] nonOptimizedDataBytes = new long[parallelism]; + long nonOptimizedMin = Long.MAX_VALUE, nonOptimizedMax = 0; + for (int i = 0; i < parallelism; ++i) { + Map<IndexRange, IndexRange> consumedSubpartitionGroups = + executionVertexInputInfos.get(i).getConsumedSubpartitionGroups(); + for (Map.Entry<IndexRange, IndexRange> entry : consumedSubpartitionGroups.entrySet()) { + IndexRange partitionRange = entry.getKey(); + IndexRange subpartitionRange = entry.getValue(); + optimizedDataBytes[i] += + inputInfo.getNumBytesProduced(partitionRange, subpartitionRange); + } + optimizedMin = Math.min(optimizedMin, optimizedDataBytes[i]); + optimizedMax = Math.max(optimizedMax, optimizedDataBytes[i]); + + Map<IndexRange, IndexRange> nonOptimizedConsumedSubpartitionGroup = + computeNumBasedConsumedSubpartitionGroup(parallelism, i, inputInfo); + checkState(nonOptimizedConsumedSubpartitionGroup.size() == 1); + nonOptimizedDataBytes[i] += + inputInfo.getNumBytesProduced( + nonOptimizedConsumedSubpartitionGroup + .entrySet() + .iterator() + .next() + .getKey(), + nonOptimizedConsumedSubpartitionGroup + .entrySet() + .iterator() + .next() + .getValue()); + nonOptimizedMin = Math.min(nonOptimizedMin, nonOptimizedDataBytes[i]); + nonOptimizedMax = Math.max(nonOptimizedMax, nonOptimizedDataBytes[i]); + + if (!optimized + && !consumedSubpartitionGroups.equals(nonOptimizedConsumedSubpartitionGroup)) { + optimized = true; + } + } + if (optimized) { + long optimizedMed = median(optimizedDataBytes); + long nonOptimizedMed = median(nonOptimizedDataBytes); + String logMessage = + String.format( + "Result id: %s, " + + "type number: %d, " + + "input data size: " + + "[ Before: {min: %s, median: %s, max: %s}, " + + "After: {min: %s, median: %s, max: %s} ]", + inputInfo.getResultId(), + inputInfo.getInputTypeNumber(), + new MemorySize(nonOptimizedMin).toHumanReadableString(), + new MemorySize(nonOptimizedMed).toHumanReadableString(), + new MemorySize(nonOptimizedMax).toHumanReadableString(), + new MemorySize(optimizedMin).toHumanReadableString(), + new MemorySize(optimizedMed).toHumanReadableString(), + new MemorySize(optimizedMax).toHumanReadableString()); + return Optional.of(logMessage); + } + return Optional.empty(); + } + + private static Map<IndexRange, IndexRange> computeNumBasedConsumedSubpartitionGroup( + int parallelism, int currentIndex, BlockingInputInfo inputInfo) { + int sourceParallelism = inputInfo.getNumPartitions(); + + if (inputInfo.isPointwise()) { + return computeNumBasedConsumedSubpartitionGroupForPointwise( + sourceParallelism, parallelism, currentIndex, inputInfo::getNumSubpartitions); + } else { + return computeNumBasedConsumedSubpartitionGroupForAllToAll( + sourceParallelism, + parallelism, + currentIndex, + inputInfo::getNumSubpartitions, + inputInfo.isBroadcast(), + inputInfo.isSingleSubpartitionContainsAllData()); + } + } + + static Map<IndexRange, IndexRange> computeNumBasedConsumedSubpartitionGroupForPointwise( Review Comment: VisibleForTesting ########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtils.java: ########## @@ -658,4 +660,148 @@ public static long calculateDataVolumePerTaskForInput( long globalDataVolumePerTask, long inputsGroupBytes, long totalDataBytes) { return (long) ((double) inputsGroupBytes / totalDataBytes * globalDataVolumePerTask); } + + public static Optional<String> constructOptimizationLog( Review Comment: Could we print the optimization log in this method and split this long method into several smaller methods? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org