@@ -4,21 +4,51 @@ import spark._
44import spark .SparkContext ._
55
66import scala .collection .mutable .ArrayBuffer
7+ import storage .StorageLevel
78
89object Bagel extends Logging {
9- def run [K : Manifest , V <: Vertex : Manifest , M <: Message [K ] : Manifest ,
10- C : Manifest , A : Manifest ](
10+ val DEFAULT_STORAGE_LEVEL = StorageLevel .MEMORY_AND_DISK
11+
12+ /**
13+ * Runs a Bagel program.
14+ * @param sc [[spark.SparkContext ]] to use for the program.
15+ * @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the Key will be
16+ * the vertex id.
17+ * @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often this will be an
18+ * empty array, i.e. sc.parallelize(Array[K, Message]()).
19+ * @param combiner [[spark.bagel.Combiner ]] combines multiple individual messages to a given vertex into one
20+ * message before sending (which often involves network I/O).
21+ * @param aggregator [[spark.bagel.Aggregator ]] performs a reduce across all vertices after each superstep,
22+ * and provides the result to each vertex in the next superstep.
23+ * @param partitioner [[spark.Partitioner ]] partitions values by key
24+ * @param numPartitions number of partitions across which to split the graph.
25+ * Default is the default parallelism of the SparkContext
26+ * @param storageLevel [[spark.storage.StorageLevel ]] to use for caching of intermediate RDDs in each superstep.
27+ * Defaults to caching in memory.
28+ * @param compute function that takes a Vertex, optional set of (possibly combined) messages to the Vertex,
29+ * optional Aggregator and the current superstep,
30+ * and returns a set of (Vertex, outgoing Messages) pairs
31+ * @tparam K key
32+ * @tparam V vertex type
33+ * @tparam M message type
34+ * @tparam C combiner
35+ * @tparam A aggregator
36+ * @return an RDD of (K, V) pairs representing the graph after completion of the program
37+ */
38+ def run [K : Manifest , V <: Vertex : Manifest , M <: Message [K ] : Manifest ,
39+ C : Manifest , A : Manifest ](
1140 sc : SparkContext ,
1241 vertices : RDD [(K , V )],
1342 messages : RDD [(K , M )],
1443 combiner : Combiner [M , C ],
1544 aggregator : Option [Aggregator [V , A ]],
1645 partitioner : Partitioner ,
17- numSplits : Int
46+ numPartitions : Int ,
47+ storageLevel : StorageLevel = DEFAULT_STORAGE_LEVEL
1848 )(
1949 compute : (V , Option [C ], Option [A ], Int ) => (V , Array [M ])
2050 ): RDD [(K , V )] = {
21- val splits = if (numSplits != 0 ) numSplits else sc.defaultParallelism
51+ val splits = if (numPartitions != 0 ) numPartitions else sc.defaultParallelism
2252
2353 var superstep = 0
2454 var verts = vertices
@@ -32,8 +62,9 @@ object Bagel extends Logging {
3262 val combinedMsgs = msgs.combineByKey(
3363 combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner)
3464 val grouped = combinedMsgs.groupWith(verts)
65+ val superstep_ = superstep // Create a read-only copy of superstep for capture in closure
3566 val (processed, numMsgs, numActiveVerts) =
36- comp[K , V , M , C ](sc, grouped, compute(_, _, aggregated, superstep) )
67+ comp[K , V , M , C ](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel )
3768
3869 val timeTaken = System .currentTimeMillis - startTime
3970 logInfo(" Superstep %d took %d s" .format(superstep, timeTaken / 1000 ))
@@ -50,57 +81,103 @@ object Bagel extends Logging {
5081 verts
5182 }
5283
53- def run [ K : Manifest , V <: Vertex : Manifest , M <: Message [ K ] : Manifest ,
54- C : Manifest ](
84+ /** Runs a Bagel program with no [[ spark.bagel.Aggregator ]] and the default storage level */
85+ def run [ K : Manifest , V <: Vertex : Manifest , M <: Message [ K ] : Manifest , C : Manifest ](
5586 sc : SparkContext ,
5687 vertices : RDD [(K , V )],
5788 messages : RDD [(K , M )],
5889 combiner : Combiner [M , C ],
5990 partitioner : Partitioner ,
60- numSplits : Int
91+ numPartitions : Int
92+ )(
93+ compute : (V , Option [C ], Int ) => (V , Array [M ])
94+ ): RDD [(K , V )] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL )(compute)
95+
96+ /** Runs a Bagel program with no [[spark.bagel.Aggregator ]] */
97+ def run [K : Manifest , V <: Vertex : Manifest , M <: Message [K ] : Manifest , C : Manifest ](
98+ sc : SparkContext ,
99+ vertices : RDD [(K , V )],
100+ messages : RDD [(K , M )],
101+ combiner : Combiner [M , C ],
102+ partitioner : Partitioner ,
103+ numPartitions : Int ,
104+ storageLevel : StorageLevel
61105 )(
62106 compute : (V , Option [C ], Int ) => (V , Array [M ])
63107 ): RDD [(K , V )] = {
64108 run[K , V , M , C , Nothing ](
65- sc, vertices, messages, combiner, None , partitioner, numSplits )(
109+ sc, vertices, messages, combiner, None , partitioner, numPartitions, storageLevel )(
66110 addAggregatorArg[K , V , M , C ](compute))
67111 }
68112
69- def run [K : Manifest , V <: Vertex : Manifest , M <: Message [K ] : Manifest ,
70- C : Manifest ](
113+ /**
114+ * Runs a Bagel program with no [[spark.bagel.Aggregator ]], default [[spark.HashPartitioner ]]
115+ * and default storage level
116+ */
117+ def run [K : Manifest , V <: Vertex : Manifest , M <: Message [K ] : Manifest , C : Manifest ](
118+ sc : SparkContext ,
119+ vertices : RDD [(K , V )],
120+ messages : RDD [(K , M )],
121+ combiner : Combiner [M , C ],
122+ numPartitions : Int
123+ )(
124+ compute : (V , Option [C ], Int ) => (V , Array [M ])
125+ ): RDD [(K , V )] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL )(compute)
126+
127+ /** Runs a Bagel program with no [[spark.bagel.Aggregator ]] and the default [[spark.HashPartitioner ]]*/
128+ def run [K : Manifest , V <: Vertex : Manifest , M <: Message [K ] : Manifest , C : Manifest ](
71129 sc : SparkContext ,
72130 vertices : RDD [(K , V )],
73131 messages : RDD [(K , M )],
74132 combiner : Combiner [M , C ],
75- numSplits : Int
133+ numPartitions : Int ,
134+ storageLevel : StorageLevel
76135 )(
77136 compute : (V , Option [C ], Int ) => (V , Array [M ])
78137 ): RDD [(K , V )] = {
79- val part = new HashPartitioner (numSplits )
138+ val part = new HashPartitioner (numPartitions )
80139 run[K , V , M , C , Nothing ](
81- sc, vertices, messages, combiner, None , part, numSplits )(
140+ sc, vertices, messages, combiner, None , part, numPartitions, storageLevel )(
82141 addAggregatorArg[K , V , M , C ](compute))
83142 }
84143
85- def run [K : Manifest , V <: Vertex : Manifest , M <: Message [K ] : Manifest ](
144+ /**
145+ * Runs a Bagel program with no [[spark.bagel.Aggregator ]], default [[spark.HashPartitioner ]],
146+ * [[spark.bagel.DefaultCombiner ]] and the default storage level
147+ */
148+ def run [K : Manifest , V <: Vertex : Manifest , M <: Message [K ] : Manifest ](
86149 sc : SparkContext ,
87150 vertices : RDD [(K , V )],
88151 messages : RDD [(K , M )],
89- numSplits : Int
152+ numPartitions : Int
90153 )(
91154 compute : (V , Option [Array [M ]], Int ) => (V , Array [M ])
92- ): RDD [(K , V )] = {
93- val part = new HashPartitioner (numSplits)
155+ ): RDD [(K , V )] = run(sc, vertices, messages, numPartitions, DEFAULT_STORAGE_LEVEL )(compute)
156+
157+ /**
158+ * Runs a Bagel program with no [[spark.bagel.Aggregator ]], the default [[spark.HashPartitioner ]]
159+ * and [[spark.bagel.DefaultCombiner ]]
160+ */
161+ def run [K : Manifest , V <: Vertex : Manifest , M <: Message [K ] : Manifest ](
162+ sc : SparkContext ,
163+ vertices : RDD [(K , V )],
164+ messages : RDD [(K , M )],
165+ numPartitions : Int ,
166+ storageLevel : StorageLevel
167+ )(
168+ compute : (V , Option [Array [M ]], Int ) => (V , Array [M ])
169+ ): RDD [(K , V )] = {
170+ val part = new HashPartitioner (numPartitions)
94171 run[K , V , M , Array [M ], Nothing ](
95- sc, vertices, messages, new DefaultCombiner (), None , part, numSplits )(
172+ sc, vertices, messages, new DefaultCombiner (), None , part, numPartitions, storageLevel )(
96173 addAggregatorArg[K , V , M , Array [M ]](compute))
97174 }
98175
99176 /**
100177 * Aggregates the given vertices using the given aggregator, if it
101178 * is specified.
102179 */
103- private def agg [K , V <: Vertex , A : Manifest ](
180+ private def agg [K , V <: Vertex , A : Manifest ](
104181 verts : RDD [(K , V )],
105182 aggregator : Option [Aggregator [V , A ]]
106183 ): Option [A ] = aggregator match {
@@ -116,10 +193,11 @@ object Bagel extends Logging {
116193 * function. Returns the processed RDD, the number of messages
117194 * created, and the number of active vertices.
118195 */
119- private def comp [K : Manifest , V <: Vertex , M <: Message [K ], C ](
196+ private def comp [K : Manifest , V <: Vertex , M <: Message [K ], C ](
120197 sc : SparkContext ,
121198 grouped : RDD [(K , (Seq [C ], Seq [V ]))],
122- compute : (V , Option [C ]) => (V , Array [M ])
199+ compute : (V , Option [C ]) => (V , Array [M ]),
200+ storageLevel : StorageLevel
123201 ): (RDD [(K , (V , Array [M ]))], Int , Int ) = {
124202 var numMsgs = sc.accumulator(0 )
125203 var numActiveVerts = sc.accumulator(0 )
@@ -137,7 +215,7 @@ object Bagel extends Logging {
137215 numActiveVerts += 1
138216
139217 Some ((newVert, newMsgs))
140- }.cache
218+ }.persist(storageLevel)
141219
142220 // Force evaluation of processed RDD for accurate performance measurements
143221 processed.foreach(x => {})
@@ -149,9 +227,7 @@ object Bagel extends Logging {
149227 * Converts a compute function that doesn't take an aggregator to
150228 * one that does, so it can be passed to Bagel.run.
151229 */
152- private def addAggregatorArg [
153- K : Manifest , V <: Vertex : Manifest , M <: Message [K ] : Manifest , C
154- ](
230+ private def addAggregatorArg [K : Manifest , V <: Vertex : Manifest , M <: Message [K ] : Manifest , C ](
155231 compute : (V , Option [C ], Int ) => (V , Array [M ])
156232 ): (V , Option [C ], Option [Nothing ], Int ) => (V , Array [M ]) = {
157233 (vert : V , msgs : Option [C ], aggregated : Option [Nothing ], superstep : Int ) =>
@@ -170,7 +246,8 @@ trait Aggregator[V, A] {
170246 def mergeAggregators (a : A , b : A ): A
171247}
172248
173- class DefaultCombiner [M : Manifest ] extends Combiner [M , Array [M ]] with Serializable {
249+ /** Default combiner that simply appends messages together (i.e. performs no aggregation) */
250+ class DefaultCombiner [M : Manifest ] extends Combiner [M , Array [M ]] with Serializable {
174251 def createCombiner (msg : M ): Array [M ] =
175252 Array (msg)
176253 def mergeMsg (combiner : Array [M ], msg : M ): Array [M ] =
0 commit comments