Hi, I'm working on big graph analytics, and currently implementing a mean field inference algorithm in GraphX/Spark. I start with an arbitrary graph, keep a (sparse) probability distribution at each node implemented as a Map[Long,Double]. At each iteration, from the current estimates of the distributions I update some global variables with two accumulators, then I gather with mapReduceTriplet the probability distributions of neighbors, and finally update the distributions with message + the two accumulators values broadcasted to the cluster with sc.broadcast. Unfortunately, this works well for extremely small graphs, but it becomes exponentially slow with the size of the graph and the number of iterations (doesn't finish 20 iterations with graphs having 48000 edges).
I suspect that the problem is related to the broadcasted variables, so I tryed to use .checkpoint() to remove the broadcasted variables from the lineage, and to use different Storagelevel for persistence, but without success. It seems to me that a lot of things are unnecessarily recomputed at each iterations whatever I try to do. I also did multiple changes to limit the number of dependency of each object, but it didn't change anything. Here is a sample of code (simplified to be understandable, so not running), hopefully this should give you a feeling about what it is doing. Thanks ! def run(graph : Graph[Long,Long],m : Long)(implicit sc : SparkContext) = { var fusionMap = Map[Long, Long]().withDefault(x => x) // Initials values val tots = Map[Long, Double]().withDefaultValue(1.0) var totBcst = sc.broadcast(tots) var fusionBcst = sc.broadcast(fusionMap) val mC = sc.broadcast(m) // Initial graph var g = graph.mapVertices({ case (vid, deg) => VertexProp(initialDistribution(vid), deg) }) var newVerts = g.vertices //Initial messages var msg = g.mapReduceTriplets(MFExecutor.sendMsgMF, MFExecutor.mergeMsgMF) var iter = 0 while (iter < 20) { // MF Messages val oldMessages = msg val oldVerts = newVerts newVerts = newVerts.innerJoin(msg)(MFExecutor.vprogMF(mC,totBcst,fusionBcst))//.persist(StorageLevel.MEMORY_AND_DISK) newVerts.checkpoint() newVerts.count() val prevG = g g = graph.outerJoinVertices(newVerts)({case (vid,deg,newOpt) => newOpt.getOrElse(VertexProp(Map(vid -> 1.0).withDefaultValue(0.0), deg))}).cache() //g = g.outerJoinVertices(newVerts)({case (vid,old,newOpt) => newOpt.getOrElse(old)}) // 1st global variable val fusionAcc = sc.accumulable[Map[Long, Long], (Long, Long)](fusionMap)(FusionAccumulable) g.triplets.filter(tp => testEq(fusionBcst)(tp.srcId,tp.dstId)&& (spd.dotPD(tp.dstAttr.prob, tp.srcAttr.prob) > 0.9)).foreach(tp => fusionAcc += (tp.dstId, tp.srcId)) fusionBcst.unpersist(blocking = false) fusionMap = fusionAcc.value fusionBcst = sc.broadcast(fusionMap) //2nd global variable val totAcc = sc.accumulator[Map[Long, Double]](Map[Long, Double]().withDefaultValue(0.0))(TotAccumulable) newVerts.foreach({ case (vid, vprop) => totAcc += vprop.prob.mapValues(p => p * vprop.deg).withDefaultValue(0.0)}) totBcst.unpersist(blocking = false) totBcst = sc.broadcast(totAcc.value) // New MF messages msg = g.mapReduceTriplets(MFExecutor.sendMsgMF, MFExecutor.mergeMsgMF) // Unpersist options oldMessages.unpersist(blocking = false) oldVerts.unpersist(blocking=false) prevG.unpersistVertices(blocking=false) iter = iter + 1 } } -- View this message in context: http://apache-spark-user-list.1001560.n3.nabble.com/Spark-GraphX-pregel-like-with-global-variables-accumulator-broadcast-tp12742.html Sent from the Apache Spark User List mailing list archive at Nabble.com. --------------------------------------------------------------------- To unsubscribe, e-mail: user-unsubscr...@spark.apache.org For additional commands, e-mail: user-h...@spark.apache.org