Github user greghogan commented on a diff in the pull request: https://github.com/apache/flink/pull/2053#discussion_r69121905 --- Diff: flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/AffinityPropagation.java --- @@ -0,0 +1,535 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.graph.library; + +import org.apache.flink.api.common.aggregators.LongSumAggregator; +import org.apache.flink.api.common.functions.FilterFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.graph.EdgeDirection; +import org.apache.flink.graph.Graph; +import org.apache.flink.graph.GraphAlgorithm; +import org.apache.flink.graph.Vertex; +import org.apache.flink.graph.Edge; +import org.apache.flink.graph.EdgesFunction; +import org.apache.flink.graph.spargel.MessageIterator; +import org.apache.flink.graph.spargel.MessagingFunction; +import org.apache.flink.graph.spargel.ScatterGatherConfiguration; +import org.apache.flink.graph.spargel.VertexUpdateFunction; +import org.apache.flink.types.LongValue; +import org.apache.flink.types.NullValue; +import org.apache.flink.util.Collector; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.HashMap; + + +/** + * This is an implementation of the Binary Affinity Propagation algorithm using a scatter-gather iteration. + * Note that is not the original Affinity Propagation. + * + * The input is an undirected graph where the vertices are the points to be clustered and the edge weights are the + * similarities of these points among them. + * + * The output is a Dataset of Tuple2, where f0 is the point id and f1 is the exemplar, so the clusters will be the + * the Tuples grouped by f1 + * + * @see <a href="http://www.psi.toronto.edu/pubs2/2009/NC09%20-%20SimpleAP.pdf"> + */ + +@SuppressWarnings("serial") +public class AffinityPropagation implements GraphAlgorithm<Long,NullValue,Double,DataSet<Tuple2<Long, Long>>> { + + private static Integer maxIterations; + private static float damping; + private static float epsilon; + + /** + * Creates a new AffinityPropagation instance algorithm instance. + * + * @param maxIterations The maximum number of iterations to run + * @param damping Damping factor. + * @param epsilon Epsilon factor. Do not send message to a neighbor if the new message + * has not changed more than epsilon. + */ + public AffinityPropagation(Integer maxIterations, float damping, float epsilon) { + this.maxIterations = maxIterations; + this.damping = damping; + this.epsilon = epsilon; + } + + @Override + public DataSet<Tuple2<Long, Long>> run(Graph<Long, NullValue, Double> input) throws Exception { + + // Create E and I AP vertices + DataSet<Vertex<Long, APVertexValue>> verticesWithAllInNeighbors = + input.groupReduceOnEdges(new InitAPVertex(), EdgeDirection.IN); + + List<Vertex<Long, APVertexValue>> APvertices = verticesWithAllInNeighbors.collect(); + + // Create E and I AP edges. Could this be done with some gelly functionality? + List<Edge<Long, NullValue>> APedges = new ArrayList<>(); + + for(int i = 1; i < input.numberOfVertices() + 1; i++){ + for(int j = 1; j < input.numberOfVertices() + 1; j++){ + APedges.add(new Edge<>(i * 10L, j * 10L + 1, NullValue.getInstance())); + } + } + + DataSet<Edge<Long, NullValue>> APEdgesDS = input.getContext().fromCollection(APedges); + DataSet<Vertex<Long, APVertexValue>> APVerticesDS = input.getContext().fromCollection(APvertices); + + ScatterGatherConfiguration parameters = new ScatterGatherConfiguration(); + parameters.registerAggregator("convergedAggregator", new LongSumAggregator()); + + Graph<Long, APVertexValue, NullValue> APgraph + = Graph.fromDataSet(APVerticesDS, APEdgesDS, input.getContext()); + + return APgraph.getUndirected().runScatterGatherIteration(new APVertexUpdater(input.numberOfVertices() * 2), + new APMessenger(),this.maxIterations,parameters).getVertices().filter(new FilterFunction<Vertex<Long, APVertexValue>>() { + @Override + public boolean filter(Vertex<Long, APVertexValue> vertex) throws Exception { + return vertex.getId()%2 == 0; + } + }).map(new MapFunction<Vertex<Long, APVertexValue>, Tuple2<Long, Long>>() { + @Override + public Tuple2<Long, Long> map(Vertex<Long, APVertexValue> value) throws Exception { + Tuple2<Long, Long> returnValue = new Tuple2<>(value.getId()/10, value.getValue().getExemplar()/10); + return returnValue; + } + }); + + } + + /** + * Foreach input point we have to create a pair of E,I vertices. Same structure is used for both vertex type, to + * diferenciate E and I vertices is used the id. Foreach input point we will create: + * + * - One E vertex with the id as the original input id * 10 + 1 + * - One I vertex with the id as the original input id * 10 + * + * This way even ids are from E type vertices and odd ids are from I vertices. + * + * It also calculates adds the weights to the I vertices. Notice that the S vertices are not created and the weights + * are added to the I vertices, simulating the S vertex. + */ + + @SuppressWarnings("serial") + private static final class InitAPVertex implements EdgesFunction<Long, Double, Vertex<Long, APVertexValue>> { + + @Override + public void iterateEdges(Iterable<Tuple2<Long, Edge<Long, Double>>> edges, + Collector<Vertex<Long, APVertexValue>> out) throws Exception { + + Vertex<Long, APVertexValue> APvertexI = new Vertex<>(); + Vertex<Long, APVertexValue> APvertexE = new Vertex<>(); + + Iterator<Tuple2<Long, Edge<Long, Double>>> itr = edges.iterator(); + Tuple2<Long, Edge<Long, Double>> edge = itr.next(); + + APvertexE.setId(edge.f0 * 10 + 1); + APvertexE.setValue(new APVertexValue()); + + APvertexI.setId(edge.f0 * 10); + APvertexI.setValue(new APVertexValue()); + APvertexI.getValue().getWeights().put(edge.f1.getSource() * 10 + 1, edge.f1.getValue()); + + APvertexE.getValue().getOldValues().put(edge.f1.getSource() * 10, 0.0); + APvertexI.getValue().getOldValues().put(edge.f1.getSource() * 10 + 1, 0.0); + + + while(itr.hasNext()){ + edge = itr.next(); + APvertexI.getValue().getWeights().put(edge.f1.getSource() * 10 + 1, edge.f1.getValue()); + + APvertexE.getValue().getOldValues().put(edge.f1.getSource() * 10, 0.0); + APvertexI.getValue().getOldValues().put(edge.f1.getSource() * 10 + 1, 0.0); + + } + + out.collect(APvertexE); + out.collect(APvertexI); + } + } + + /** + * Vertex updater + */ + + @SuppressWarnings("serial") + public static final class APVertexUpdater extends VertexUpdateFunction<Long, APVertexValue, APMessage> { + + private Long numOfVertex; + LongSumAggregator aggregator = new LongSumAggregator(); + + public APVertexUpdater(Long numOfVertex){ + this.numOfVertex = numOfVertex; + } + + @Override + public void preSuperstep() throws Exception { + + aggregator = getIterationAggregator("convergedAggregator"); + + } + + /** + * Main vertex update function. It calls updateIVertex, updateEVertex, computeExemplars or computeClusters + * depending on the phase of the algorithm execution + */ + + @Override + public void updateVertex(Vertex<Long, APVertexValue> vertex, + MessageIterator<APMessage> inMessages) { + + //If all vertices converged compute the Exemplars + + if(getSuperstepNumber() > 1 + && (((LongValue)getPreviousIterationAggregate("convergedAggregator")).getValue() + == numOfVertex|| getSuperstepNumber() == maxIterations-2)) { + computeExemplars(vertex, inMessages); + return; + } + + //Once the exemplars have been calculated calculate the clusters. The aggregator has a negative value assigned + //when exemplars are calculated + if(getSuperstepNumber() > 1 + && ((LongValue)getPreviousIterationAggregate("convergedAggregator")).getValue() + < 0) { + if(vertex.getValue().getExemplar() < 0){ + computeClusters(vertex, inMessages); + } + return; + } + + //Call updateIvertex or updateEvertex depending on the id + if(vertex.getId()%2 == 0){ --- End diff -- Are the I and E vertex updates not talking past each other? We had this consideration with HITS where the computation is running twice in parallel and the updates are crisscrossing.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. ---