zhipeng93 commented on code in PR #227: URL: https://github.com/apache/flink-ml/pull/227#discussion_r1155538452
########## flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java: ########## @@ -288,13 +291,20 @@ public static <IN, ACC, OUT> DataStream<OUT> aggregate( public static <T> DataStream<T> sample(DataStream<T> input, int numSamples, long randomSeed) { int inputParallelism = input.getParallelism(); - return input.transform( + // In a worst-case scenario, the data partition with the greatest number of elements has + // `inputParallelism` additional elements compared to the one with the fewest elements even + // after `rebalance` performed. Therefore, additional elements are sampled from each + // partition in case some partitions has insufficient elements. + int firstRoundNumSamples = Review Comment: nit: The max difference of number of data points in each partition after calling `rebalance` is `inputParallism`. As a result, extra `inputParallelism` data points are sampled for each partition in the first round. ########## flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java: ########## @@ -288,13 +291,20 @@ public static <IN, ACC, OUT> DataStream<OUT> aggregate( public static <T> DataStream<T> sample(DataStream<T> input, int numSamples, long randomSeed) { int inputParallelism = input.getParallelism(); - return input.transform( + // In a worst-case scenario, the data partition with the greatest number of elements has + // `inputParallelism` additional elements compared to the one with the fewest elements even + // after `rebalance` performed. Therefore, additional elements are sampled from each + // partition in case some partitions has insufficient elements. + int firstRoundNumSamples = + Math.min((numSamples / inputParallelism) + inputParallelism, numSamples); + return input.rebalance() + .transform( "samplingOperator", input.getType(), - new SamplingOperator<>(numSamples, randomSeed)) + new SamplingOperator<>(firstRoundNumSamples, randomSeed)) .setParallelism(inputParallelism) .transform( - "samplingOperator", + "samplingOperator-2nd-round", Review Comment: nit: Let's make the samping operator name consistent. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org