yunfengzhou-hub commented on a change in pull request #32: URL: https://github.com/apache/flink-ml/pull/32#discussion_r755671196
########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java ########## @@ -83,43 +97,19 @@ public void before() { dataTable = tEnv.fromDataStream(env.fromCollection(DATA), schema).as("features"); } - // Executes the graph and returns a map which maps points to clusterId. - private static Map<DenseVector, Integer> executeAndCollect( - Table output, String featureCol, String predictionCol) throws Exception { - StreamTableEnvironment tEnv = - (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); - - DataStream<Tuple2<DenseVector, Integer>> stream = - tEnv.toDataStream(output) - .map( - new MapFunction<Row, Tuple2<DenseVector, Integer>>() { - @Override - public Tuple2<DenseVector, Integer> map(Row row) { - return Tuple2.of( - (DenseVector) row.getField(featureCol), - (Integer) row.getField(predictionCol)); - } - }); - - List<Tuple2<DenseVector, Integer>> pointsWithClusterId = - IteratorUtils.toList(stream.executeAndCollect()); - - Map<DenseVector, Integer> clusterIdByPoints = new HashMap<>(); - for (Tuple2<DenseVector, Integer> entry : pointsWithClusterId) { - clusterIdByPoints.put(entry.f0, entry.f1); - } - return clusterIdByPoints; - } - - private static void verifyClusteringResult( - Map<DenseVector, Integer> clusterIdByPoints, List<List<Integer>> groups) { - for (List<Integer> group : groups) { - for (int i = 1; i < group.size(); i++) { - assertEquals( - clusterIdByPoints.get(DATA.get(group.get(0))), - clusterIdByPoints.get(DATA.get(group.get(i)))); + private static List<Set<DenseVector>> executeAndCollect( + Table output, String featureCol, String predictionCol) { + Map<Integer, Set<DenseVector>> map = new HashMap<>(); + for (CloseableIterator<Row> it = output.execute().collect(); it.hasNext(); ) { + Row row = it.next(); + DenseVector vector = (DenseVector) row.getField(featureCol); + int predict = (Integer) row.getField(predictionCol); + if (!map.containsKey(predict)) { + map.put(predict, new HashSet<>()); Review comment: OK. I'll fix it. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org