Skip to content

[SPARK-18710][ML] Add offset in GLM #16699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3bf2718
add trait offset
actuaryzhang Jan 24, 2017
0e240eb
add offset setter
actuaryzhang Jan 24, 2017
9c41453
implement offset in GLM
actuaryzhang Jan 25, 2017
7823f8a
add test for glm with offset
actuaryzhang Jan 25, 2017
a1f5695
minor cleanup
actuaryzhang Jan 25, 2017
d071b95
add doc for GLRInstance
actuaryzhang Jan 25, 2017
d2afcb0
remove offset from shared param
actuaryzhang Jan 25, 2017
9eca1a6
fix style issue
actuaryzhang Jan 25, 2017
d44974c
rename to OffsetInstance and add param check
actuaryzhang Jan 25, 2017
9c320ee
create separate instance definition when initializing
actuaryzhang Jan 26, 2017
e183c08
fix style in test
actuaryzhang Jan 26, 2017
58f93af
resolve conflict
actuaryzhang Jan 27, 2017
da4174a
add test for tweedie
actuaryzhang Jan 27, 2017
52bc32b
cast offset and add in instrumentation
actuaryzhang Jan 28, 2017
59e10f7
update var name
actuaryzhang Jan 30, 2017
1d41bdd
add test for intercept only
actuaryzhang Feb 8, 2017
fb372ad
update test
actuaryzhang Feb 8, 2017
2bc3ae7
pull and merge
actuaryzhang Feb 8, 2017
afb4643
implement null dev for offset model
actuaryzhang Feb 9, 2017
fc64d32
fix null deviance calculation and add tests
actuaryzhang Feb 10, 2017
90d68a6
allow missing offset in prediction
actuaryzhang Feb 14, 2017
e95c25b
clean up
actuaryzhang Feb 14, 2017
4b336be
Merge branch 'master' of https://github.com/apache/spark into offset
actuaryzhang Feb 14, 2017
1e47a11
address comments
actuaryzhang Jun 27, 2017
db0ac93
address comments
actuaryzhang Jun 29, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,24 @@ import org.apache.spark.ml.linalg.Vector
* @param features The vector of features for this data point.
*/
private[ml] case class Instance(label: Double, weight: Double, features: Vector)

/**
* Case class that represents an instance of data point with
* label, weight, offset and features.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add doc This is mainly used in GeneralizedLinearRegression currently.

* This is mainly used in GeneralizedLinearRegression currently.
*
* @param label Label for this data point.
* @param weight The weight of this instance.
* @param offset The offset used for this data point.
* @param features The vector of features for this data point.
*/
private[ml] case class OffsetInstance(
label: Double,
weight: Double,
offset: Double,
features: Vector) {

/** Converts to an [[Instance]] object by leaving out the offset. */
def toInstance: Instance = Instance(label, weight, features)

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml.optim

import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
import org.apache.spark.ml.linalg._
import org.apache.spark.rdd.RDD

Expand All @@ -43,7 +43,7 @@ private[ml] class IterativelyReweightedLeastSquaresModel(
* find M-estimator in robust regression and other optimization problems.
*
* @param initialModel the initial guess model.
* @param reweightFunc the reweight function which is used to update offsets and weights
* @param reweightFunc the reweight function which is used to update working labels and weights
* at each iteration.
* @param fitIntercept whether to fit intercept.
* @param regParam L2 regularization parameter used by WLS.
Expand All @@ -57,13 +57,13 @@ private[ml] class IterativelyReweightedLeastSquaresModel(
*/
private[ml] class IterativelyReweightedLeastSquares(
val initialModel: WeightedLeastSquaresModel,
val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double),
val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double),
val fitIntercept: Boolean,
val regParam: Double,
val maxIter: Int,
val tol: Double) extends Logging with Serializable {

def fit(instances: RDD[Instance]): IterativelyReweightedLeastSquaresModel = {
def fit(instances: RDD[OffsetInstance]): IterativelyReweightedLeastSquaresModel = {

var converged = false
var iter = 0
Expand All @@ -75,10 +75,10 @@ private[ml] class IterativelyReweightedLeastSquares(

oldModel = model

// Update offsets and weights using reweightFunc
// Update working labels and weights using reweightFunc
val newInstances = instances.map { instance =>
val (newOffset, newWeight) = reweightFunc(instance, oldModel)
Instance(newOffset, newWeight, instance.features)
val (newLabel, newWeight) = reweightFunc(instance, oldModel)
Instance(newLabel, newWeight, instance.features)
}

// Estimate new model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml.optim

import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
import org.apache.spark.ml.linalg._
import org.apache.spark.rdd.RDD

Expand Down
Loading