Skip to content

Commit 3fd2e09

Browse files
committed
Merge pull request apache#104 from jianpingjwang/master
SVD++ demo
2 parents 1b5eacb + 06581b6 commit 3fd2e09

File tree

2 files changed

+173
-0
lines changed

2 files changed

+173
-0
lines changed
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
package org.apache.spark.graph.algorithms
2+
3+
import org.apache.spark._
4+
import org.apache.spark.rdd._
5+
import org.apache.spark.graph._
6+
import scala.util.Random
7+
import org.apache.commons.math.linear._
8+
9+
class VT ( // vertex type
10+
var v1: RealVector, // v1: p for user node, q for item node
11+
var v2: RealVector, // v2: pu + |N(u)|^(-0.5)*sum(y) for user node, y for item node
12+
var bias: Double,
13+
var norm: Double // only for user node
14+
) extends Serializable
15+
16+
class Msg ( // message
17+
var v1: RealVector,
18+
var v2: RealVector,
19+
var bias: Double
20+
) extends Serializable
21+
22+
object Svdpp {
23+
// implement SVD++ based on http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf
24+
25+
def run(edges: RDD[Edge[Double]]): Graph[VT, Double] = {
26+
// defalut parameters
27+
val rank = 10
28+
val maxIters = 20
29+
val minVal = 0.0
30+
val maxVal = 5.0
31+
val gamma1 = 0.007
32+
val gamma2 = 0.007
33+
val gamma6 = 0.005
34+
val gamma7 = 0.015
35+
36+
def defaultF(rank: Int) = {
37+
val v1 = new ArrayRealVector(rank)
38+
val v2 = new ArrayRealVector(rank)
39+
for (i <- 0 until rank) {
40+
v1.setEntry(i, Random.nextDouble)
41+
v2.setEntry(i, Random.nextDouble)
42+
}
43+
var vd = new VT(v1, v2, 0.0, 0.0)
44+
vd
45+
}
46+
47+
// calculate initial norm and bias
48+
def mapF0(et: EdgeTriplet[VT, Double]): Iterator[(Vid, (Long, Double))] = {
49+
assert(et.srcAttr != null && et.dstAttr != null)
50+
Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr)))
51+
}
52+
def reduceF0(g1: (Long, Double), g2: (Long, Double)) = {
53+
(g1._1 + g2._1, g1._2 + g2._2)
54+
}
55+
def updateF0(vid: Vid, vd: VT, msg: Option[(Long, Double)]) = {
56+
if (msg.isDefined) {
57+
vd.bias = msg.get._2 / msg.get._1
58+
vd.norm = 1.0 / scala.math.sqrt(msg.get._1)
59+
}
60+
vd
61+
}
62+
63+
// calculate global rating mean
64+
val (rs, rc) = edges.map(e => (e.attr, 1L)).reduce((a, b) => (a._1 + b._1, a._2 + b._2))
65+
val u = rs / rc // global rating mean
66+
67+
// make graph
68+
var g = Graph.fromEdges(edges, defaultF(rank)).cache()
69+
70+
// calculate initial norm and bias
71+
val t0 = g.mapReduceTriplets(mapF0, reduceF0)
72+
g.outerJoinVertices(t0) {updateF0}
73+
74+
// phase 1
75+
def mapF1(et: EdgeTriplet[VT, Double]): Iterator[(Vid, RealVector)] = {
76+
assert(et.srcAttr != null && et.dstAttr != null)
77+
Iterator((et.srcId, et.dstAttr.v2)) // sum up y of connected item nodes
78+
}
79+
def reduceF1(g1: RealVector, g2: RealVector) = {
80+
g1.add(g2)
81+
}
82+
def updateF1(vid: Vid, vd: VT, msg: Option[RealVector]) = {
83+
if (msg.isDefined) {
84+
vd.v2 = vd.v1.add(msg.get.mapMultiply(vd.norm)) // pu + |N(u)|^(-0.5)*sum(y)
85+
}
86+
vd
87+
}
88+
89+
// phase 2
90+
def mapF2(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Msg)] = {
91+
assert(et.srcAttr != null && et.dstAttr != null)
92+
val usr = et.srcAttr
93+
val itm = et.dstAttr
94+
var p = usr.v1
95+
var q = itm.v1
96+
val itmBias = 0.0
97+
val usrBias = 0.0
98+
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
99+
pred = math.max(pred, minVal)
100+
pred = math.min(pred, maxVal)
101+
val err = et.attr - pred
102+
val y = (q.mapMultiply(err*usr.norm)).subtract((usr.v2).mapMultiply(gamma7))
103+
val newP = (q.mapMultiply(err)).subtract(p.mapMultiply(gamma7)) // for each connected item q
104+
val newQ = (usr.v2.mapMultiply(err)).subtract(q.mapMultiply(gamma7))
105+
Iterator((et.srcId, new Msg(newP, y, err - gamma6*usr.bias)), (et.dstId, new Msg(newQ, y, err - gamma6*itm.bias)))
106+
}
107+
def reduceF2(g1: Msg, g2: Msg):Msg = {
108+
g1.v1 = g1.v1.add(g2.v1)
109+
g1.v2 = g1.v2.add(g2.v2)
110+
g1.bias += g2.bias
111+
g1
112+
}
113+
def updateF2(vid: Vid, vd: VT, msg: Option[Msg]) = {
114+
if (msg.isDefined) {
115+
vd.v1 = vd.v1.add(msg.get.v1.mapMultiply(gamma2))
116+
if (vid % 2 == 1) { // item node update y
117+
vd.v2 = vd.v2.add(msg.get.v2.mapMultiply(gamma2))
118+
}
119+
vd.bias += msg.get.bias*gamma1
120+
}
121+
vd
122+
}
123+
124+
for (i <- 0 until maxIters) {
125+
// phase 1
126+
val t1: VertexRDD[RealVector] = g.mapReduceTriplets(mapF1, reduceF1)
127+
g.outerJoinVertices(t1) {updateF1}
128+
// phase 2
129+
val t2: VertexRDD[Msg] = g.mapReduceTriplets(mapF2, reduceF2)
130+
g.outerJoinVertices(t2) {updateF2}
131+
}
132+
133+
// calculate error on training set
134+
def mapF3(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Double)] = {
135+
assert(et.srcAttr != null && et.dstAttr != null)
136+
val usr = et.srcAttr
137+
val itm = et.dstAttr
138+
var p = usr.v1
139+
var q = itm.v1
140+
val itmBias = 0.0
141+
val usrBias = 0.0
142+
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
143+
pred = math.max(pred, minVal)
144+
pred = math.min(pred, maxVal)
145+
val err = (et.attr - pred)*(et.attr - pred)
146+
Iterator((et.dstId, err))
147+
}
148+
def updateF3(vid: Vid, vd: VT, msg: Option[Double]) = {
149+
if (msg.isDefined && vid % 2 == 1) { // item sum up the errors
150+
vd.norm = msg.get
151+
}
152+
vd
153+
}
154+
val t3: VertexRDD[Double] = g.mapReduceTriplets(mapF3, _ + _)
155+
g.outerJoinVertices(t3) {updateF3}
156+
g
157+
}
158+
}

graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,4 +257,19 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
257257
verts.collect.foreach { case (vid, count) => assert(count === 1) }
258258
}
259259
}
260+
261+
test("Test SVD++ with mean square error on training set") {
262+
withSpark(new SparkContext("local", "test")) { sc =>
263+
val SvdppErr = 0.01
264+
val edges = sc.textFile("mllib/data/als/test.data").map { line =>
265+
val fields = line.split(",")
266+
Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
267+
}
268+
val graph = Svdpp.run(edges)
269+
val err = graph.vertices.collect.map{ case (vid, vd) =>
270+
if (vid % 2 == 1) { vd.norm } else { 0.0 }
271+
}.reduce(_ + _) / graph.triplets.collect.size
272+
assert(err < SvdppErr)
273+
}
274+
}
260275
} // end of AnalyticsSuite

0 commit comments

Comments
 (0)