Skip to content

[SPARK-7685][ML] Apply weights to different samples in Logistic Regression #7884

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ private[shared] object SharedParamsCodeGen {
Some("\"rawPrediction\"")),
ParamDesc[String]("probabilityCol", "Column name for predicted class conditional" +
" probabilities. Note: Not all models output well-calibrated probability estimates!" +
" These probabilities should be treated as confidences, not precise probabilities.",
" These probabilities should be treated as confidences, not precise probabilities",
Some("\"probability\"")),
ParamDesc[Double]("threshold",
"threshold in binary classification prediction, in range [0, 1]", Some("0.5"),
Expand All @@ -65,10 +65,10 @@ private[shared] object SharedParamsCodeGen {
"options may be added later.",
isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"),
ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
" before fitting the model.", Some("true")),
" before fitting the model", Some("true")),
ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),
ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." +
" For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.",
" For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty",
isValid = "ParamValidators.inRange(0, 1)"),
ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"),
ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ private[ml] trait HasRawPredictionCol extends Params {
private[ml] trait HasProbabilityCol extends Params {

/**
* Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities..
* Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.
* @group param
*/
final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.")
final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities")

setDefault(probabilityCol, "probability")

Expand Down Expand Up @@ -270,10 +270,10 @@ private[ml] trait HasHandleInvalid extends Params {
private[ml] trait HasStandardization extends Params {

/**
* Param for whether to standardize the training features before fitting the model..
* Param for whether to standardize the training features before fitting the model.
* @group param
*/
final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model.")
final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model")

setDefault(standardization, true)

Expand Down Expand Up @@ -304,10 +304,10 @@ private[ml] trait HasSeed extends Params {
private[ml] trait HasElasticNetParam extends Params {

/**
* Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty..
* Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
* @group param
*/
final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1))
final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", ParamValidators.inRange(0, 1))

/** @group getParam */
final def getElasticNetParam: Double = $(elasticNetParam)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,19 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector}
/**
* :: DeveloperApi ::
* MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean,
* variance, minimum, maximum, counts, and nonzero counts for samples in sparse or dense vector
* variance, minimum, maximum, counts, and nonzero counts for instances in sparse or dense vector
* format in a online fashion.
*
* Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of
* the corresponding joint dataset.
*
* A numerically stable algorithm is implemented to compute sample mean and variance:
* A numerically stable algorithm is implemented to compute the mean and variance of instances:
* Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]
* Zero elements (including explicit zero values) are skipped when calling add(),
* to have time complexity O(nnz) instead of O(n) for each column.
*
* For weighted instances, the unbiased estimation of variance is defined by the reliability
* weights: [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]].
*/
@Since("1.1.0")
@DeveloperApi
Expand All @@ -44,6 +47,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
private var currM2: Array[Double] = _
private var currL1: Array[Double] = _
private var totalCnt: Long = 0
private var weightSum: Double = 0.0
private var weightSquareSum: Double = 0.0
private var nnz: Array[Double] = _
private var currMax: Array[Double] = _
private var currMin: Array[Double] = _
Expand All @@ -55,10 +60,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
* @return This MultivariateOnlineSummarizer object.
*/
@Since("1.1.0")
def add(sample: Vector): this.type = {
def add(sample: Vector): this.type = add(sample, 1.0)

private[spark] def add(instance: Vector, weight: Double): this.type = {
require(weight >= 0.0, s"sample weight, ${weight} has to be >= 0.0")
if (weight == 0.0) return this

if (n == 0) {
require(sample.size > 0, s"Vector should have dimension larger than zero.")
n = sample.size
require(instance.size > 0, s"Vector should have dimension larger than zero.")
n = instance.size

currMean = Array.ofDim[Double](n)
currM2n = Array.ofDim[Double](n)
Expand All @@ -69,8 +79,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currMin = Array.fill[Double](n)(Double.MaxValue)
}

require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")
require(n == instance.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${instance.size}.")

val localCurrMean = currMean
val localCurrM2n = currM2n
Expand All @@ -79,7 +89,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val localNnz = nnz
val localCurrMax = currMax
val localCurrMin = currMin
sample.foreachActive { (index, value) =>
instance.foreachActive { (index, value) =>
if (value != 0.0) {
if (localCurrMax(index) < value) {
localCurrMax(index) = value
Expand All @@ -90,15 +100,17 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S

val prevMean = localCurrMean(index)
val diff = value - prevMean
localCurrMean(index) = prevMean + diff / (localNnz(index) + 1.0)
localCurrM2n(index) += (value - localCurrMean(index)) * diff
localCurrM2(index) += value * value
localCurrL1(index) += math.abs(value)
localCurrMean(index) = prevMean + weight * diff / (localNnz(index) + weight)
localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff
localCurrM2(index) += weight * value * value
localCurrL1(index) += weight * math.abs(value)

localNnz(index) += 1.0
localNnz(index) += weight
}
}

weightSum += weight
weightSquareSum += weight * weight
totalCnt += 1
this
}
Expand All @@ -112,10 +124,12 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
def merge(other: MultivariateOnlineSummarizer): this.type = {
if (this.totalCnt != 0 && other.totalCnt != 0) {
if (this.weightSum != 0.0 && other.weightSum != 0.0) {
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
s"Expecting $n but got ${other.n}.")
totalCnt += other.totalCnt
weightSum += other.weightSum
weightSquareSum += other.weightSquareSum
var i = 0
while (i < n) {
val thisNnz = nnz(i)
Expand All @@ -138,13 +152,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
nnz(i) = totalNnz
i += 1
}
} else if (totalCnt == 0 && other.totalCnt != 0) {
} else if (weightSum == 0.0 && other.weightSum != 0.0) {
this.n = other.n
this.currMean = other.currMean.clone()
this.currM2n = other.currM2n.clone()
this.currM2 = other.currM2.clone()
this.currL1 = other.currL1.clone()
this.totalCnt = other.totalCnt
this.weightSum = other.weightSum
this.weightSquareSum = other.weightSquareSum
this.nnz = other.nnz.clone()
this.currMax = other.currMax.clone()
this.currMin = other.currMin.clone()
Expand All @@ -158,38 +174,37 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def mean: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
require(weightSum > 0, s"Nothing has been added to this summarizer.")

val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * (nnz(i) / totalCnt)
realMean(i) = currMean(i) * (nnz(i) / weightSum)
i += 1
}
Vectors.dense(realMean)
}

/**
* Sample variance of each dimension.
* Unbiased estimate of sample variance of each dimension.
*
*/
@Since("1.1.0")
override def variance: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
require(weightSum > 0, s"Nothing has been added to this summarizer.")

val realVariance = Array.ofDim[Double](n)

val denominator = totalCnt - 1.0
val denominator = weightSum - (weightSquareSum / weightSum)

// Sample variance is computed, if the denominator is less than 0, the variance is just 0.
if (denominator > 0.0) {
val deltaMean = currMean
var i = 0
val len = currM2n.length
while (i < len) {
realVariance(i) =
currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
realVariance(i) /= denominator
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) *
(weightSum - nnz(i)) / weightSum) / denominator
i += 1
}
}
Expand All @@ -209,7 +224,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def numNonzeros: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
require(weightSum > 0, s"Nothing has been added to this summarizer.")

Vectors.dense(nnz)
}
Expand All @@ -220,11 +235,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def max: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
require(weightSum > 0, s"Nothing has been added to this summarizer.")

var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
if ((nnz(i) < weightSum) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
Vectors.dense(currMax)
Expand All @@ -236,11 +251,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def min: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
require(weightSum > 0, s"Nothing has been added to this summarizer.")

var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
if ((nnz(i) < weightSum) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
Vectors.dense(currMin)
Expand All @@ -252,7 +267,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.2.0")
override def normL2: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
require(weightSum > 0, s"Nothing has been added to this summarizer.")

val realMagnitude = Array.ofDim[Double](n)

Expand All @@ -271,7 +286,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.2.0")
override def normL1: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
require(weightSum > 0, s"Nothing has been added to this summarizer.")

Vectors.dense(currL1)
}
Expand Down
Loading