Github user andralungu commented on a diff in the pull request: https://github.com/apache/flink/pull/408#discussion_r25853291 --- Diff: flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java --- @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.graph.gsa; + +import org.apache.commons.lang3.Validate; +import org.apache.flink.api.common.functions.FlatJoinFunction; +import org.apache.flink.api.common.functions.RichFlatJoinFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichReduceFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.operators.CustomUnaryOperation; +import org.apache.flink.api.java.operators.DeltaIteration; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.ResultTypeQueryable; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.graph.Edge; +import org.apache.flink.graph.Vertex; +import org.apache.flink.util.Collector; + +import java.io.Serializable; + +/** + * This class represents iterative graph computations, programmed in a gather-sum-apply perspective. + * + * @param <K> The type of the vertex key in the graph + * @param <VV> The type of the vertex value in the graph + * @param <EV> The type of the edge value in the graph + * @param <M> The intermediate type used by the gather, sum and apply functions + */ +public class GatherSumApplyIteration<K extends Comparable<K> & Serializable, + VV extends Serializable, EV extends Serializable, M> implements CustomUnaryOperation<Vertex<K, VV>, + Vertex<K, VV>> { + + private DataSet<Vertex<K, VV>> vertexDataSet; + private DataSet<Edge<K, EV>> edgeDataSet; + + private final GatherFunction<VV, EV, M> gather; + private final SumFunction<VV, EV, M> sum; + private final ApplyFunction<VV, EV, M> apply; + private final int maximumNumberOfIterations; + + private String name; + private int parallelism = -1; + + // ---------------------------------------------------------------------------------- + + private GatherSumApplyIteration(GatherFunction<VV, EV, M> gather, SumFunction<VV, EV, M> sum, + ApplyFunction<VV, EV, M> apply, DataSet<Edge<K, EV>> edges, int maximumNumberOfIterations) { + + Validate.notNull(gather); + Validate.notNull(sum); + Validate.notNull(apply); + Validate.notNull(edges); + Validate.isTrue(maximumNumberOfIterations > 0, "The maximum number of iterations must be at least one."); + + this.gather = gather; + this.sum = sum; + this.apply = apply; + this.edgeDataSet = edges; + this.maximumNumberOfIterations = maximumNumberOfIterations; + } + + + /** + * Sets the name for the gather-sum-apply iteration. The name is displayed in logs and messages. + * + * @param name The name for the iteration. + */ + public void setName(String name) { + this.name = name; + } + + /** + * Gets the name from this gather-sum-apply iteration. + * + * @return The name of the iteration. + */ + public String getName() { + return name; + } + + /** + * Sets the degree of parallelism for the iteration. + * + * @param parallelism The degree of parallelism. + */ + public void setParallelism(int parallelism) { + Validate.isTrue(parallelism > 0 || parallelism == -1, + "The degree of parallelism must be positive, or -1 (use default)."); + this.parallelism = parallelism; + } + + /** + * Gets the iteration's degree of parallelism. + * + * @return The iterations parallelism, or -1, if not set. + */ + public int getParallelism() { + return parallelism; + } + + // -------------------------------------------------------------------------------------------- + // Custom Operator behavior + // -------------------------------------------------------------------------------------------- + + /** + * Sets the input data set for this operator. In the case of this operator this input data set represents + * the set of vertices with their initial state. + * + * @param dataSet The input data set, which in the case of this operator represents the set of + * vertices with their initial state. + */ + @Override + public void setInput(DataSet<Vertex<K, VV>> dataSet) { + this.vertexDataSet = dataSet; + } + + /** + * Computes the results of the gather-sum-apply iteration + * + * @return The resulting DataSet + */ + @Override + public DataSet<Vertex<K, VV>> createResult() { + if (vertexDataSet == null) { + throw new IllegalStateException("The input data set has not been set."); + } + + // Prepare type information + TypeInformation<K> keyType = ((TupleTypeInfo<?>) vertexDataSet.getType()).getTypeAt(0); + TypeInformation<M> messageType = TypeExtractor.createTypeInfo(GatherFunction.class, gather.getClass(), 2, null, null); + TypeInformation<Tuple2<K, M>> innerType = new TupleTypeInfo<Tuple2<K, M>>(keyType, messageType); + TypeInformation<Vertex<K, VV>> outputType = vertexDataSet.getType(); + + // Prepare UDFs + GatherUdf<K, VV, EV, M> gatherUdf = new GatherUdf<K, VV, EV, M>(gather, innerType); + SumUdf<K, VV, EV, M> sumUdf = new SumUdf<K, VV, EV, M>(sum, innerType); + ApplyUdf<K, VV, EV, M> applyUdf = new ApplyUdf<K, VV, EV, M>(apply, outputType); + + final int[] zeroKeyPos = new int[] {0}; + final DeltaIteration<Vertex<K, VV>, Vertex<K, VV>> iteration = + vertexDataSet.iterateDelta(vertexDataSet, maximumNumberOfIterations, zeroKeyPos); + + // Prepare triplets + DataSet<Tuple3<Vertex<K, VV>, Edge<K, EV>, Vertex<K, VV>>> triplets = iteration + .getWorkset() + .join(edgeDataSet) + .where(0) + .equalTo(0) + .with(new PairJoinFunction<K, VV, EV>()) + .join(iteration.getSolutionSet()) + .where(0) + .equalTo(0) + .with(new TripletJoinFunction<K, VV, EV>()); + + // Gather, sum and apply + DataSet<Tuple2<K, M>> gatheredSet = triplets.map(gatherUdf); + DataSet<Tuple2<K, M>> summedSet = gatheredSet.groupBy(0).reduce(sumUdf); + DataSet<Vertex<K, VV>> appliedSet = summedSet + .join(iteration.getSolutionSet()) + .where(0) + .equalTo(0) + .with(applyUdf); + + return iteration.closeWith(appliedSet, appliedSet); + } + + /** + * Creates a new gather-sum-apply iteration operator for graphs + * + * @param edges The edge DataSet + * + * @param gather The gather function of the GSA iteration + * @param sum The sum function of the GSA iteration + * @param apply The apply function of the GSA iteration + * + * @param maximumNumberOfIterations The maximum number of iterations executed + * + * @param <K> The type of the vertex key in the graph + * @param <VV> The type of the vertex value in the graph + * @param <EV> The type of the edge value in the graph + * @param <M> The intermediate type used by the gather, sum and apply functions + * + * @return An in stance of the gather-sum-apply graph computation operator. + */ + public static final <K extends Comparable<K> & Serializable, VV extends Serializable, EV extends Serializable, M> + GatherSumApplyIteration<K, VV, EV, M> withEdges(DataSet<Edge<K, EV>> edges, + GatherFunction<VV, EV, M> gather, SumFunction<VV, EV, M> sum, ApplyFunction<VV, EV, M> apply, + int maximumNumberOfIterations) { + return new GatherSumApplyIteration<K, VV, EV, M>(gather, sum, apply, edges, maximumNumberOfIterations); + } + + // -------------------------------------------------------------------------------------------- + // Triplet Utils + // -------------------------------------------------------------------------------------------- + + private static final class PairJoinFunction<K extends Comparable<K> & Serializable, VV extends Serializable, + EV extends Serializable> implements FlatJoinFunction<Vertex<K, VV>, Edge<K, EV>, + Tuple3<K, Vertex<K, VV>, Edge<K, EV>>> { + + @Override + public void join(Vertex<K, VV> vertex, Edge<K, EV> edge, + Collector<Tuple3<K, Vertex<K, VV>, Edge<K, EV>>> collector) throws Exception { + collector.collect(new Tuple3<K, Vertex<K, VV>, Edge<K, EV>>(edge.getTarget(), vertex, edge)); + } + } + + private static final class TripletJoinFunction<K extends Comparable<K> & Serializable, VV extends Serializable, + EV extends Serializable> implements FlatJoinFunction<Tuple3<K, Vertex<K, VV>, Edge<K, EV>>, Vertex<K, VV>, + Tuple3<Vertex<K, VV>, Edge<K, EV>, Vertex<K, VV>>> { + @Override + public void join(Tuple3<K, Vertex<K, VV>, Edge<K, EV>> vertexEdge, Vertex<K, VV> vertex, + Collector<Tuple3<Vertex<K, VV>, Edge<K, EV>, Vertex<K, VV>>> out) throws Exception { + out.collect(new Tuple3<Vertex<K, VV>, Edge<K, EV>, Vertex<K, VV>>( + vertexEdge.f1, vertexEdge.f2, vertex + )); + } + } + + // -------------------------------------------------------------------------------------------- + // Wrapping UDFs + // -------------------------------------------------------------------------------------------- + + private static final class GatherUdf<K extends Comparable<K> & Serializable, VV extends Serializable, + EV extends Serializable, M> extends RichMapFunction<Tuple3<Vertex<K, VV>, Edge<K, EV>, Vertex<K, VV>>, + Tuple2<K, M>> implements ResultTypeQueryable<Tuple2<K, M>> { + + private final GatherFunction<VV, EV, M> gatherFunction; + private transient TypeInformation<Tuple2<K, M>> resultType; + + private GatherUdf(GatherFunction<VV, EV, M> gatherFunction, TypeInformation<Tuple2<K, M>> resultType) { + this.gatherFunction = gatherFunction; + this.resultType = resultType; + } + + @Override + public Tuple2<K, M> map(Tuple3<Vertex<K, VV>, Edge<K, EV>, Vertex<K, VV>> triplet) throws Exception { + Triplet<VV, EV> userTriplet = new Triplet<VV, EV>(triplet.f0.getValue(), + triplet.f1.getValue(), triplet.f2.getValue()); + + K key = triplet.f2.getId(); + M result = this.gatherFunction.gather(userTriplet); + return new Tuple2<K, M>(key, result); + } + + @Override + public void open(Configuration parameters) throws Exception { --- End diff -- @vasia, Perhaps we should check if the superstep number bug in vertexCentricIteration is reproducible, because from the looks of this, it should be...
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. ---