Skip to content

Commit

Permalink
[SPARK-3936] Add aggregateMessages, which supersedes mapReduceTriplets
Browse files Browse the repository at this point in the history
aggregateMessages enables neighborhood computation similarly to mapReduceTriplets, but it introduces two API improvements:

1. Messages are sent using an imperative interface based on EdgeContext rather than by returning an iterator of messages.

2. Rather than attempting bytecode inspection, the required triplet fields must be explicitly specified by the user by passing a TripletFields object. This fixes SPARK-3936.

Additionally, this PR includes the following optimizations for aggregateMessages and EdgePartition:

1. EdgePartition now stores local vertex ids instead of global ids. This avoids hash lookups when looking up vertex attributes and aggregating messages.

2. Internal iterators in aggregateMessages are inlined into a while loop.

In total, these optimizations were tested to provide a 37% speedup on PageRank (uk-2007-05 graph, 10 iterations, 16 r3.2xlarge machines, sped up from 513 s to 322 s).

Subsumes apache#2815. Also fixes SPARK-4173.

Author: Ankur Dave <ankurdave@gmail.com>

Closes apache#3100 from ankurdave/aggregateMessages and squashes the following commits:

f5b65d0 [Ankur Dave] Address @rxin comments on apache#3054 and apache#3100
1e80aca [Ankur Dave] Add aggregateMessages, which supersedes mapReduceTriplets
194a2df [Ankur Dave] Test triplet iterator in EdgePartition serialization test
e0f8ecc [Ankur Dave] Take activeSet in ExistingEdgePartitionBuilder
c85076d [Ankur Dave] Readability improvements
b567be2 [Ankur Dave] iter.foreach -> while loop
4a566dc [Ankur Dave] Optimizations for mapReduceTriplets and EdgePartition
  • Loading branch information
ankurdave authored and rxin committed Nov 12, 2014
1 parent 2ef016b commit faeb41d
Show file tree
Hide file tree
Showing 15 changed files with 766 additions and 376 deletions.
51 changes: 51 additions & 0 deletions graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.graphx

/**
* Represents an edge along with its neighboring vertices and allows sending messages along the
* edge. Used in [[Graph#aggregateMessages]].
*/
abstract class EdgeContext[VD, ED, A] {
/** The vertex id of the edge's source vertex. */
def srcId: VertexId
/** The vertex id of the edge's destination vertex. */
def dstId: VertexId
/** The vertex attribute of the edge's source vertex. */
def srcAttr: VD
/** The vertex attribute of the edge's destination vertex. */
def dstAttr: VD
/** The attribute associated with the edge. */
def attr: ED

/** Sends a message to the source vertex. */
def sendToSrc(msg: A): Unit
/** Sends a message to the destination vertex. */
def sendToDst(msg: A): Unit

/** Converts the edge and vertex properties into an [[EdgeTriplet]] for convenience. */
def toEdgeTriplet: EdgeTriplet[VD, ED] = {
val et = new EdgeTriplet[VD, ED]
et.srcId = srcId
et.srcAttr = srcAttr
et.dstId = dstId
et.dstAttr = dstAttr
et.attr = attr
et
}
}
137 changes: 124 additions & 13 deletions graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,39 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* }}}
*
*/
def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
mapTriplets((pid, iter) => iter.map(map))
def mapTriplets[ED2: ClassTag](
map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
mapTriplets((pid, iter) => iter.map(map), TripletFields.All)
}

/**
* Transforms each edge attribute using the map function, passing it the adjacent vertex
* attributes as well. If adjacent vertex values are not required,
* consider using `mapEdges` instead.
*
* @note This does not change the structure of the
* graph or modify the values of this graph. As a consequence
* the underlying index structures can be reused.
*
* @param map the function from an edge object to a new edge value.
* @param tripletFields which fields should be included in the edge triplet passed to the map
* function. If not all fields are needed, specifying this can improve performance.
*
* @tparam ED2 the new edge data type
*
* @example This function might be used to initialize edge
* attributes based on the attributes associated with each vertex.
* {{{
* val rawGraph: Graph[Int, Int] = someLoadFunction()
* val graph = rawGraph.mapTriplets[Int]( edge =>
* edge.src.data - edge.dst.data)
* }}}
*
*/
def mapTriplets[ED2: ClassTag](
map: EdgeTriplet[VD, ED] => ED2,
tripletFields: TripletFields): Graph[VD, ED2] = {
mapTriplets((pid, iter) => iter.map(map), tripletFields)
}

/**
Expand All @@ -223,12 +254,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* the underlying index structures can be reused.
*
* @param map the iterator transform
* @param tripletFields which fields should be included in the edge triplet passed to the map
* function. If not all fields are needed, specifying this can improve performance.
*
* @tparam ED2 the new edge data type
*
*/
def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2])
: Graph[VD, ED2]
def mapTriplets[ED2: ClassTag](
map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2],
tripletFields: TripletFields): Graph[VD, ED2]

/**
* Reverses all edges in the graph. If this graph contains an edge from a to b then the returned
Expand Down Expand Up @@ -287,6 +321,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* "sent" to either vertex in the edge. The `reduceFunc` is then used to combine the output of
* the map phase destined to each vertex.
*
* This function is deprecated in 1.2.0 because of SPARK-3936. Use aggregateMessages instead.
*
* @tparam A the type of "message" to be sent to each vertex
*
* @param mapFunc the user defined map function which returns 0 or
Expand All @@ -296,13 +332,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* be commutative and associative and is used to combine the output
* of the map phase
*
* @param activeSetOpt optionally, a set of "active" vertices and a direction of edges to
* consider when running `mapFunc`. If the direction is `In`, `mapFunc` will only be run on
* edges with destination in the active set. If the direction is `Out`,
* `mapFunc` will only be run on edges originating from vertices in the active set. If the
* direction is `Either`, `mapFunc` will be run on edges with *either* vertex in the active set
* . If the direction is `Both`, `mapFunc` will be run on edges with *both* vertices in the
* active set. The active set must have the same index as the graph's vertices.
* @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if
* desired. This is done by specifying a set of "active" vertices and an edge direction. The
* `sendMsg` function will then run only on edges connected to active vertices by edges in the
* specified direction. If the direction is `In`, `sendMsg` will only be run on edges with
* destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges
* originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be
* run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg`
* will be run on edges with *both* vertices in the active set. The active set must have the
* same index as the graph's vertices.
*
* @example We can use this function to compute the in-degree of each
* vertex
Expand All @@ -319,15 +357,88 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* predicate or implement PageRank.
*
*/
@deprecated("use aggregateMessages", "1.2.0")
def mapReduceTriplets[A: ClassTag](
mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
reduceFunc: (A, A) => A,
activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None)
: VertexRDD[A]

/**
* Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. The
* input table should contain at most one entry for each vertex. If no entry in `other` is
* Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied
* `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be
* sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages
* destined to the same vertex.
*
* @tparam A the type of message to be sent to each vertex
*
* @param sendMsg runs on each edge, sending messages to neighboring vertices using the
* [[EdgeContext]].
* @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This
* combiner should be commutative and associative.
* @param tripletFields which fields should be included in the [[EdgeContext]] passed to the
* `sendMsg` function. If not all fields are needed, specifying this can improve performance.
*
* @example We can use this function to compute the in-degree of each
* vertex
* {{{
* val rawGraph: Graph[_, _] = Graph.textFile("twittergraph")
* val inDeg: RDD[(VertexId, Int)] =
* aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _)
* }}}
*
* @note By expressing computation at the edge level we achieve
* maximum parallelism. This is one of the core functions in the
* Graph API in that enables neighborhood level computation. For
* example this function can be used to count neighbors satisfying a
* predicate or implement PageRank.
*
*/
def aggregateMessages[A: ClassTag](
sendMsg: EdgeContext[VD, ED, A] => Unit,
mergeMsg: (A, A) => A,
tripletFields: TripletFields = TripletFields.All)
: VertexRDD[A] = {
aggregateMessagesWithActiveSet(sendMsg, mergeMsg, tripletFields, None)
}

/**
* Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied
* `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be
* sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages
* destined to the same vertex.
*
* This variant can take an active set to restrict the computation and is intended for internal
* use only.
*
* @tparam A the type of message to be sent to each vertex
*
* @param sendMsg runs on each edge, sending messages to neighboring vertices using the
* [[EdgeContext]].
* @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This
* combiner should be commutative and associative.
* @param tripletFields which fields should be included in the [[EdgeContext]] passed to the
* `sendMsg` function. If not all fields are needed, specifying this can improve performance.
* @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if
* desired. This is done by specifying a set of "active" vertices and an edge direction. The
* `sendMsg` function will then run on only edges connected to active vertices by edges in the
* specified direction. If the direction is `In`, `sendMsg` will only be run on edges with
* destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges
* originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be
* run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg`
* will be run on edges with *both* vertices in the active set. The active set must have the
* same index as the graph's vertices.
*/
private[graphx] def aggregateMessagesWithActiveSet[A: ClassTag](
sendMsg: EdgeContext[VD, ED, A] => Unit,
mergeMsg: (A, A) => A,
tripletFields: TripletFields,
activeSetOpt: Option[(VertexRDD[_], EdgeDirection)])
: VertexRDD[A]

/**
* Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`.
* The input table should contain at most one entry for each vertex. If no entry in `other` is
* provided for a particular vertex in the graph, the map function receives `None`.
*
* @tparam U the type of entry in the table of updates
Expand Down
85 changes: 46 additions & 39 deletions graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,12 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
*/
private def degreesRDD(edgeDirection: EdgeDirection): VertexRDD[Int] = {
if (edgeDirection == EdgeDirection.In) {
graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _)
graph.aggregateMessages(_.sendToDst(1), _ + _, TripletFields.None)
} else if (edgeDirection == EdgeDirection.Out) {
graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _)
graph.aggregateMessages(_.sendToSrc(1), _ + _, TripletFields.None)
} else { // EdgeDirection.Either
graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _)
graph.aggregateMessages(ctx => { ctx.sendToSrc(1); ctx.sendToDst(1) }, _ + _,
TripletFields.None)
}
}

Expand All @@ -88,18 +89,17 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] = {
val nbrs =
if (edgeDirection == EdgeDirection.Either) {
graph.mapReduceTriplets[Array[VertexId]](
mapFunc = et => Iterator((et.srcId, Array(et.dstId)), (et.dstId, Array(et.srcId))),
reduceFunc = _ ++ _
)
graph.aggregateMessages[Array[VertexId]](
ctx => { ctx.sendToSrc(Array(ctx.dstId)); ctx.sendToDst(Array(ctx.srcId)) },
_ ++ _, TripletFields.None)
} else if (edgeDirection == EdgeDirection.Out) {
graph.mapReduceTriplets[Array[VertexId]](
mapFunc = et => Iterator((et.srcId, Array(et.dstId))),
reduceFunc = _ ++ _)
graph.aggregateMessages[Array[VertexId]](
ctx => ctx.sendToSrc(Array(ctx.dstId)),
_ ++ _, TripletFields.None)
} else if (edgeDirection == EdgeDirection.In) {
graph.mapReduceTriplets[Array[VertexId]](
mapFunc = et => Iterator((et.dstId, Array(et.srcId))),
reduceFunc = _ ++ _)
graph.aggregateMessages[Array[VertexId]](
ctx => ctx.sendToDst(Array(ctx.srcId)),
_ ++ _, TripletFields.None)
} else {
throw new SparkException("It doesn't make sense to collect neighbor ids without a " +
"direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)")
Expand All @@ -122,22 +122,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
* @return the vertex set of neighboring vertex attributes for each vertex
*/
def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = {
val nbrs = graph.mapReduceTriplets[Array[(VertexId,VD)]](
edge => {
val msgToSrc = (edge.srcId, Array((edge.dstId, edge.dstAttr)))
val msgToDst = (edge.dstId, Array((edge.srcId, edge.srcAttr)))
edgeDirection match {
case EdgeDirection.Either => Iterator(msgToSrc, msgToDst)
case EdgeDirection.In => Iterator(msgToDst)
case EdgeDirection.Out => Iterator(msgToSrc)
case EdgeDirection.Both =>
throw new SparkException("collectNeighbors does not support EdgeDirection.Both. Use" +
"EdgeDirection.Either instead.")
}
},
(a, b) => a ++ b)

graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) =>
val nbrs = edgeDirection match {
case EdgeDirection.Either =>
graph.aggregateMessages[Array[(VertexId,VD)]](
ctx => {
ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr)))
ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr)))
},
(a, b) => a ++ b, TripletFields.SrcDstOnly)
case EdgeDirection.In =>
graph.aggregateMessages[Array[(VertexId,VD)]](
ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))),
(a, b) => a ++ b, TripletFields.SrcOnly)
case EdgeDirection.Out =>
graph.aggregateMessages[Array[(VertexId,VD)]](
ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))),
(a, b) => a ++ b, TripletFields.DstOnly)
case EdgeDirection.Both =>
throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
"EdgeDirection.Either instead.")
}
graph.vertices.leftJoin(nbrs) { (vid, vdata, nbrsOpt) =>
nbrsOpt.getOrElse(Array.empty[(VertexId, VD)])
}
} // end of collectNeighbor
Expand All @@ -160,18 +165,20 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = {
edgeDirection match {
case EdgeDirection.Either =>
graph.mapReduceTriplets[Array[Edge[ED]]](
edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))),
(edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
(a, b) => a ++ b)
graph.aggregateMessages[Array[Edge[ED]]](
ctx => {
ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr)))
ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr)))
},
(a, b) => a ++ b, TripletFields.EdgeOnly)
case EdgeDirection.In =>
graph.mapReduceTriplets[Array[Edge[ED]]](
edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
(a, b) => a ++ b)
graph.aggregateMessages[Array[Edge[ED]]](
ctx => ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))),
(a, b) => a ++ b, TripletFields.EdgeOnly)
case EdgeDirection.Out =>
graph.mapReduceTriplets[Array[Edge[ED]]](
edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
(a, b) => a ++ b)
graph.aggregateMessages[Array[Edge[ED]]](
ctx => ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))),
(a, b) => a ++ b, TripletFields.EdgeOnly)
case EdgeDirection.Both =>
throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
"EdgeDirection.Either instead.")
Expand Down
Loading

0 comments on commit faeb41d

Please sign in to comment.