Skip to content

Commit 3ade8be

Browse files
committed
Add clustered index on edges by source vertex
This allows efficient edge scan in mapReduceTriplets when many source vertices are inactive. The scan method switches from edge scan to clustered index scan when less than 80% of source vertices are active.
1 parent 0476c84 commit 3ade8be

File tree

7 files changed

+167
-65
lines changed

7 files changed

+167
-65
lines changed

graph/src/main/scala/org/apache/spark/graph/Edge.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,10 @@ case class Edge[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED]
4141
def relativeDirection(vid: Vid): EdgeDirection =
4242
if (vid == srcId) EdgeDirection.Out else { assert(vid == dstId); EdgeDirection.In }
4343
}
44+
45+
object Edge {
46+
def lexicographicOrdering[ED] = new Ordering[Edge[ED]] {
47+
override def compare(a: Edge[ED], b: Edge[ED]): Int =
48+
Ordering[(Vid, Vid)].compare((a.srcId, a.dstId), (b.srcId, b.dstId))
49+
}
50+
}

graph/src/main/scala/org/apache/spark/graph/impl/EdgePartition.scala

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,36 @@
11
package org.apache.spark.graph.impl
22

33
import org.apache.spark.graph._
4-
import org.apache.spark.util.collection.OpenHashMap
4+
import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap
55

66
/**
7-
* A collection of edges stored in 3 large columnar arrays (src, dst, attribute).
7+
* A collection of edges stored in 3 large columnar arrays (src, dst, attribute). The arrays are
8+
* clustered by src.
89
*
910
* @param srcIds the source vertex id of each edge
1011
* @param dstIds the destination vertex id of each edge
1112
* @param data the attribute associated with each edge
13+
* @param index a clustered index on source vertex id
1214
* @tparam ED the edge attribute type.
1315
*/
1416
class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassManifest](
1517
val srcIds: Array[Vid],
1618
val dstIds: Array[Vid],
17-
val data: Array[ED]) {
19+
val data: Array[ED],
20+
val index: PrimitiveKeyOpenHashMap[Vid, Int]) {
1821

1922
/**
2023
* Reverse all the edges in this partition.
2124
*
22-
* @note No new data structures are created.
23-
*
2425
* @return a new edge partition with all edges reversed.
2526
*/
26-
def reverse: EdgePartition[ED] = new EdgePartition(dstIds, srcIds, data)
27+
def reverse: EdgePartition[ED] = {
28+
val builder = new EdgePartitionBuilder(size)
29+
for (e <- iterator) {
30+
builder.add(e.dstId, e.srcId, e.attr)
31+
}
32+
builder.toEdgePartition
33+
}
2734

2835
/**
2936
* Construct a new edge partition by applying the function f to all
@@ -46,25 +53,16 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
4653
newData(i) = f(edge)
4754
i += 1
4855
}
49-
new EdgePartition(srcIds, dstIds, newData)
56+
new EdgePartition(srcIds, dstIds, newData, index)
5057
}
5158

5259
/**
5360
* Apply the function f to all edges in this partition.
5461
*
5562
* @param f an external state mutating user defined function.
5663
*/
57-
def foreach(f: Edge[ED] => Unit) {
58-
val edge = new Edge[ED]
59-
val size = data.size
60-
var i = 0
61-
while (i < size) {
62-
edge.srcId = srcIds(i)
63-
edge.dstId = dstIds(i)
64-
edge.attr = data(i)
65-
f(edge)
66-
i += 1
67-
}
64+
def foreach(f: Edge[ED] => Unit) {
65+
iterator.foreach(f)
6866
}
6967

7068
/**
@@ -75,21 +73,29 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
7573
* @return a new edge partition without duplicate edges
7674
*/
7775
def groupEdges(merge: (ED, ED) => ED): EdgePartition[ED] = {
78-
// Aggregate all matching edges in a hashmap
79-
val agg = new OpenHashMap[(Vid,Vid), ED]
80-
foreach { e => agg.setMerge((e.srcId, e.dstId), e.attr, merge) }
81-
// Populate new srcId, dstId, and data, arrays
82-
val newSrcIds = new Array[Vid](agg.size)
83-
val newDstIds = new Array[Vid](agg.size)
84-
val newData = new Array[ED](agg.size)
76+
val builder = new EdgePartitionBuilder[ED]
77+
var firstIter: Boolean = true
78+
var currSrcId: Vid = nullValue[Vid]
79+
var currDstId: Vid = nullValue[Vid]
80+
var currAttr: ED = nullValue[ED]
8581
var i = 0
86-
agg.foreach { kv =>
87-
newSrcIds(i) = kv._1._1
88-
newDstIds(i) = kv._1._2
89-
newData(i) = kv._2
82+
while (i < size) {
83+
if (i > 0 && currSrcId == srcIds(i) && currDstId == dstIds(i)) {
84+
currAttr = merge(currAttr, data(i))
85+
} else {
86+
if (i > 0) {
87+
builder.add(currSrcId, currDstId, currAttr)
88+
}
89+
currSrcId = srcIds(i)
90+
currDstId = dstIds(i)
91+
currAttr = data(i)
92+
}
9093
i += 1
9194
}
92-
new EdgePartition(newSrcIds, newDstIds, newData)
95+
if (size > 0) {
96+
builder.add(currSrcId, currDstId, currAttr)
97+
}
98+
builder.toEdgePartition
9399
}
94100

95101
/**
@@ -99,6 +105,9 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
99105
*/
100106
def size: Int = srcIds.size
101107

108+
/** The number of unique source vertices in the partition. */
109+
def indexSize: Int = index.size
110+
102111
/**
103112
* Get an iterator over the edges in this partition.
104113
*
@@ -118,4 +127,34 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
118127
edge
119128
}
120129
}
130+
131+
/**
132+
* Get an iterator over the edges in this partition whose source vertex ids match srcIdPred. The
133+
* iterator is generated using an index scan, so it is efficient at skipping edges that don't
134+
* match srcIdPred.
135+
*/
136+
def indexIterator(srcIdPred: Vid => Boolean): Iterator[Edge[ED]] =
137+
index.iterator.filter(kv => srcIdPred(kv._1)).flatMap(Function.tupled(clusterIterator))
138+
139+
/**
140+
* Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The
141+
* cluster must start at position `index`.
142+
*/
143+
private def clusterIterator(srcId: Vid, index: Int) = new Iterator[Edge[ED]] {
144+
private[this] val edge = new Edge[ED]
145+
private[this] var pos = index
146+
147+
override def hasNext: Boolean = {
148+
pos >= 0 && pos < EdgePartition.this.size && srcIds(pos) == srcId
149+
}
150+
151+
override def next(): Edge[ED] = {
152+
assert(srcIds(pos) == srcId)
153+
edge.srcId = srcIds(pos)
154+
edge.dstId = dstIds(pos)
155+
edge.attr = data(pos)
156+
pos += 1
157+
edge
158+
}
159+
}
121160
}
Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,45 @@
11
package org.apache.spark.graph.impl
22

3+
import scala.util.Sorting
4+
35
import org.apache.spark.graph._
4-
import org.apache.spark.util.collection.PrimitiveVector
6+
import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
57

68

79
//private[graph]
8-
class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassManifest] {
10+
class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassManifest](size: Int = 64) {
911

10-
val srcIds = new PrimitiveVector[Vid]
11-
val dstIds = new PrimitiveVector[Vid]
12-
var dataBuilder = new PrimitiveVector[ED]
12+
var edges = new PrimitiveVector[Edge[ED]](size)
1313

1414
/** Add a new edge to the partition. */
1515
def add(src: Vid, dst: Vid, d: ED) {
16-
srcIds += src
17-
dstIds += dst
18-
dataBuilder += d
16+
edges += Edge(src, dst, d)
1917
}
2018

2119
def toEdgePartition: EdgePartition[ED] = {
22-
new EdgePartition(srcIds.trim().array, dstIds.trim().array, dataBuilder.trim().array)
20+
val edgeArray = edges.trim().array
21+
Sorting.quickSort(edgeArray)(Edge.lexicographicOrdering)
22+
val srcIds = new Array[Vid](edgeArray.size)
23+
val dstIds = new Array[Vid](edgeArray.size)
24+
val data = new Array[ED](edgeArray.size)
25+
val index = new PrimitiveKeyOpenHashMap[Vid, Int]
26+
// Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and
27+
// adding them to the index
28+
if (edgeArray.length > 0) {
29+
index.update(srcIds(0), 0)
30+
var currSrcId: Vid = srcIds(0)
31+
var i = 0
32+
while (i < edgeArray.size) {
33+
srcIds(i) = edgeArray(i).srcId
34+
dstIds(i) = edgeArray(i).dstId
35+
data(i) = edgeArray(i).attr
36+
if (edgeArray(i).srcId != currSrcId) {
37+
currSrcId = edgeArray(i).srcId
38+
index.update(currSrcId, i)
39+
}
40+
i += 1
41+
}
42+
}
43+
new EdgePartition(srcIds, dstIds, data, index)
2344
}
2445
}

graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -245,37 +245,44 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
245245

246246
// Map and combine.
247247
val preAgg = edges.zipEdgePartitions(vs) { (edgePartition, vTableReplicatedIter) =>
248-
val (_, vertexPartition) = vTableReplicatedIter.next()
248+
val (_, vPart) = vTableReplicatedIter.next()
249+
250+
// Choose scan method
251+
val activeFraction = vPart.numActives.getOrElse(0) / edgePartition.indexSize.toFloat
252+
val edgeIter = activeDirectionOpt match {
253+
case Some(EdgeDirection.Both) =>
254+
if (activeFraction < 0.8) {
255+
edgePartition.indexIterator(srcVid => vPart.isActive(srcVid))
256+
.filter(e => vPart.isActive(e.dstId))
257+
} else {
258+
edgePartition.iterator.filter(e => vPart.isActive(e.srcId) && vPart.isActive(e.dstId))
259+
}
260+
case Some(EdgeDirection.Out) =>
261+
if (activeFraction < 0.8) {
262+
edgePartition.indexIterator(srcVid => vPart.isActive(srcVid))
263+
} else {
264+
edgePartition.iterator.filter(e => vPart.isActive(e.srcId))
265+
}
266+
case Some(EdgeDirection.In) =>
267+
edgePartition.iterator.filter(e => vPart.isActive(e.dstId))
268+
case None =>
269+
edgePartition.iterator
270+
}
249271

250-
// Iterate over the partition
272+
// Scan edges and run the map function
251273
val et = new EdgeTriplet[VD, ED]
252-
val filteredEdges = edgePartition.iterator.flatMap { e =>
253-
// Ensure the edge is adjacent to a vertex in activeSet if necessary
254-
val adjacent = activeDirectionOpt match {
255-
case Some(EdgeDirection.In) =>
256-
vertexPartition.isActive(e.dstId)
257-
case Some(EdgeDirection.Out) =>
258-
vertexPartition.isActive(e.srcId)
259-
case Some(EdgeDirection.Both) =>
260-
vertexPartition.isActive(e.srcId) && vertexPartition.isActive(e.dstId)
261-
case None =>
262-
true
274+
val mapOutputs = edgeIter.flatMap { e =>
275+
et.set(e)
276+
if (mapUsesSrcAttr) {
277+
et.srcAttr = vPart(e.srcId)
263278
}
264-
if (adjacent) {
265-
et.set(e)
266-
if (mapUsesSrcAttr) {
267-
et.srcAttr = vertexPartition(e.srcId)
268-
}
269-
if (mapUsesDstAttr) {
270-
et.dstAttr = vertexPartition(e.dstId)
271-
}
272-
mapFunc(et)
273-
} else {
274-
Iterator.empty
279+
if (mapUsesDstAttr) {
280+
et.dstAttr = vPart(e.dstId)
275281
}
282+
mapFunc(et)
276283
}
277284
// Note: This doesn't allow users to send messages to arbitrary vertices.
278-
vertexPartition.aggregateUsingIndex(filteredEdges, reduceFunc).iterator
285+
vPart.aggregateUsingIndex(mapOutputs, reduceFunc).iterator
279286
}
280287

281288
// do the final reduction reusing the index map

graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicated.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class VTableReplicated[VD: ClassManifest](
115115
for (i <- 0 until block.vids.size) {
116116
val vid = block.vids(i)
117117
val attr = block.attrs(i)
118-
val ind = vidToIndex.getPos(vid) & OpenHashSet.POSITION_MASK
118+
val ind = vidToIndex.getPos(vid)
119119
vertexArray(ind) = attr
120120
}
121121
}

graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ class VertexPartition[@specialized(Long, Int, Double) VD: ClassManifest](
5454
activeSet.get.contains(vid)
5555
}
5656

57+
/** The number of active vertices, if any exist. */
58+
def numActives: Option[Int] = activeSet.map(_.size)
59+
5760
/**
5861
* Pass each vertex attribute along with the vertex id through a map
5962
* function and retain the original RDD's partitioning and index.

graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package org.apache.spark.graph
22

3+
import scala.util.Random
4+
35
import org.scalatest.FunSuite
46

57
import org.apache.spark.SparkContext
68
import org.apache.spark.graph.LocalSparkContext._
9+
import org.apache.spark.graph.impl.EdgePartitionBuilder
710
import org.apache.spark.rdd._
811

912
class GraphSuite extends FunSuite with LocalSparkContext {
@@ -59,6 +62,13 @@ class GraphSuite extends FunSuite with LocalSparkContext {
5962
// mapVertices changing type
6063
val mappedVAttrs2 = reverseStar.mapVertices((vid, attr) => attr.length)
6164
assert(mappedVAttrs2.vertices.collect.toSet === (0 to n).map(x => (x: Vid, 1)).toSet)
65+
// groupEdges
66+
val doubleStar = Graph.fromEdgeTuples(
67+
sc.parallelize((1 to n).flatMap(x => List((0: Vid, x: Vid), (0: Vid, x: Vid))), 1), "v")
68+
val star2 = doubleStar.groupEdges { (a, b) => a}
69+
assert(star2.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]) ===
70+
star.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]))
71+
assert(star2.vertices.collect.toSet === star.vertices.collect.toSet)
6272
}
6373
}
6474

@@ -206,4 +216,19 @@ class GraphSuite extends FunSuite with LocalSparkContext {
206216
assert(subgraph.edges.map(_.copy()).collect().toSet === (2 to n by 2).map(x => Edge(0, x, 1)).toSet)
207217
}
208218
}
219+
220+
test("EdgePartition.sort") {
221+
val edgesFrom0 = List(Edge(0, 1, 0))
222+
val edgesFrom1 = List(Edge(1, 0, 0), Edge(1, 2, 0))
223+
val sortedEdges = edgesFrom0 ++ edgesFrom1
224+
val builder = new EdgePartitionBuilder[Int]
225+
for (e <- Random.shuffle(sortedEdges)) {
226+
builder.add(e.srcId, e.dstId, e.attr)
227+
}
228+
229+
val edgePartition = builder.toEdgePartition
230+
assert(edgePartition.iterator.map(_.copy()).toList === sortedEdges)
231+
assert(edgePartition.indexIterator(_ == 0).map(_.copy()).toList === edgesFrom0)
232+
assert(edgePartition.indexIterator(_ == 1).map(_.copy()).toList === edgesFrom1)
233+
}
209234
}

0 commit comments

Comments
 (0)