Skip to content

Commit 6340a18

Browse files
srowenmengxr
authored andcommitted
MLLIB-22. Support negative implicit input in ALS
I'm back with another less trivial suggestion for ALS: In ALS for implicit feedback, input values are treated as weights on squared-errors in a loss function (or rather, the weight is a simple function of the input r, like c = 1 + alpha*r). The paper on which it's based assumes that the input is positive. Indeed, if the input is negative, it will create a negative weight on squared-errors, which causes things to go haywire. The optimization will try to make the error in a cell as large possible, and the result is silently bogus. There is a good use case for negative input values though. Implicit feedback is usually collected from signals of positive interaction like a view or like or buy, but equally, can come from "not interested" signals. The natural representation is negative values. The algorithm can be extended quite simply to provide a sound interpretation of these values: negative values should encourage the factorization to come up with 0 for cells with large negative input values, just as much as positive values encourage it to come up with 1. The implications for the algorithm are simple: * the confidence function value must not be negative, and so can become 1 + alpha*|r| * the matrix P should have a value 1 where the input R is _positive_, not merely where it is non-zero. Actually, that's what the paper already says, it's just that we can't assume P = 1 when a cell in R is specified anymore, since it may be negative This in turn entails just a few lines of code change in `ALS.scala`: * `rs(i)` becomes `abs(rs(i))` * When constructing `userXy(us(i))`, it's implicitly only adding where P is 1. That had been true for any us(i) that is iterated over, before, since these are exactly the ones for which P is 1. But now P is zero where rs(i) <= 0, and should not be added I think it's a safe change because: * It doesn't change any existing behavior (unless you're using negative values, in which case results are already borked) * It's the simplest direct extension of the paper's algorithm * (I've used it to good effect in production FWIW) Tests included. I tweaked minor things en route: * `ALS.scala` javadoc writes "R = Xt*Y" when the paper and rest of code defines it as "R = X*Yt" * RMSE in the ALS tests uses a confidence-weighted mean, but the denominator is not actually sum of weights Excuse my Scala style; I'm sure it needs tweaks. Author: Sean Owen <sowen@cloudera.com> Closes #500 from srowen/ALSNegativeImplicitInput and squashes the following commits: cf902a9 [Sean Owen] Support negative implicit input in ALS 953be1c [Sean Owen] Make weighted RMSE in ALS test actually weighted; adjust comment about R = X*Yt
1 parent f27441a commit 6340a18

File tree

3 files changed

+52
-21
lines changed

3 files changed

+52
-21
lines changed

mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ case class Rating(val user: Int, val product: Int, val rating: Double)
6464
* Alternating Least Squares matrix factorization.
6565
*
6666
* ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices,
67-
* `X` and `Y`, i.e. `Xt * Y = R`. Typically these approximations are called 'factor' matrices.
67+
* `X` and `Y`, i.e. `X * Yt = R`. Typically these approximations are called 'factor' matrices.
6868
* The general approach is iterative. During each iteration, one of the factor matrices is held
6969
* constant, while the other is solved for using least squares. The newly-solved factor matrix is
7070
* then held constant while solving for the other factor matrix.
@@ -381,8 +381,16 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
381381
userXtX(us(i)).addi(tempXtX)
382382
SimpleBlas.axpy(rs(i), x, userXy(us(i)))
383383
case true =>
384-
userXtX(us(i)).addi(tempXtX.mul(alpha * rs(i)))
385-
SimpleBlas.axpy(1 + alpha * rs(i), x, userXy(us(i)))
384+
// Extension to the original paper to handle rs(i) < 0. confidence is a function
385+
// of |rs(i)| instead so that it is never negative:
386+
val confidence = 1 + alpha * abs(rs(i))
387+
userXtX(us(i)).addi(tempXtX.mul(confidence - 1))
388+
// For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i)
389+
// means we try to reconstruct 0. We add terms only where P = 1, so, term below
390+
// is now only added for rs(i) > 0:
391+
if (rs(i) > 0) {
392+
SimpleBlas.axpy(confidence, x, userXy(us(i)))
393+
}
386394
}
387395
}
388396
}

mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import java.io.Serializable;
2121
import java.util.List;
22-
import java.lang.Math;
2322

2423
import org.junit.After;
2524
import org.junit.Assert;
@@ -46,7 +45,7 @@ public void tearDown() {
4645
System.clearProperty("spark.driver.port");
4746
}
4847

49-
void validatePrediction(MatrixFactorizationModel model, int users, int products, int features,
48+
static void validatePrediction(MatrixFactorizationModel model, int users, int products, int features,
5049
DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) {
5150
DoubleMatrix predictedU = new DoubleMatrix(users, features);
5251
List<scala.Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
@@ -84,15 +83,15 @@ void validatePrediction(MatrixFactorizationModel model, int users, int products,
8483
for (int p = 0; p < products; ++p) {
8584
double prediction = predictedRatings.get(u, p);
8685
double truePref = truePrefs.get(u, p);
87-
double confidence = 1.0 + /* alpha = */ 1.0 * trueRatings.get(u, p);
86+
double confidence = 1.0 + /* alpha = */ 1.0 * Math.abs(trueRatings.get(u, p));
8887
double err = confidence * (truePref - prediction) * (truePref - prediction);
8988
sqErr += err;
90-
denom += 1.0;
89+
denom += confidence;
9190
}
9291
}
9392
double rmse = Math.sqrt(sqErr / denom);
9493
Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f",
95-
rmse, matchThreshold), Math.abs(rmse) < matchThreshold);
94+
rmse, matchThreshold), rmse < matchThreshold);
9695
}
9796
}
9897

@@ -103,7 +102,7 @@ public void runALSUsingStaticMethods() {
103102
int users = 50;
104103
int products = 100;
105104
scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
106-
users, products, features, 0.7, false);
105+
users, products, features, 0.7, false, false);
107106

108107
JavaRDD<Rating> data = sc.parallelize(testData._1());
109108
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
@@ -117,7 +116,7 @@ public void runALSUsingConstructor() {
117116
int users = 100;
118117
int products = 200;
119118
scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
120-
users, products, features, 0.7, false);
119+
users, products, features, 0.7, false, false);
121120

122121
JavaRDD<Rating> data = sc.parallelize(testData._1());
123122

@@ -134,7 +133,7 @@ public void runImplicitALSUsingStaticMethods() {
134133
int users = 80;
135134
int products = 160;
136135
scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
137-
users, products, features, 0.7, true);
136+
users, products, features, 0.7, true, false);
138137

139138
JavaRDD<Rating> data = sc.parallelize(testData._1());
140139
MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
@@ -148,7 +147,7 @@ public void runImplicitALSUsingConstructor() {
148147
int users = 100;
149148
int products = 200;
150149
scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
151-
users, products, features, 0.7, true);
150+
users, products, features, 0.7, true, false);
152151

153152
JavaRDD<Rating> data = sc.parallelize(testData._1());
154153

@@ -158,4 +157,19 @@ public void runImplicitALSUsingConstructor() {
158157
.run(data.rdd());
159158
validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
160159
}
160+
161+
@Test
162+
public void runImplicitALSWithNegativeWeight() {
163+
int features = 2;
164+
int iterations = 15;
165+
int users = 80;
166+
int products = 160;
167+
scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
168+
users, products, features, 0.7, true, true);
169+
170+
JavaRDD<Rating> data = sc.parallelize(testData._1());
171+
MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
172+
validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
173+
}
174+
161175
}

mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
package org.apache.spark.mllib.recommendation
1919

2020
import scala.collection.JavaConversions._
21+
import scala.math.abs
2122
import scala.util.Random
2223

23-
import org.scalatest.BeforeAndAfterAll
2424
import org.scalatest.FunSuite
2525

2626
import org.jblas._
@@ -34,7 +34,8 @@ object ALSSuite {
3434
products: Int,
3535
features: Int,
3636
samplingRate: Double,
37-
implicitPrefs: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = {
37+
implicitPrefs: Boolean,
38+
negativeWeights: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = {
3839
val (sampledRatings, trueRatings, truePrefs) =
3940
generateRatings(users, products, features, samplingRate, implicitPrefs)
4041
(seqAsJavaList(sampledRatings), trueRatings, truePrefs)
@@ -45,7 +46,8 @@ object ALSSuite {
4546
products: Int,
4647
features: Int,
4748
samplingRate: Double,
48-
implicitPrefs: Boolean = false): (Seq[Rating], DoubleMatrix, DoubleMatrix) = {
49+
implicitPrefs: Boolean = false,
50+
negativeWeights: Boolean = false): (Seq[Rating], DoubleMatrix, DoubleMatrix) = {
4951
val rand = new Random(42)
5052

5153
// Create a random matrix with uniform values from -1 to 1
@@ -56,7 +58,9 @@ object ALSSuite {
5658
val productMatrix = randomMatrix(features, products)
5759
val (trueRatings, truePrefs) = implicitPrefs match {
5860
case true =>
59-
val raw = new DoubleMatrix(users, products, Array.fill(users * products)(rand.nextInt(10).toDouble): _*)
61+
// Generate raw values from [0,9], or if negativeWeights, from [-2,7]
62+
val raw = new DoubleMatrix(users, products,
63+
Array.fill(users * products)((if (negativeWeights) -2 else 0) + rand.nextInt(10).toDouble): _*)
6064
val prefs = new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*)
6165
(raw, prefs)
6266
case false => (userMatrix.mmul(productMatrix), null)
@@ -107,6 +111,10 @@ class ALSSuite extends FunSuite with LocalSparkContext {
107111
testALS(100, 200, 2, 15, 0.7, 0.4, true, true)
108112
}
109113

114+
test("rank-2 matrices implicit negative") {
115+
testALS(100, 200, 2, 15, 0.7, 0.4, true, false, true)
116+
}
117+
110118
/**
111119
* Test if we can correctly factorize R = U * P where U and P are of known rank.
112120
*
@@ -118,13 +126,14 @@ class ALSSuite extends FunSuite with LocalSparkContext {
118126
* @param matchThreshold max difference allowed to consider a predicted rating correct
119127
* @param implicitPrefs flag to test implicit feedback
120128
* @param bulkPredict flag to test bulk prediciton
129+
* @param negativeWeights whether the generated data can contain negative values
121130
*/
122131
def testALS(users: Int, products: Int, features: Int, iterations: Int,
123132
samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false,
124-
bulkPredict: Boolean = false)
133+
bulkPredict: Boolean = false, negativeWeights: Boolean = false)
125134
{
126135
val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
127-
features, samplingRate, implicitPrefs)
136+
features, samplingRate, implicitPrefs, negativeWeights)
128137
val model = implicitPrefs match {
129138
case false => ALS.train(sc.parallelize(sampledRatings), features, iterations)
130139
case true => ALS.trainImplicit(sc.parallelize(sampledRatings), features, iterations)
@@ -166,13 +175,13 @@ class ALSSuite extends FunSuite with LocalSparkContext {
166175
for (u <- 0 until users; p <- 0 until products) {
167176
val prediction = predictedRatings.get(u, p)
168177
val truePref = truePrefs.get(u, p)
169-
val confidence = 1 + 1.0 * trueRatings.get(u, p)
178+
val confidence = 1 + 1.0 * abs(trueRatings.get(u, p))
170179
val err = confidence * (truePref - prediction) * (truePref - prediction)
171180
sqErr += err
172-
denom += 1
181+
denom += confidence
173182
}
174183
val rmse = math.sqrt(sqErr / denom)
175-
if (math.abs(rmse) > matchThreshold) {
184+
if (rmse > matchThreshold) {
176185
fail("Model failed to predict RMSE: %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
177186
rmse, truePrefs, predictedRatings, predictedU, predictedP))
178187
}

0 commit comments

Comments
 (0)