-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-8601][ML] Add an option to disable standardization for linear regression #7875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
00a1dc5
55d3a66
e47c574
e54a8a9
99ce053
0c334a2
b83a41e
3f92935
eebe10a
332f140
6b1dc09
d6234ba
baa0805
bbff347
596e96c
e856036
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,7 +45,7 @@ import org.apache.spark.util.StatCounter | |
*/ | ||
private[regression] trait LinearRegressionParams extends PredictorParams | ||
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol | ||
with HasFitIntercept | ||
with HasFitIntercept with HasStandardization | ||
|
||
/** | ||
* :: Experimental :: | ||
|
@@ -84,6 +84,18 @@ class LinearRegression(override val uid: String) | |
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) | ||
setDefault(fitIntercept -> true) | ||
|
||
/** | ||
* Whether to standardize the training features before fitting the model. | ||
* The coefficients of models will be always returned on the original scale, | ||
* so it will be transparent for users. Note that with/without standardization, | ||
* the models should be always converged to the same solution when no regularization | ||
* is applied. In R's GLMNET package, the default behavior is true as well. | ||
* Default is true. | ||
* @group setParam | ||
*/ | ||
def setStandardization(value: Boolean): this.type = set(standardization, value) | ||
setDefault(standardization -> true) | ||
|
||
/** | ||
* Set the ElasticNet mixing parameter. | ||
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. | ||
|
@@ -165,12 +177,24 @@ class LinearRegression(override val uid: String) | |
val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam | ||
|
||
val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept), | ||
featuresStd, featuresMean, effectiveL2RegParam) | ||
$(standardization), featuresStd, featuresMean, effectiveL2RegParam) | ||
|
||
val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { | ||
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) | ||
} else { | ||
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegParam, $(tol)) | ||
def effectiveL1RegFun = (index: Int) => { | ||
if ($(standardization)) { | ||
effectiveL1RegParam | ||
} else { | ||
// If `standardization` is false, we still standardize the data | ||
// to improve the rate of convergence; as a result, we have to | ||
// perform this reverse standardization by penalizing each component | ||
// differently to get effectively the same objective function when | ||
// the training dataset is not standardized. | ||
if (featuresStd(index) != 0.0) effectiveL1RegParam / featuresStd(index) else 0.0 | ||
} | ||
} | ||
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, $(tol)) | ||
} | ||
|
||
val initialWeights = Vectors.zeros(numFeatures) | ||
|
@@ -456,6 +480,7 @@ class LinearRegressionSummary private[regression] ( | |
* @param weights The weights/coefficients corresponding to the features. | ||
* @param labelStd The standard deviation value of the label. | ||
* @param labelMean The mean value of the label. | ||
* @param fitIntercept Whether to fit an intercept term. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing standardization param There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ha. this is confusing. standardization param is in different class. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh damn my bad :p |
||
* @param featuresStd The standard deviation values of the features. | ||
* @param featuresMean The mean values of the features. | ||
*/ | ||
|
@@ -568,6 +593,7 @@ private class LeastSquaresCostFun( | |
labelStd: Double, | ||
labelMean: Double, | ||
fitIntercept: Boolean, | ||
standardization: Boolean, | ||
featuresStd: Array[Double], | ||
featuresMean: Array[Double], | ||
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { | ||
|
@@ -584,14 +610,38 @@ private class LeastSquaresCostFun( | |
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) | ||
}) | ||
|
||
// regVal is the sum of weight squares for L2 regularization | ||
val norm = brzNorm(weights, 2.0) | ||
val regVal = 0.5 * effectiveL2regParam * norm * norm | ||
val totalGradientArray = leastSquaresAggregator.gradient.toArray | ||
|
||
val loss = leastSquaresAggregator.loss + regVal | ||
val gradient = leastSquaresAggregator.gradient | ||
axpy(effectiveL2regParam, w, gradient) | ||
val regVal = if (effectiveL2regParam == 0.0) { | ||
0.0 | ||
} else { | ||
var sum = 0.0 | ||
w.foreachActive { (index, value) => | ||
// The following code will compute the loss of the regularization; also | ||
// the gradient of the regularization, and add back to totalGradientArray. | ||
sum += { | ||
if (standardization) { | ||
totalGradientArray(index) += effectiveL2regParam * value | ||
value * value | ||
} else { | ||
if (featuresStd(index) != 0.0) { | ||
// If `standardization` is false, we still standardize the data | ||
// to improve the rate of convergence; as a result, we have to | ||
// perform this reverse standardization by penalizing each component | ||
// differently to get effectively the same objective function when | ||
// the training dataset is not standardized. | ||
val temp = value / (featuresStd(index) * featuresStd(index)) | ||
totalGradientArray(index) += effectiveL2regParam * temp | ||
value * temp | ||
} else { | ||
0.0 | ||
} | ||
} | ||
} | ||
} | ||
0.5 * effectiveL2regParam * sum | ||
} | ||
|
||
(loss, gradient.toBreeze.asInstanceOf[BDV[Double]]) | ||
(leastSquaresAggregator.loss + regVal, new BDV(totalGradientArray)) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor note: since this is a new API consider mentioning that the old default did have standardization turned on so people don't expect differences? (nit)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The old behavior is true. Should we mention it since it's just a new feature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yah I think it would be nice to just mention that while this is a new flag, the old default effectively had this flag on so don't worry (also maybe add a @SInCE tag?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds great. How about you create an JIRA for both this PR and the one for LOR? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If an update does not change the default behavior, I don't think there's a great need to notify the user. (Unnecessary documentation detracts from important docs.) But I do like the idea of saying somewhere that the default behavior matches that of glmnet (in the Scala doc of setStandardization).