yunfengzhou-hub commented on a change in pull request #32:
URL: https://github.com/apache/flink-ml/pull/32#discussion_r755671196



##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -83,43 +97,19 @@ public void before() {
         dataTable = tEnv.fromDataStream(env.fromCollection(DATA), 
schema).as("features");
     }
 
-    // Executes the graph and returns a map which maps points to clusterId.
-    private static Map<DenseVector, Integer> executeAndCollect(
-            Table output, String featureCol, String predictionCol) throws 
Exception {
-        StreamTableEnvironment tEnv =
-                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
-
-        DataStream<Tuple2<DenseVector, Integer>> stream =
-                tEnv.toDataStream(output)
-                        .map(
-                                new MapFunction<Row, Tuple2<DenseVector, 
Integer>>() {
-                                    @Override
-                                    public Tuple2<DenseVector, Integer> 
map(Row row) {
-                                        return Tuple2.of(
-                                                (DenseVector) 
row.getField(featureCol),
-                                                (Integer) 
row.getField(predictionCol));
-                                    }
-                                });
-
-        List<Tuple2<DenseVector, Integer>> pointsWithClusterId =
-                IteratorUtils.toList(stream.executeAndCollect());
-
-        Map<DenseVector, Integer> clusterIdByPoints = new HashMap<>();
-        for (Tuple2<DenseVector, Integer> entry : pointsWithClusterId) {
-            clusterIdByPoints.put(entry.f0, entry.f1);
-        }
-        return clusterIdByPoints;
-    }
-
-    private static void verifyClusteringResult(
-            Map<DenseVector, Integer> clusterIdByPoints, List<List<Integer>> 
groups) {
-        for (List<Integer> group : groups) {
-            for (int i = 1; i < group.size(); i++) {
-                assertEquals(
-                        clusterIdByPoints.get(DATA.get(group.get(0))),
-                        clusterIdByPoints.get(DATA.get(group.get(i))));
+    private static List<Set<DenseVector>> executeAndCollect(
+            Table output, String featureCol, String predictionCol) {
+        Map<Integer, Set<DenseVector>> map = new HashMap<>();
+        for (CloseableIterator<Row> it = output.execute().collect(); 
it.hasNext(); ) {
+            Row row = it.next();
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            int predict = (Integer) row.getField(predictionCol);
+            if (!map.containsKey(predict)) {
+                map.put(predict, new HashSet<>());

Review comment:
       OK. I'll fix it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to