|
| 1 | +package org.apache.spark.graph.algorithms |
| 2 | + |
| 3 | +import org.apache.spark._ |
| 4 | +import org.apache.spark.rdd._ |
| 5 | +import org.apache.spark.graph._ |
| 6 | +import scala.util.Random |
| 7 | +import org.apache.commons.math.linear._ |
| 8 | + |
| 9 | +class VT ( // vertex type |
| 10 | + var v1: RealVector, // v1: p for user node, q for item node |
| 11 | + var v2: RealVector, // v2: pu + |N(u)|^(-0.5)*sum(y) for user node, y for item node |
| 12 | + var bias: Double, |
| 13 | + var norm: Double // only for user node |
| 14 | +) extends Serializable |
| 15 | + |
| 16 | +class Msg ( // message |
| 17 | + var v1: RealVector, |
| 18 | + var v2: RealVector, |
| 19 | + var bias: Double |
| 20 | +) extends Serializable |
| 21 | + |
| 22 | +object Svdpp { |
| 23 | + // implement SVD++ based on http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf |
| 24 | + |
| 25 | + def run(edges: RDD[Edge[Double]]): Graph[VT, Double] = { |
| 26 | + // defalut parameters |
| 27 | + val rank = 10 |
| 28 | + val maxIters = 20 |
| 29 | + val minVal = 0.0 |
| 30 | + val maxVal = 5.0 |
| 31 | + val gamma1 = 0.007 |
| 32 | + val gamma2 = 0.007 |
| 33 | + val gamma6 = 0.005 |
| 34 | + val gamma7 = 0.015 |
| 35 | + |
| 36 | + def defaultF(rank: Int) = { |
| 37 | + val v1 = new ArrayRealVector(rank) |
| 38 | + val v2 = new ArrayRealVector(rank) |
| 39 | + for (i <- 0 until rank) { |
| 40 | + v1.setEntry(i, Random.nextDouble) |
| 41 | + v2.setEntry(i, Random.nextDouble) |
| 42 | + } |
| 43 | + var vd = new VT(v1, v2, 0.0, 0.0) |
| 44 | + vd |
| 45 | + } |
| 46 | + |
| 47 | + // calculate initial norm and bias |
| 48 | + def mapF0(et: EdgeTriplet[VT, Double]): Iterator[(Vid, (Long, Double))] = { |
| 49 | + assert(et.srcAttr != null && et.dstAttr != null) |
| 50 | + Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))) |
| 51 | + } |
| 52 | + def reduceF0(g1: (Long, Double), g2: (Long, Double)) = { |
| 53 | + (g1._1 + g2._1, g1._2 + g2._2) |
| 54 | + } |
| 55 | + def updateF0(vid: Vid, vd: VT, msg: Option[(Long, Double)]) = { |
| 56 | + if (msg.isDefined) { |
| 57 | + vd.bias = msg.get._2 / msg.get._1 |
| 58 | + vd.norm = 1.0 / scala.math.sqrt(msg.get._1) |
| 59 | + } |
| 60 | + vd |
| 61 | + } |
| 62 | + |
| 63 | + // calculate global rating mean |
| 64 | + val (rs, rc) = edges.map(e => (e.attr, 1L)).reduce((a, b) => (a._1 + b._1, a._2 + b._2)) |
| 65 | + val u = rs / rc // global rating mean |
| 66 | + |
| 67 | + // make graph |
| 68 | + var g = Graph.fromEdges(edges, defaultF(rank)).cache() |
| 69 | + |
| 70 | + // calculate initial norm and bias |
| 71 | + val t0 = g.mapReduceTriplets(mapF0, reduceF0) |
| 72 | + g.outerJoinVertices(t0) {updateF0} |
| 73 | + |
| 74 | + // phase 1 |
| 75 | + def mapF1(et: EdgeTriplet[VT, Double]): Iterator[(Vid, RealVector)] = { |
| 76 | + assert(et.srcAttr != null && et.dstAttr != null) |
| 77 | + Iterator((et.srcId, et.dstAttr.v2)) // sum up y of connected item nodes |
| 78 | + } |
| 79 | + def reduceF1(g1: RealVector, g2: RealVector) = { |
| 80 | + g1.add(g2) |
| 81 | + } |
| 82 | + def updateF1(vid: Vid, vd: VT, msg: Option[RealVector]) = { |
| 83 | + if (msg.isDefined) { |
| 84 | + vd.v2 = vd.v1.add(msg.get.mapMultiply(vd.norm)) // pu + |N(u)|^(-0.5)*sum(y) |
| 85 | + } |
| 86 | + vd |
| 87 | + } |
| 88 | + |
| 89 | + // phase 2 |
| 90 | + def mapF2(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Msg)] = { |
| 91 | + assert(et.srcAttr != null && et.dstAttr != null) |
| 92 | + val usr = et.srcAttr |
| 93 | + val itm = et.dstAttr |
| 94 | + var p = usr.v1 |
| 95 | + var q = itm.v1 |
| 96 | + val itmBias = 0.0 |
| 97 | + val usrBias = 0.0 |
| 98 | + var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2) |
| 99 | + pred = math.max(pred, minVal) |
| 100 | + pred = math.min(pred, maxVal) |
| 101 | + val err = et.attr - pred |
| 102 | + val y = (q.mapMultiply(err*usr.norm)).subtract((usr.v2).mapMultiply(gamma7)) |
| 103 | + val newP = (q.mapMultiply(err)).subtract(p.mapMultiply(gamma7)) // for each connected item q |
| 104 | + val newQ = (usr.v2.mapMultiply(err)).subtract(q.mapMultiply(gamma7)) |
| 105 | + Iterator((et.srcId, new Msg(newP, y, err - gamma6*usr.bias)), (et.dstId, new Msg(newQ, y, err - gamma6*itm.bias))) |
| 106 | + } |
| 107 | + def reduceF2(g1: Msg, g2: Msg):Msg = { |
| 108 | + g1.v1 = g1.v1.add(g2.v1) |
| 109 | + g1.v2 = g1.v2.add(g2.v2) |
| 110 | + g1.bias += g2.bias |
| 111 | + g1 |
| 112 | + } |
| 113 | + def updateF2(vid: Vid, vd: VT, msg: Option[Msg]) = { |
| 114 | + if (msg.isDefined) { |
| 115 | + vd.v1 = vd.v1.add(msg.get.v1.mapMultiply(gamma2)) |
| 116 | + if (vid % 2 == 1) { // item node update y |
| 117 | + vd.v2 = vd.v2.add(msg.get.v2.mapMultiply(gamma2)) |
| 118 | + } |
| 119 | + vd.bias += msg.get.bias*gamma1 |
| 120 | + } |
| 121 | + vd |
| 122 | + } |
| 123 | + |
| 124 | + for (i <- 0 until maxIters) { |
| 125 | + // phase 1 |
| 126 | + val t1: VertexRDD[RealVector] = g.mapReduceTriplets(mapF1, reduceF1) |
| 127 | + g.outerJoinVertices(t1) {updateF1} |
| 128 | + // phase 2 |
| 129 | + val t2: VertexRDD[Msg] = g.mapReduceTriplets(mapF2, reduceF2) |
| 130 | + g.outerJoinVertices(t2) {updateF2} |
| 131 | + } |
| 132 | + |
| 133 | + // calculate error on training set |
| 134 | + def mapF3(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Double)] = { |
| 135 | + assert(et.srcAttr != null && et.dstAttr != null) |
| 136 | + val usr = et.srcAttr |
| 137 | + val itm = et.dstAttr |
| 138 | + var p = usr.v1 |
| 139 | + var q = itm.v1 |
| 140 | + val itmBias = 0.0 |
| 141 | + val usrBias = 0.0 |
| 142 | + var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2) |
| 143 | + pred = math.max(pred, minVal) |
| 144 | + pred = math.min(pred, maxVal) |
| 145 | + val err = (et.attr - pred)*(et.attr - pred) |
| 146 | + Iterator((et.dstId, err)) |
| 147 | + } |
| 148 | + def updateF3(vid: Vid, vd: VT, msg: Option[Double]) = { |
| 149 | + if (msg.isDefined && vid % 2 == 1) { // item sum up the errors |
| 150 | + vd.norm = msg.get |
| 151 | + } |
| 152 | + vd |
| 153 | + } |
| 154 | + val t3: VertexRDD[Double] = g.mapReduceTriplets(mapF3, _ + _) |
| 155 | + g.outerJoinVertices(t3) {updateF3} |
| 156 | + g |
| 157 | + } |
| 158 | +} |
0 commit comments