zhipeng93 commented on code in PR #227:
URL: https://github.com/apache/flink-ml/pull/227#discussion_r1155538452


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -288,13 +291,20 @@ public static <IN, ACC, OUT> DataStream<OUT> aggregate(
     public static <T> DataStream<T> sample(DataStream<T> input, int 
numSamples, long randomSeed) {
         int inputParallelism = input.getParallelism();
 
-        return input.transform(
+        // In a worst-case scenario, the data partition with the greatest 
number of elements has
+        // `inputParallelism` additional elements compared to the one with the 
fewest elements even
+        // after `rebalance` performed. Therefore, additional elements are 
sampled from each
+        // partition in case some partitions has insufficient elements.
+        int firstRoundNumSamples =

Review Comment:
   nit: The max difference of number of data points in each partition after 
calling `rebalance` is `inputParallism`. 
   As a result, extra `inputParallelism` data points are sampled for each 
partition in the first round.



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -288,13 +291,20 @@ public static <IN, ACC, OUT> DataStream<OUT> aggregate(
     public static <T> DataStream<T> sample(DataStream<T> input, int 
numSamples, long randomSeed) {
         int inputParallelism = input.getParallelism();
 
-        return input.transform(
+        // In a worst-case scenario, the data partition with the greatest 
number of elements has
+        // `inputParallelism` additional elements compared to the one with the 
fewest elements even
+        // after `rebalance` performed. Therefore, additional elements are 
sampled from each
+        // partition in case some partitions has insufficient elements.
+        int firstRoundNumSamples =
+                Math.min((numSamples / inputParallelism) + inputParallelism, 
numSamples);
+        return input.rebalance()
+                .transform(
                         "samplingOperator",
                         input.getType(),
-                        new SamplingOperator<>(numSamples, randomSeed))
+                        new SamplingOperator<>(firstRoundNumSamples, 
randomSeed))
                 .setParallelism(inputParallelism)
                 .transform(
-                        "samplingOperator",
+                        "samplingOperator-2nd-round",

Review Comment:
   nit: Let's make the samping operator name consistent.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to