Hi,

I'm working on big graph analytics, and currently implementing a mean field
inference algorithm in GraphX/Spark. I start with an arbitrary graph, keep a
(sparse) probability distribution at each node implemented as a
Map[Long,Double]. At each iteration, from the current estimates of the
distributions I update some global variables with two accumulators, then I
gather with mapReduceTriplet the probability distributions of neighbors, and
finally update the distributions with message + the two accumulators values
broadcasted to the cluster with sc.broadcast. Unfortunately, this works well
for extremely small graphs, but it becomes exponentially slow with the size
of the graph and the number of iterations (doesn't finish 20 iterations with
graphs having 48000 edges).

I suspect that the problem is related to the broadcasted variables, so I
tryed to use .checkpoint() to remove the broadcasted variables from the
lineage, and to use different Storagelevel for persistence, but without
success. It seems to me that a lot of things are unnecessarily recomputed at
each iterations whatever I try to do. I also did multiple changes to limit
the number of dependency of each object, but it didn't change anything.

Here is a sample of code (simplified to be understandable, so not running),
hopefully this should give you a feeling about what it is doing. Thanks !

def run(graph : Graph[Long,Long],m : Long)(implicit sc : SparkContext) =  {
    var fusionMap = Map[Long, Long]().withDefault(x => x)
    // Initials values
    val tots = Map[Long, Double]().withDefaultValue(1.0)
    var totBcst = sc.broadcast(tots)
    var fusionBcst = sc.broadcast(fusionMap)
    val mC = sc.broadcast(m)
    // Initial graph
    var g = graph.mapVertices({ case (vid, deg) =>
VertexProp(initialDistribution(vid), deg) })
    var newVerts = g.vertices
    //Initial messages
    var msg = g.mapReduceTriplets(MFExecutor.sendMsgMF,
MFExecutor.mergeMsgMF)
    var iter = 0

    while (iter < 20) {
      // MF Messages
      val oldMessages = msg
      val oldVerts = newVerts
      newVerts =
newVerts.innerJoin(msg)(MFExecutor.vprogMF(mC,totBcst,fusionBcst))//.persist(StorageLevel.MEMORY_AND_DISK)
      newVerts.checkpoint()
      newVerts.count()
      val prevG = g
      g = graph.outerJoinVertices(newVerts)({case (vid,deg,newOpt) =>
newOpt.getOrElse(VertexProp(Map(vid -> 1.0).withDefaultValue(0.0),
deg))}).cache()
      //g = g.outerJoinVertices(newVerts)({case (vid,old,newOpt) =>
newOpt.getOrElse(old)})

      // 1st global variable
      val fusionAcc = sc.accumulable[Map[Long, Long], (Long,
Long)](fusionMap)(FusionAccumulable)
      g.triplets.filter(tp => testEq(fusionBcst)(tp.srcId,tp.dstId)&&
(spd.dotPD(tp.dstAttr.prob, tp.srcAttr.prob) > 0.9)).foreach(tp => 
fusionAcc += (tp.dstId, tp.srcId))
      fusionBcst.unpersist(blocking = false)
      fusionMap = fusionAcc.value
      fusionBcst = sc.broadcast(fusionMap)

      //2nd global variable
      val totAcc = sc.accumulator[Map[Long, Double]](Map[Long,
Double]().withDefaultValue(0.0))(TotAccumulable)
      newVerts.foreach({ case (vid, vprop) => totAcc +=
vprop.prob.mapValues(p => p * vprop.deg).withDefaultValue(0.0)})
      totBcst.unpersist(blocking = false)
      totBcst = sc.broadcast(totAcc.value)

      // New MF messages
      msg = g.mapReduceTriplets(MFExecutor.sendMsgMF, MFExecutor.mergeMsgMF)

      // Unpersist options
      oldMessages.unpersist(blocking = false)
      oldVerts.unpersist(blocking=false)
      prevG.unpersistVertices(blocking=false)

      iter = iter + 1
    } 
}



--
View this message in context: 
http://apache-spark-user-list.1001560.n3.nabble.com/Spark-GraphX-pregel-like-with-global-variables-accumulator-broadcast-tp12742.html
Sent from the Apache Spark User List mailing list archive at Nabble.com.

---------------------------------------------------------------------
To unsubscribe, e-mail: user-unsubscr...@spark.apache.org
For additional commands, e-mail: user-h...@spark.apache.org

Reply via email to