Skip to content

Commit 1b5eacb

Browse files
committed
Merge pull request apache#102 from ankurdave/clustered-edge-index
Add clustered index on edges by source vertex
2 parents 0476c84 + 3ade8be commit 1b5eacb

File tree

7 files changed

+167
-65
lines changed

7 files changed

+167
-65
lines changed

graph/src/main/scala/org/apache/spark/graph/Edge.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,10 @@ case class Edge[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED]
4141
def relativeDirection(vid: Vid): EdgeDirection =
4242
if (vid == srcId) EdgeDirection.Out else { assert(vid == dstId); EdgeDirection.In }
4343
}
44+
45+
object Edge {
46+
def lexicographicOrdering[ED] = new Ordering[Edge[ED]] {
47+
override def compare(a: Edge[ED], b: Edge[ED]): Int =
48+
Ordering[(Vid, Vid)].compare((a.srcId, a.dstId), (b.srcId, b.dstId))
49+
}
50+
}

graph/src/main/scala/org/apache/spark/graph/impl/EdgePartition.scala

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,36 @@
11
package org.apache.spark.graph.impl
22

33
import org.apache.spark.graph._
4-
import org.apache.spark.util.collection.OpenHashMap
4+
import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap
55

66
/**
7-
* A collection of edges stored in 3 large columnar arrays (src, dst, attribute).
7+
* A collection of edges stored in 3 large columnar arrays (src, dst, attribute). The arrays are
8+
* clustered by src.
89
*
910
* @param srcIds the source vertex id of each edge
1011
* @param dstIds the destination vertex id of each edge
1112
* @param data the attribute associated with each edge
13+
* @param index a clustered index on source vertex id
1214
* @tparam ED the edge attribute type.
1315
*/
1416
class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassManifest](
1517
val srcIds: Array[Vid],
1618
val dstIds: Array[Vid],
17-
val data: Array[ED]) {
19+
val data: Array[ED],
20+
val index: PrimitiveKeyOpenHashMap[Vid, Int]) {
1821

1922
/**
2023
* Reverse all the edges in this partition.
2124
*
22-
* @note No new data structures are created.
23-
*
2425
* @return a new edge partition with all edges reversed.
2526
*/
26-
def reverse: EdgePartition[ED] = new EdgePartition(dstIds, srcIds, data)
27+
def reverse: EdgePartition[ED] = {
28+
val builder = new EdgePartitionBuilder(size)
29+
for (e <- iterator) {
30+
builder.add(e.dstId, e.srcId, e.attr)
31+
}
32+
builder.toEdgePartition
33+
}
2734

2835
/**
2936
* Construct a new edge partition by applying the function f to all
@@ -46,25 +53,16 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
4653
newData(i) = f(edge)
4754
i += 1
4855
}
49-
new EdgePartition(srcIds, dstIds, newData)
56+
new EdgePartition(srcIds, dstIds, newData, index)
5057
}
5158

5259
/**
5360
* Apply the function f to all edges in this partition.
5461
*
5562
* @param f an external state mutating user defined function.
5663
*/
57-
def foreach(f: Edge[ED] => Unit) {
58-
val edge = new Edge[ED]
59-
val size = data.size
60-
var i = 0
61-
while (i < size) {
62-
edge.srcId = srcIds(i)
63-
edge.dstId = dstIds(i)
64-
edge.attr = data(i)
65-
f(edge)
66-
i += 1
67-
}
64+
def foreach(f: Edge[ED] => Unit) {
65+
iterator.foreach(f)
6866
}
6967

7068
/**
@@ -75,21 +73,29 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
7573
* @return a new edge partition without duplicate edges
7674
*/
7775
def groupEdges(merge: (ED, ED) => ED): EdgePartition[ED] = {
78-
// Aggregate all matching edges in a hashmap
79-
val agg = new OpenHashMap[(Vid,Vid), ED]
80-
foreach { e => agg.setMerge((e.srcId, e.dstId), e.attr, merge) }
81-
// Populate new srcId, dstId, and data, arrays
82-
val newSrcIds = new Array[Vid](agg.size)
83-
val newDstIds = new Array[Vid](agg.size)
84-
val newData = new Array[ED](agg.size)
76+
val builder = new EdgePartitionBuilder[ED]
77+
var firstIter: Boolean = true
78+
var currSrcId: Vid = nullValue[Vid]
79+
var currDstId: Vid = nullValue[Vid]
80+
var currAttr: ED = nullValue[ED]
8581
var i = 0
86-
agg.foreach { kv =>
87-
newSrcIds(i) = kv._1._1
88-
newDstIds(i) = kv._1._2
89-
newData(i) = kv._2
82+
while (i < size) {
83+
if (i > 0 && currSrcId == srcIds(i) && currDstId == dstIds(i)) {
84+
currAttr = merge(currAttr, data(i))
85+
} else {
86+
if (i > 0) {
87+
builder.add(currSrcId, currDstId, currAttr)
88+
}
89+
currSrcId = srcIds(i)
90+
currDstId = dstIds(i)
91+
currAttr = data(i)
92+
}
9093
i += 1
9194
}
92-
new EdgePartition(newSrcIds, newDstIds, newData)
95+
if (size > 0) {
96+
builder.add(currSrcId, currDstId, currAttr)
97+
}
98+
builder.toEdgePartition
9399
}
94100

95101
/**
@@ -99,6 +105,9 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
99105
*/
100106
def size: Int = srcIds.size
101107

108+
/** The number of unique source vertices in the partition. */
109+
def indexSize: Int = index.size
110+
102111
/**
103112
* Get an iterator over the edges in this partition.
104113
*
@@ -118,4 +127,34 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
118127
edge
119128
}
120129
}
130+
131+
/**
132+
* Get an iterator over the edges in this partition whose source vertex ids match srcIdPred. The
133+
* iterator is generated using an index scan, so it is efficient at skipping edges that don't
134+
* match srcIdPred.
135+
*/
136+
def indexIterator(srcIdPred: Vid => Boolean): Iterator[Edge[ED]] =
137+
index.iterator.filter(kv => srcIdPred(kv._1)).flatMap(Function.tupled(clusterIterator))
138+
139+
/**
140+
* Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The
141+
* cluster must start at position `index`.
142+
*/
143+
private def clusterIterator(srcId: Vid, index: Int) = new Iterator[Edge[ED]] {
144+
private[this] val edge = new Edge[ED]
145+
private[this] var pos = index
146+
147+
override def hasNext: Boolean = {
148+
pos >= 0 && pos < EdgePartition.this.size && srcIds(pos) == srcId
149+
}
150+
151+
override def next(): Edge[ED] = {
152+
assert(srcIds(pos) == srcId)
153+
edge.srcId = srcIds(pos)
154+
edge.dstId = dstIds(pos)
155+
edge.attr = data(pos)
156+
pos += 1
157+
edge
158+
}
159+
}
121160
}
Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,45 @@
11
package org.apache.spark.graph.impl
22

3+
import scala.util.Sorting
4+
35
import org.apache.spark.graph._
4-
import org.apache.spark.util.collection.PrimitiveVector
6+
import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
57

68

79
//private[graph]
8-
class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassManifest] {
10+
class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassManifest](size: Int = 64) {
911

10-
val srcIds = new PrimitiveVector[Vid]
11-
val dstIds = new PrimitiveVector[Vid]
12-
var dataBuilder = new PrimitiveVector[ED]
12+
var edges = new PrimitiveVector[Edge[ED]](size)
1313

1414
/** Add a new edge to the partition. */
1515
def add(src: Vid, dst: Vid, d: ED) {
16-
srcIds += src
17-
dstIds += dst
18-
dataBuilder += d
16+
edges += Edge(src, dst, d)
1917
}
2018

2119
def toEdgePartition: EdgePartition[ED] = {
22-
new EdgePartition(srcIds.trim().array, dstIds.trim().array, dataBuilder.trim().array)
20+
val edgeArray = edges.trim().array
21+
Sorting.quickSort(edgeArray)(Edge.lexicographicOrdering)
22+
val srcIds = new Array[Vid](edgeArray.size)
23+
val dstIds = new Array[Vid](edgeArray.size)
24+
val data = new Array[ED](edgeArray.size)
25+
val index = new PrimitiveKeyOpenHashMap[Vid, Int]
26+
// Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and
27+
// adding them to the index
28+
if (edgeArray.length > 0) {
29+
index.update(srcIds(0), 0)
30+
var currSrcId: Vid = srcIds(0)
31+
var i = 0
32+
while (i < edgeArray.size) {
33+
srcIds(i) = edgeArray(i).srcId
34+
dstIds(i) = edgeArray(i).dstId
35+
data(i) = edgeArray(i).attr
36+
if (edgeArray(i).srcId != currSrcId) {
37+
currSrcId = edgeArray(i).srcId
38+
index.update(currSrcId, i)
39+
}
40+
i += 1
41+
}
42+
}
43+
new EdgePartition(srcIds, dstIds, data, index)
2344
}
2445
}

graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -245,37 +245,44 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
245245

246246
// Map and combine.
247247
val preAgg = edges.zipEdgePartitions(vs) { (edgePartition, vTableReplicatedIter) =>
248-
val (_, vertexPartition) = vTableReplicatedIter.next()
248+
val (_, vPart) = vTableReplicatedIter.next()
249+
250+
// Choose scan method
251+
val activeFraction = vPart.numActives.getOrElse(0) / edgePartition.indexSize.toFloat
252+
val edgeIter = activeDirectionOpt match {
253+
case Some(EdgeDirection.Both) =>
254+
if (activeFraction < 0.8) {
255+
edgePartition.indexIterator(srcVid => vPart.isActive(srcVid))
256+
.filter(e => vPart.isActive(e.dstId))
257+
} else {
258+
edgePartition.iterator.filter(e => vPart.isActive(e.srcId) && vPart.isActive(e.dstId))
259+
}
260+
case Some(EdgeDirection.Out) =>
261+
if (activeFraction < 0.8) {
262+
edgePartition.indexIterator(srcVid => vPart.isActive(srcVid))
263+
} else {
264+
edgePartition.iterator.filter(e => vPart.isActive(e.srcId))
265+
}
266+
case Some(EdgeDirection.In) =>
267+
edgePartition.iterator.filter(e => vPart.isActive(e.dstId))
268+
case None =>
269+
edgePartition.iterator
270+
}
249271

250-
// Iterate over the partition
272+
// Scan edges and run the map function
251273
val et = new EdgeTriplet[VD, ED]
252-
val filteredEdges = edgePartition.iterator.flatMap { e =>
253-
// Ensure the edge is adjacent to a vertex in activeSet if necessary
254-
val adjacent = activeDirectionOpt match {
255-
case Some(EdgeDirection.In) =>
256-
vertexPartition.isActive(e.dstId)
257-
case Some(EdgeDirection.Out) =>
258-
vertexPartition.isActive(e.srcId)
259-
case Some(EdgeDirection.Both) =>
260-
vertexPartition.isActive(e.srcId) && vertexPartition.isActive(e.dstId)
261-
case None =>
262-
true
274+
val mapOutputs = edgeIter.flatMap { e =>
275+
et.set(e)
276+
if (mapUsesSrcAttr) {
277+
et.srcAttr = vPart(e.srcId)
263278
}
264-
if (adjacent) {
265-
et.set(e)
266-
if (mapUsesSrcAttr) {
267-
et.srcAttr = vertexPartition(e.srcId)
268-
}
269-
if (mapUsesDstAttr) {
270-
et.dstAttr = vertexPartition(e.dstId)
271-
}
272-
mapFunc(et)
273-
} else {
274-
Iterator.empty
279+
if (mapUsesDstAttr) {
280+
et.dstAttr = vPart(e.dstId)
275281
}
282+
mapFunc(et)
276283
}
277284
// Note: This doesn't allow users to send messages to arbitrary vertices.
278-
vertexPartition.aggregateUsingIndex(filteredEdges, reduceFunc).iterator
285+
vPart.aggregateUsingIndex(mapOutputs, reduceFunc).iterator
279286
}
280287

281288
// do the final reduction reusing the index map

graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicated.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class VTableReplicated[VD: ClassManifest](
115115
for (i <- 0 until block.vids.size) {
116116
val vid = block.vids(i)
117117
val attr = block.attrs(i)
118-
val ind = vidToIndex.getPos(vid) & OpenHashSet.POSITION_MASK
118+
val ind = vidToIndex.getPos(vid)
119119
vertexArray(ind) = attr
120120
}
121121
}

graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ class VertexPartition[@specialized(Long, Int, Double) VD: ClassManifest](
5454
activeSet.get.contains(vid)
5555
}
5656

57+
/** The number of active vertices, if any exist. */
58+
def numActives: Option[Int] = activeSet.map(_.size)
59+
5760
/**
5861
* Pass each vertex attribute along with the vertex id through a map
5962
* function and retain the original RDD's partitioning and index.

graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package org.apache.spark.graph
22

3+
import scala.util.Random
4+
35
import org.scalatest.FunSuite
46

57
import org.apache.spark.SparkContext
68
import org.apache.spark.graph.LocalSparkContext._
9+
import org.apache.spark.graph.impl.EdgePartitionBuilder
710
import org.apache.spark.rdd._
811

912
class GraphSuite extends FunSuite with LocalSparkContext {
@@ -59,6 +62,13 @@ class GraphSuite extends FunSuite with LocalSparkContext {
5962
// mapVertices changing type
6063
val mappedVAttrs2 = reverseStar.mapVertices((vid, attr) => attr.length)
6164
assert(mappedVAttrs2.vertices.collect.toSet === (0 to n).map(x => (x: Vid, 1)).toSet)
65+
// groupEdges
66+
val doubleStar = Graph.fromEdgeTuples(
67+
sc.parallelize((1 to n).flatMap(x => List((0: Vid, x: Vid), (0: Vid, x: Vid))), 1), "v")
68+
val star2 = doubleStar.groupEdges { (a, b) => a}
69+
assert(star2.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]) ===
70+
star.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]))
71+
assert(star2.vertices.collect.toSet === star.vertices.collect.toSet)
6272
}
6373
}
6474

@@ -206,4 +216,19 @@ class GraphSuite extends FunSuite with LocalSparkContext {
206216
assert(subgraph.edges.map(_.copy()).collect().toSet === (2 to n by 2).map(x => Edge(0, x, 1)).toSet)
207217
}
208218
}
219+
220+
test("EdgePartition.sort") {
221+
val edgesFrom0 = List(Edge(0, 1, 0))
222+
val edgesFrom1 = List(Edge(1, 0, 0), Edge(1, 2, 0))
223+
val sortedEdges = edgesFrom0 ++ edgesFrom1
224+
val builder = new EdgePartitionBuilder[Int]
225+
for (e <- Random.shuffle(sortedEdges)) {
226+
builder.add(e.srcId, e.dstId, e.attr)
227+
}
228+
229+
val edgePartition = builder.toEdgePartition
230+
assert(edgePartition.iterator.map(_.copy()).toList === sortedEdges)
231+
assert(edgePartition.indexIterator(_ == 0).map(_.copy()).toList === edgesFrom0)
232+
assert(edgePartition.indexIterator(_ == 1).map(_.copy()).toList === edgesFrom1)
233+
}
209234
}

0 commit comments

Comments
 (0)