Skip to content

Initialized the regVal for first iteration in SGD optimizer #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,13 @@ object GradientDescent extends Logging {

// Initialize weights as a column vector
var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*)
var regVal = 0.0

/**
* For the first iteration, the regVal will be initialized as sum of sqrt of
* weights if it's L2 update; for L1 update; the same logic is followed.
*/
var regVal = updater.compute(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a nit style pick here since @mengxr asked me to chime in.

it would be better if you just put weights and the rest on the same line, e.g.

var regVal = updater.compute(
  weights, new DoubleMatrix(initialWeights.length, 1), 0, 1, regParam)._2

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Changed as you suggested.

weights, new DoubleMatrix(initialWeights.length, 1), 0, 1, regParam)._2

for (i <- 1 to numIterations) {
// Sample a subset (fraction miniBatchFraction) of the total data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class SquaredL2Updater extends Updater {
val step = gradient.mul(thisIterStepSize)
// add up both updates from the gradient of the loss (= step) as well as
// the gradient of the regularizer (= regParam * weightsOld)
// w' = w - thisIterStepSize * (gradient + regParam * w)
// w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient
val newWeights = weightsOld.mul(1.0 - thisIterStepSize * regParam).sub(step)
(newWeights, 0.5 * pow(newWeights.norm2, 2.0) * regParam)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,45 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa
val lossDiff = loss.init.zip(loss.tail).map { case (lhs, rhs) => lhs - rhs }
assert(lossDiff.count(_ > 0).toDouble / lossDiff.size > 0.8)
}

test("Test the loss and gradient of first iteration with regularization.") {

val gradient = new LogisticGradient()
val updater = new SquaredL2Updater()

// Add a extra variable consisting of all 1.0's for the intercept.
val testData = GradientDescentSuite.generateGDInput(2.0, -1.5, 10000, 42)
val data = testData.map { case LabeledPoint(label, features) =>
label -> Array(1.0, features: _*)
}

val dataRDD = sc.parallelize(data, 2).cache()

// Prepare non-zero weights
val initialWeightsWithIntercept = Array(1.0, 0.5)

val regParam0 = 0
val (newWeights0, loss0) = GradientDescent.runMiniBatchSGD(
dataRDD, gradient, updater, 1, 1, regParam0, 1.0, initialWeightsWithIntercept)

val regParam1 = 1
val (newWeights1, loss1) = GradientDescent.runMiniBatchSGD(
dataRDD, gradient, updater, 1, 1, regParam1, 1.0, initialWeightsWithIntercept)

def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = {
math.abs(x - y) / (math.abs(y) + 1e-15) < tol
}

assert(compareDouble(
loss1(0),
loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) +
math.pow(initialWeightsWithIntercept(1), 2)) / 2),
"""For non-zero weights, the regVal should be \frac{1}{2}\sum_i w_i^2.""")

assert(
compareDouble(newWeights1(0) , newWeights0(0) - initialWeightsWithIntercept(0)) &&
compareDouble(newWeights1(1) , newWeights0(1) - initialWeightsWithIntercept(1)),
"The different between newWeights with/without regularization " +
"should be initialWeightsWithIntercept.")
}
}