Skip to content

Commit 44733e1

Browse files
committed
use in-place gradient computation
1 parent e981396 commit 44733e1

File tree

2 files changed

+60
-4
lines changed

2 files changed

+60
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.mllib.optimization
1919

20+
import breeze.linalg.{axpy => brzAxpy}
21+
2022
import org.apache.spark.mllib.linalg.{Vectors, Vector}
2123

2224
/**
@@ -33,6 +35,19 @@ abstract class Gradient extends Serializable {
3335
* @return (gradient: Vector, loss: Double)
3436
*/
3537
def compute(data: Vector, label: Double, weights: Vector): (Vector, Double)
38+
39+
/**
40+
* Compute the gradient and loss given the features of a single data point, add the gradient to a provided vector to
41+
* avoid creating new objects, and return loss.
42+
*
43+
* @param data features for one data point
44+
* @param label label for this data point
45+
* @param weights weights/coefficients corresponding to features
46+
* @param gradientAddTo gradient will be added to this vector
47+
*
48+
* @return (gradient: Vector, loss: Double)
49+
*/
50+
def compute(data: Vector, label: Double, weights: Vector, gradientAddTo: Vector): Double
3651
}
3752

3853
/**
@@ -55,6 +70,21 @@ class LogisticGradient extends Gradient {
5570

5671
(Vectors.fromBreeze(gradient), loss)
5772
}
73+
74+
override def compute(data: Vector, label: Double, weights: Vector, gradientAddTo: Vector): Double = {
75+
val brzData = data.toBreeze
76+
val brzWeights = weights.toBreeze
77+
val margin: Double = -1.0 * brzWeights.dot(brzData)
78+
val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
79+
80+
brzAxpy(gradientMultiplier, brzData, gradientAddTo.toBreeze)
81+
82+
if (label > 0) {
83+
math.log(1 + math.exp(margin))
84+
} else {
85+
math.log(1 + math.exp(margin)) - margin
86+
}
87+
}
5888
}
5989

6090
/**
@@ -73,6 +103,16 @@ class LeastSquaresGradient extends Gradient {
73103

74104
(Vectors.fromBreeze(gradient), loss)
75105
}
106+
107+
override def compute(data: Vector, label: Double, weights: Vector, gradientAddTo: Vector): Double = {
108+
val brzData = data.toBreeze
109+
val brzWeights = weights.toBreeze
110+
val diff = brzWeights.dot(brzData) - label
111+
112+
brzAxpy(2.0 * diff, brzData, gradientAddTo.toBreeze)
113+
114+
diff * diff
115+
}
76116
}
77117

78118
/**
@@ -96,4 +136,21 @@ class HingeGradient extends Gradient {
96136
(Vectors.dense(new Array[Double](weights.size)), 0.0)
97137
}
98138
}
139+
140+
override def compute(data: Vector, label: Double, weights: Vector, gradientAddTo: Vector): Double = {
141+
val brzData = data.toBreeze
142+
val brzWeights = weights.toBreeze
143+
val dotProduct = brzWeights.dot(brzData)
144+
145+
// Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x)))
146+
// Therefore the gradient is -(2y - 1)*x
147+
val labelScaled = 2 * label - 1.0
148+
149+
if (1.0 > labelScaled * dotProduct) {
150+
brzAxpy(-labelScaled, brzData, gradientAddTo.toBreeze)
151+
1.0 - labelScaled * dotProduct
152+
} else {
153+
0.0
154+
}
155+
}
99156
}

mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,12 @@ object GradientDescent extends Logging {
162162
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
163163
.aggregate((BDV.zeros[Double](weights.size), 0.0))(
164164
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
165-
val (g, l) = gradient.compute(features, label, weights)
166-
(grad += g.toBreeze, loss + l)
165+
val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad))
166+
(grad, loss + l)
167167
},
168168
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
169169
(grad1 += grad2, loss1 + loss2)
170-
}
171-
)
170+
})
172171

173172
/**
174173
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration

0 commit comments

Comments
 (0)