1717
1818package org .apache .spark .mllib .optimization
1919
20+ import breeze .linalg .{axpy => brzAxpy }
21+
2022import org .apache .spark .mllib .linalg .{Vectors , Vector }
2123
2224/**
@@ -33,6 +35,19 @@ abstract class Gradient extends Serializable {
3335 * @return (gradient: Vector, loss: Double)
3436 */
3537 def compute (data : Vector , label : Double , weights : Vector ): (Vector , Double )
38+
39+ /**
40+ * Compute the gradient and loss given the features of a single data point, add the gradient to a provided vector to
41+ * avoid creating new objects, and return loss.
42+ *
43+ * @param data features for one data point
44+ * @param label label for this data point
45+ * @param weights weights/coefficients corresponding to features
46+ * @param gradientAddTo gradient will be added to this vector
47+ *
48+ * @return (gradient: Vector, loss: Double)
49+ */
50+ def compute (data : Vector , label : Double , weights : Vector , gradientAddTo : Vector ): Double
3651}
3752
3853/**
@@ -55,6 +70,21 @@ class LogisticGradient extends Gradient {
5570
5671 (Vectors .fromBreeze(gradient), loss)
5772 }
73+
74+ override def compute (data : Vector , label : Double , weights : Vector , gradientAddTo : Vector ): Double = {
75+ val brzData = data.toBreeze
76+ val brzWeights = weights.toBreeze
77+ val margin : Double = - 1.0 * brzWeights.dot(brzData)
78+ val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
79+
80+ brzAxpy(gradientMultiplier, brzData, gradientAddTo.toBreeze)
81+
82+ if (label > 0 ) {
83+ math.log(1 + math.exp(margin))
84+ } else {
85+ math.log(1 + math.exp(margin)) - margin
86+ }
87+ }
5888}
5989
6090/**
@@ -73,6 +103,16 @@ class LeastSquaresGradient extends Gradient {
73103
74104 (Vectors .fromBreeze(gradient), loss)
75105 }
106+
107+ override def compute (data : Vector , label : Double , weights : Vector , gradientAddTo : Vector ): Double = {
108+ val brzData = data.toBreeze
109+ val brzWeights = weights.toBreeze
110+ val diff = brzWeights.dot(brzData) - label
111+
112+ brzAxpy(2.0 * diff, brzData, gradientAddTo.toBreeze)
113+
114+ diff * diff
115+ }
76116}
77117
78118/**
@@ -96,4 +136,21 @@ class HingeGradient extends Gradient {
96136 (Vectors .dense(new Array [Double ](weights.size)), 0.0 )
97137 }
98138 }
139+
140+ override def compute (data : Vector , label : Double , weights : Vector , gradientAddTo : Vector ): Double = {
141+ val brzData = data.toBreeze
142+ val brzWeights = weights.toBreeze
143+ val dotProduct = brzWeights.dot(brzData)
144+
145+ // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x)))
146+ // Therefore the gradient is -(2y - 1)*x
147+ val labelScaled = 2 * label - 1.0
148+
149+ if (1.0 > labelScaled * dotProduct) {
150+ brzAxpy(- labelScaled, brzData, gradientAddTo.toBreeze)
151+ 1.0 - labelScaled * dotProduct
152+ } else {
153+ 0.0
154+ }
155+ }
99156}
0 commit comments