Skip to content

Commit a645ef6

Browse files
committed
Merge pull request apache#48 from amatsukawa/add_project_to_graph
Add mask operation on graph and filter graph primitive
2 parents 3fd2e09 + d7ebff0 commit a645ef6

File tree

6 files changed

+168
-2
lines changed

6 files changed

+168
-2
lines changed

graph/src/main/scala/org/apache/spark/graph/EdgeRDD.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class EdgeRDD[@specialized ED: ClassManifest](
4444
override def cache(): EdgeRDD[ED] = persist()
4545

4646
def mapEdgePartitions[ED2: ClassManifest](f: EdgePartition[ED] => EdgePartition[ED2])
47-
: EdgeRDD[ED2]= {
47+
: EdgeRDD[ED2] = {
4848
new EdgeRDD[ED2](partitionsRDD.mapPartitions({ iter =>
4949
val (pid, ep) = iter.next()
5050
Iterator(Tuple2(pid, f(ep)))
@@ -60,6 +60,27 @@ class EdgeRDD[@specialized ED: ClassManifest](
6060
}
6161
}
6262

63+
def zipEdgePartitions[ED2: ClassManifest, ED3: ClassManifest]
64+
(other: EdgeRDD[ED2])
65+
(f: (EdgePartition[ED], EdgePartition[ED2]) => EdgePartition[ED3]): EdgeRDD[ED3] = {
66+
new EdgeRDD[ED3](partitionsRDD.zipPartitions(other.partitionsRDD, preservesPartitioning = true) {
67+
(thisIter, otherIter) =>
68+
val (pid, thisEPart) = thisIter.next()
69+
val (_, otherEPart) = otherIter.next()
70+
Iterator(Tuple2(pid, f(thisEPart, otherEPart)))
71+
})
72+
}
73+
74+
def innerJoin[ED2: ClassManifest, ED3: ClassManifest]
75+
(other: EdgeRDD[ED2])
76+
(f: (Vid, Vid, ED, ED2) => ED3): EdgeRDD[ED3] = {
77+
val ed2Manifest = classManifest[ED2]
78+
val ed3Manifest = classManifest[ED3]
79+
zipEdgePartitions(other) { (thisEPart, otherEPart) =>
80+
thisEPart.innerJoin(otherEPart)(f)(ed2Manifest, ed3Manifest)
81+
}
82+
}
83+
6384
def collectVids(): RDD[Vid] = {
6485
partitionsRDD.flatMap { case (_, p) => Array.concat(p.srcIds, p.dstIds) }
6586
}

graph/src/main/scala/org/apache/spark/graph/Graph.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
4848
* along with their vertex data.
4949
*
5050
*/
51-
val edges: RDD[Edge[ED]]
51+
val edges: EdgeRDD[ED]
5252

5353
/**
5454
* Get the edges with the vertex data associated with the adjacent
@@ -197,6 +197,14 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
197197
def subgraph(epred: EdgeTriplet[VD,ED] => Boolean = (x => true),
198198
vpred: (Vid, VD) => Boolean = ((v,d) => true) ): Graph[VD, ED]
199199

200+
/**
201+
* Subgraph of this graph with only vertices and edges from the other graph.
202+
* @param other the graph to project this graph onto
203+
* @return a graph with vertices and edges that exists in both the current graph and other,
204+
* with vertex and edge data from the current graph.
205+
*/
206+
def mask[VD2: ClassManifest, ED2: ClassManifest](other: Graph[VD2, ED2]): Graph[VD, ED]
207+
200208
/**
201209
* This function merges multiple edges between two vertices into a single Edge. For correct
202210
* results, the graph must have been partitioned using partitionBy.

graph/src/main/scala/org/apache/spark/graph/GraphOps.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,4 +237,35 @@ class GraphOps[VD: ClassManifest, ED: ClassManifest](graph: Graph[VD, ED]) {
237237
graph.outerJoinVertices(table)(uf)
238238
}
239239

240+
/**
241+
* Filter the graph by computing some values to filter on, and applying the predicates.
242+
*
243+
* @param preprocess a function to compute new vertex and edge data before filtering
244+
* @param epred edge pred to filter on after preprocess, see more details under Graph#subgraph
245+
* @param vpred vertex pred to filter on after prerocess, see more details under Graph#subgraph
246+
* @tparam VD2 vertex type the vpred operates on
247+
* @tparam ED2 edge type the epred operates on
248+
* @return a subgraph of the orginal graph, with its data unchanged
249+
*
250+
* @example This function can be used to filter the graph based on some property, without
251+
* changing the vertex and edge values in your program. For example, we could remove the vertices
252+
* in a graph with 0 outdegree
253+
*
254+
* {{{
255+
* graph.filter(
256+
* graph => {
257+
* val degrees: VertexSetRDD[Int] = graph.outDegrees
258+
* graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)}
259+
* },
260+
* vpred = (vid: Vid, deg:Int) => deg > 0
261+
* )
262+
* }}}
263+
*
264+
*/
265+
def filter[VD2: ClassManifest, ED2: ClassManifest](
266+
preprocess: Graph[VD, ED] => Graph[VD2, ED2],
267+
epred: (EdgeTriplet[VD2, ED2]) => Boolean = (x: EdgeTriplet[VD2, ED2]) => true,
268+
vpred: (Vid, VD2) => Boolean = (v:Vid, d:VD2) => true): Graph[VD, ED] = {
269+
graph.mask(preprocess(graph).subgraph(epred, vpred))
270+
}
240271
} // end of GraphOps

graph/src/main/scala/org/apache/spark/graph/impl/EdgePartition.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,40 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
9898
builder.toEdgePartition
9999
}
100100

101+
/**
102+
* Apply `f` to all edges present in both `this` and `other` and return a new EdgePartition
103+
* containing the resulting edges.
104+
*
105+
* If there are multiple edges with the same src and dst in `this`, `f` will be invoked once for
106+
* each edge, but each time it may be invoked on any corresponding edge in `other`.
107+
*
108+
* If there are multiple edges with the same src and dst in `other`, `f` will only be invoked
109+
* once.
110+
*/
111+
def innerJoin[ED2: ClassManifest, ED3: ClassManifest]
112+
(other: EdgePartition[ED2])
113+
(f: (Vid, Vid, ED, ED2) => ED3): EdgePartition[ED3] = {
114+
val builder = new EdgePartitionBuilder[ED3]
115+
var i = 0
116+
var j = 0
117+
// For i = index of each edge in `this`...
118+
while (i < size && j < other.size) {
119+
val srcId = this.srcIds(i)
120+
val dstId = this.dstIds(i)
121+
// ... forward j to the index of the corresponding edge in `other`, and...
122+
while (j < other.size && other.srcIds(j) < srcId) { j += 1 }
123+
if (j < other.size && other.srcIds(j) == srcId) {
124+
while (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) < dstId) { j += 1 }
125+
if (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) == dstId) {
126+
// ... run `f` on the matching edge
127+
builder.add(srcId, dstId, f(srcId, dstId, this.data(i), other.data(j)))
128+
}
129+
}
130+
i += 1
131+
}
132+
builder.toEdgePartition
133+
}
134+
101135
/**
102136
* The number of edges in this partition
103137
*

graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,14 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
215215
new GraphImpl(newVTable, newETable)
216216
} // end of subgraph
217217

218+
override def mask[VD2: ClassManifest, ED2: ClassManifest] (
219+
other: Graph[VD2, ED2]): Graph[VD, ED] = {
220+
val newVerts = vertices.innerJoin(other.vertices) { (vid, v, w) => v }
221+
val newEdges = edges.innerJoin(other.edges) { (src, dst, v, w) => v }
222+
new GraphImpl(newVerts, newEdges)
223+
224+
}
225+
218226
override def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED] = {
219227
ClosureCleaner.clean(merge)
220228
val newETable = edges.mapEdgePartitions(_.groupEdges(merge))

graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ import scala.util.Random
55
import org.scalatest.FunSuite
66

77
import org.apache.spark.SparkContext
8+
import org.apache.spark.graph.Graph._
89
import org.apache.spark.graph.LocalSparkContext._
10+
import org.apache.spark.graph.impl.EdgePartition
911
import org.apache.spark.graph.impl.EdgePartitionBuilder
1012
import org.apache.spark.rdd._
1113

@@ -183,6 +185,53 @@ class GraphSuite extends FunSuite with LocalSparkContext {
183185
}
184186
}
185187

188+
test("mask") {
189+
withSpark(new SparkContext("local", "test")) { sc =>
190+
val n = 5
191+
val vertices = sc.parallelize((0 to n).map(x => (x:Vid, x)))
192+
val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x)))
193+
val graph: Graph[Int, Int] = Graph(vertices, edges)
194+
195+
val subgraph = graph.subgraph(
196+
e => e.dstId != 4L,
197+
(vid, vdata) => vid != 3L
198+
).mapVertices((vid, vdata) => -1).mapEdges(e => -1)
199+
200+
val projectedGraph = graph.mask(subgraph)
201+
202+
val v = projectedGraph.vertices.collect().toSet
203+
assert(v === Set((0,0), (1,1), (2,2), (4,4), (5,5)))
204+
205+
// the map is necessary because of object-reuse in the edge iterator
206+
val e = projectedGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet
207+
assert(e === Set(Edge(0,1,1), Edge(0,2,2), Edge(0,5,5)))
208+
209+
}
210+
}
211+
212+
test ("filter") {
213+
withSpark(new SparkContext("local", "test")) { sc =>
214+
val n = 5
215+
val vertices = sc.parallelize((0 to n).map(x => (x:Vid, x)))
216+
val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x)))
217+
val graph: Graph[Int, Int] = Graph(vertices, edges)
218+
val filteredGraph = graph.filter(
219+
graph => {
220+
val degrees: VertexRDD[Int] = graph.outDegrees
221+
graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)}
222+
},
223+
vpred = (vid: Vid, deg:Int) => deg > 0
224+
)
225+
226+
val v = filteredGraph.vertices.collect().toSet
227+
assert(v === Set((0,0)))
228+
229+
// the map is necessary because of object-reuse in the edge iterator
230+
val e = filteredGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet
231+
assert(e.isEmpty)
232+
}
233+
}
234+
186235
test("VertexSetRDD") {
187236
withSpark(new SparkContext("local", "test")) { sc =>
188237
val n = 100
@@ -231,4 +280,19 @@ class GraphSuite extends FunSuite with LocalSparkContext {
231280
assert(edgePartition.indexIterator(_ == 0).map(_.copy()).toList === edgesFrom0)
232281
assert(edgePartition.indexIterator(_ == 1).map(_.copy()).toList === edgesFrom1)
233282
}
283+
284+
test("EdgePartition.innerJoin") {
285+
def makeEdgePartition[A: ClassManifest](xs: Iterable[(Int, Int, A)]): EdgePartition[A] = {
286+
val builder = new EdgePartitionBuilder[A]
287+
for ((src, dst, attr) <- xs) { builder.add(src: Vid, dst: Vid, attr) }
288+
builder.toEdgePartition
289+
}
290+
val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
291+
val bList = List((0, 1, 0), (1, 0, 0), (1, 1, 0), (3, 4, 0), (5, 5, 0))
292+
val a = makeEdgePartition(aList)
293+
val b = makeEdgePartition(bList)
294+
295+
assert(a.innerJoin(b) { (src, dst, a, b) => a }.iterator.map(_.copy()).toList ===
296+
List(Edge(0, 1, 0), Edge(1, 0, 0), Edge(5, 5, 0)))
297+
}
234298
}

0 commit comments

Comments
 (0)