Skip to content

Commit 15eb86c

Browse files
sethahdbtsai
authored andcommitted
[SPARK-18456][ML][FOLLOWUP] Use matrix abstraction for coefficients in LogisticRegression training
## What changes were proposed in this pull request? This is a follow up to some of the discussion [here](#15593). During LogisticRegression training, we store the coefficients combined with intercepts as a flat vector, but a more natural abstraction is a matrix. Here, we refactor the code to use matrix where possible, which makes the code more readable and greatly simplifies the indexing. Note: We do not use a Breeze matrix for the cost function as was mentioned in the linked PR. This is because LBFGS/OWLQN require an implicit `MutableInnerProductModule[DenseMatrix[Double], Double]` which is not natively defined in Breeze. We would need to extend Breeze in Spark to define it ourselves. Also, we do not modify the `regParamL1Fun` because OWLQN in Breeze requires a `MutableEnumeratedCoordinateField[(Int, Int), DenseVector[Double]]` (since we still use a dense vector for coefficients). Here again we would have to extend Breeze inside Spark. ## How was this patch tested? This is internal code refactoring - the current unit tests passing show us that the change did not break anything. No added functionality in this patch. Author: sethah <seth.hendrickson16@gmail.com> Closes #15893 from sethah/logreg_refactor. (cherry picked from commit 856e004) Signed-off-by: DB Tsai <dbtsai@dbtsai.com>
1 parent 15ad3a3 commit 15eb86c

File tree

1 file changed

+53
-62
lines changed

1 file changed

+53
-62
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 53 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -463,16 +463,11 @@ class LogisticRegression @Since("1.2.0") (
463463
}
464464

465465
/*
466-
The coefficients are laid out in column major order during training. e.g. for
467-
`numClasses = 3` and `numFeatures = 2` and `fitIntercept = true` the layout is:
468-
469-
Array(beta_11, beta_21, beta_31, beta_12, beta_22, beta_32, intercept_1, intercept_2,
470-
intercept_3)
471-
472-
where beta_jk corresponds to the coefficient for class `j` and feature `k`.
466+
The coefficients are laid out in column major order during training. Here we initialize
467+
a column major matrix of initial coefficients.
473468
*/
474-
val initialCoefficientsWithIntercept =
475-
Vectors.zeros(numCoefficientSets * numFeaturesPlusIntercept)
469+
val initialCoefWithInterceptMatrix =
470+
Matrices.zeros(numCoefficientSets, numFeaturesPlusIntercept)
476471

477472
val initialModelIsValid = optInitialModel match {
478473
case Some(_initialModel) =>
@@ -491,18 +486,15 @@ class LogisticRegression @Since("1.2.0") (
491486
}
492487

493488
if (initialModelIsValid) {
494-
val initialCoefWithInterceptArray = initialCoefficientsWithIntercept.toArray
495489
val providedCoef = optInitialModel.get.coefficientMatrix
496-
providedCoef.foreachActive { (row, col, value) =>
497-
// convert matrix to column major for training
498-
val flatIndex = col * numCoefficientSets + row
490+
providedCoef.foreachActive { (classIndex, featureIndex, value) =>
499491
// We need to scale the coefficients since they will be trained in the scaled space
500-
initialCoefWithInterceptArray(flatIndex) = value * featuresStd(col)
492+
initialCoefWithInterceptMatrix.update(classIndex, featureIndex,
493+
value * featuresStd(featureIndex))
501494
}
502495
if ($(fitIntercept)) {
503-
optInitialModel.get.interceptVector.foreachActive { (index, value) =>
504-
val coefIndex = numCoefficientSets * numFeatures + index
505-
initialCoefWithInterceptArray(coefIndex) = value
496+
optInitialModel.get.interceptVector.foreachActive { (classIndex, value) =>
497+
initialCoefWithInterceptMatrix.update(classIndex, numFeatures, value)
506498
}
507499
}
508500
} else if ($(fitIntercept) && isMultinomial) {
@@ -532,8 +524,7 @@ class LogisticRegression @Since("1.2.0") (
532524
val rawIntercepts = histogram.map(c => math.log(c + 1)) // add 1 for smoothing
533525
val rawMean = rawIntercepts.sum / rawIntercepts.length
534526
rawIntercepts.indices.foreach { i =>
535-
initialCoefficientsWithIntercept.toArray(numClasses * numFeatures + i) =
536-
rawIntercepts(i) - rawMean
527+
initialCoefWithInterceptMatrix.update(i, numFeatures, rawIntercepts(i) - rawMean)
537528
}
538529
} else if ($(fitIntercept)) {
539530
/*
@@ -549,12 +540,12 @@ class LogisticRegression @Since("1.2.0") (
549540
b = \log{P(1) / P(0)} = \log{count_1 / count_0}
550541
}}}
551542
*/
552-
initialCoefficientsWithIntercept.toArray(numFeatures) = math.log(
553-
histogram(1) / histogram(0))
543+
initialCoefWithInterceptMatrix.update(0, numFeatures,
544+
math.log(histogram(1) / histogram(0)))
554545
}
555546

556547
val states = optimizer.iterations(new CachedDiffFunction(costFun),
557-
initialCoefficientsWithIntercept.asBreeze.toDenseVector)
548+
new BDV[Double](initialCoefWithInterceptMatrix.toArray))
558549

559550
/*
560551
Note that in Logistic Regression, the objective history (loss + regularization)
@@ -586,15 +577,24 @@ class LogisticRegression @Since("1.2.0") (
586577
Note that the intercept in scaled space and original space is the same;
587578
as a result, no scaling is needed.
588579
*/
589-
val rawCoefficients = state.x.toArray.clone()
590-
val coefficientArray = Array.tabulate(numCoefficientSets * numFeatures) { i =>
591-
val colMajorIndex = (i % numFeatures) * numCoefficientSets + i / numFeatures
592-
val featureIndex = i % numFeatures
593-
if (featuresStd(featureIndex) != 0.0) {
594-
rawCoefficients(colMajorIndex) / featuresStd(featureIndex)
595-
} else {
596-
0.0
580+
val allCoefficients = state.x.toArray.clone()
581+
val allCoefMatrix = new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept,
582+
allCoefficients)
583+
val denseCoefficientMatrix = new DenseMatrix(numCoefficientSets, numFeatures,
584+
new Array[Double](numCoefficientSets * numFeatures), isTransposed = true)
585+
val interceptVec = if ($(fitIntercept) || !isMultinomial) {
586+
Vectors.zeros(numCoefficientSets)
587+
} else {
588+
Vectors.sparse(numCoefficientSets, Seq())
589+
}
590+
// separate intercepts and coefficients from the combined matrix
591+
allCoefMatrix.foreachActive { (classIndex, featureIndex, value) =>
592+
val isIntercept = $(fitIntercept) && (featureIndex == numFeatures)
593+
if (!isIntercept && featuresStd(featureIndex) != 0.0) {
594+
denseCoefficientMatrix.update(classIndex, featureIndex,
595+
value / featuresStd(featureIndex))
597596
}
597+
if (isIntercept) interceptVec.toArray(classIndex) = value
598598
}
599599

600600
if ($(regParam) == 0.0 && isMultinomial) {
@@ -607,17 +607,16 @@ class LogisticRegression @Since("1.2.0") (
607607
Friedman, et al. "Regularization Paths for Generalized Linear Models via
608608
Coordinate Descent," https://core.ac.uk/download/files/153/6287975.pdf
609609
*/
610-
val coefficientMean = coefficientArray.sum / coefficientArray.length
611-
coefficientArray.indices.foreach { i => coefficientArray(i) -= coefficientMean}
610+
val denseValues = denseCoefficientMatrix.values
611+
val coefficientMean = denseValues.sum / denseValues.length
612+
denseCoefficientMatrix.update(_ - coefficientMean)
612613
}
613614

614-
val denseCoefficientMatrix =
615-
new DenseMatrix(numCoefficientSets, numFeatures, coefficientArray, isTransposed = true)
616615
// TODO: use `denseCoefficientMatrix.compressed` after SPARK-17471
617616
val compressedCoefficientMatrix = if (isMultinomial) {
618617
denseCoefficientMatrix
619618
} else {
620-
val compressedVector = Vectors.dense(coefficientArray).compressed
619+
val compressedVector = Vectors.dense(denseCoefficientMatrix.values).compressed
621620
compressedVector match {
622621
case dv: DenseVector => denseCoefficientMatrix
623622
case sv: SparseVector =>
@@ -626,25 +625,13 @@ class LogisticRegression @Since("1.2.0") (
626625
}
627626
}
628627

629-
val interceptsArray: Array[Double] = if ($(fitIntercept)) {
630-
Array.tabulate(numCoefficientSets) { i =>
631-
val coefIndex = numFeatures * numCoefficientSets + i
632-
rawCoefficients(coefIndex)
633-
}
634-
} else {
635-
Array.empty[Double]
636-
}
637-
val interceptVector = if (interceptsArray.nonEmpty && isMultinomial) {
638-
// The intercepts are never regularized, so we always center the mean.
639-
val interceptMean = interceptsArray.sum / numClasses
640-
interceptsArray.indices.foreach { i => interceptsArray(i) -= interceptMean }
641-
Vectors.dense(interceptsArray)
642-
} else if (interceptsArray.length == 1) {
643-
Vectors.dense(interceptsArray)
644-
} else {
645-
Vectors.sparse(numCoefficientSets, Seq())
628+
// center the intercepts when using multinomial algorithm
629+
if ($(fitIntercept) && isMultinomial) {
630+
val interceptArray = interceptVec.toArray
631+
val interceptMean = interceptArray.sum / interceptArray.length
632+
(0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean }
646633
}
647-
(compressedCoefficientMatrix, interceptVector.compressed, arrayBuilder.result())
634+
(compressedCoefficientMatrix, interceptVec.compressed, arrayBuilder.result())
648635
}
649636
}
650637

@@ -1424,6 +1411,7 @@ private class LogisticAggregator(
14241411
private val numFeatures = bcFeaturesStd.value.length
14251412
private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures
14261413
private val coefficientSize = bcCoefficients.value.size
1414+
private val numCoefficientSets = if (multinomial) numClasses else 1
14271415
if (multinomial) {
14281416
require(numClasses == coefficientSize / numFeaturesPlusIntercept, s"The number of " +
14291417
s"coefficients should be ${numClasses * numFeaturesPlusIntercept} but was $coefficientSize")
@@ -1633,12 +1621,12 @@ private class LogisticAggregator(
16331621
lossSum / weightSum
16341622
}
16351623

1636-
def gradient: Vector = {
1624+
def gradient: Matrix = {
16371625
require(weightSum > 0.0, s"The effective number of instances should be " +
16381626
s"greater than 0.0, but $weightSum.")
16391627
val result = Vectors.dense(gradientSumArray.clone())
16401628
scal(1.0 / weightSum, result)
1641-
result
1629+
new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, result.toArray)
16421630
}
16431631
}
16441632

@@ -1664,6 +1652,7 @@ private class LogisticCostFun(
16641652
val featuresStd = bcFeaturesStd.value
16651653
val numFeatures = featuresStd.length
16661654
val numCoefficientSets = if (multinomial) numClasses else 1
1655+
val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures
16671656

16681657
val logisticAggregator = {
16691658
val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance)
@@ -1675,32 +1664,34 @@ private class LogisticCostFun(
16751664
)(seqOp, combOp, aggregationDepth)
16761665
}
16771666

1678-
val totalGradientArray = logisticAggregator.gradient.toArray
1667+
val totalGradientMatrix = logisticAggregator.gradient
1668+
val coefMatrix = new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, coeffs.toArray)
16791669
// regVal is the sum of coefficients squares excluding intercept for L2 regularization.
16801670
val regVal = if (regParamL2 == 0.0) {
16811671
0.0
16821672
} else {
16831673
var sum = 0.0
1684-
coeffs.foreachActive { case (index, value) =>
1674+
coefMatrix.foreachActive { case (classIndex, featureIndex, value) =>
16851675
// We do not apply regularization to the intercepts
1686-
val isIntercept = fitIntercept && index >= numCoefficientSets * numFeatures
1676+
val isIntercept = fitIntercept && (featureIndex == numFeatures)
16871677
if (!isIntercept) {
16881678
// The following code will compute the loss of the regularization; also
16891679
// the gradient of the regularization, and add back to totalGradientArray.
16901680
sum += {
16911681
if (standardization) {
1692-
totalGradientArray(index) += regParamL2 * value
1682+
val gradValue = totalGradientMatrix(classIndex, featureIndex)
1683+
totalGradientMatrix.update(classIndex, featureIndex, gradValue + regParamL2 * value)
16931684
value * value
16941685
} else {
1695-
val featureIndex = index / numCoefficientSets
16961686
if (featuresStd(featureIndex) != 0.0) {
16971687
// If `standardization` is false, we still standardize the data
16981688
// to improve the rate of convergence; as a result, we have to
16991689
// perform this reverse standardization by penalizing each component
17001690
// differently to get effectively the same objective function when
17011691
// the training dataset is not standardized.
17021692
val temp = value / (featuresStd(featureIndex) * featuresStd(featureIndex))
1703-
totalGradientArray(index) += regParamL2 * temp
1693+
val gradValue = totalGradientMatrix(classIndex, featureIndex)
1694+
totalGradientMatrix.update(classIndex, featureIndex, gradValue + regParamL2 * temp)
17041695
value * temp
17051696
} else {
17061697
0.0
@@ -1713,6 +1704,6 @@ private class LogisticCostFun(
17131704
}
17141705
bcCoeffs.destroy(blocking = false)
17151706

1716-
(logisticAggregator.loss + regVal, new BDV(totalGradientArray))
1707+
(logisticAggregator.loss + regVal, new BDV(totalGradientMatrix.toArray))
17171708
}
17181709
}

0 commit comments

Comments
 (0)