Skip to content

Commit 091cbc3

Browse files
committed
[SPARK-9612][ML] Add instance weight support for GBTs
### What changes were proposed in this pull request? add weight support for GBTs by sampling data before passing it to trees and then passing weights to trees in summary: 1, add setters of `minWeightFractionPerNode` & `weightCol` 2, update input types in private methods from `RDD[LabeledPoint]` to `RDD[Instance]`: `DecisionTreeRegressor.train`, `GradientBoostedTrees.run`, `GradientBoostedTrees.runWithValidation`, `GradientBoostedTrees.computeInitialPredictionAndError`, `GradientBoostedTrees.computeError`, `GradientBoostedTrees.evaluateEachIteration`, `GradientBoostedTrees.boost`, `GradientBoostedTrees.updatePredictionError` 3, add new private method `GradientBoostedTrees.computeError(data, predError)` to compute average error, since original `predError.values.mean()` do not take weights into account. 4, add new tests ### Why are the changes needed? GBTs should support sample weights like other algs ### Does this PR introduce any user-facing change? yes, new setters are added ### How was this patch tested? existing & added testsuites Closes #25926 from zhengruifeng/gbt_add_weight. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
1 parent 1474ed0 commit 091cbc3

File tree

11 files changed

+261
-166
lines changed

11 files changed

+261
-166
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ private[spark] trait ClassifierParams
5353
val validateInstance = (instance: Instance) => {
5454
val label = instance.label
5555
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
56-
s" dataset with invalid label $label. Labels must be integers in range" +
56+
s" dataset with invalid label $label. Labels must be integers in range" +
5757
s" [0, $numClasses).")
5858
}
5959
extractInstances(dataset, validateInstance)

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.json4s.JsonDSL._
2323

2424
import org.apache.spark.annotation.Since
2525
import org.apache.spark.internal.Logging
26-
import org.apache.spark.ml.feature.LabeledPoint
26+
import org.apache.spark.ml.feature.Instance
2727
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
2828
import org.apache.spark.ml.param.ParamMap
2929
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
@@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
3434
import org.apache.spark.ml.util.Instrumentation.instrumented
3535
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
3636
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
37-
import org.apache.spark.sql.{DataFrame, Dataset, Row}
37+
import org.apache.spark.sql.{DataFrame, Dataset}
3838
import org.apache.spark.sql.functions._
3939

4040
/**
@@ -79,6 +79,10 @@ class GBTClassifier @Since("1.4.0") (
7979
@Since("1.4.0")
8080
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
8181

82+
/** @group setParam */
83+
@Since("3.0.0")
84+
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
85+
8286
/** @group setParam */
8387
@Since("1.4.0")
8488
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
@@ -152,36 +156,34 @@ class GBTClassifier @Since("1.4.0") (
152156
set(validationIndicatorCol, value)
153157
}
154158

159+
/**
160+
* Sets the value of param [[weightCol]].
161+
* If this is not set or empty, we treat all instance weights as 1.0.
162+
* By default the weightCol is not set, so all instances have weight 1.0.
163+
*
164+
* @group setParam
165+
*/
166+
@Since("3.0.0")
167+
def setWeightCol(value: String): this.type = set(weightCol, value)
168+
155169
override protected def train(
156170
dataset: Dataset[_]): GBTClassificationModel = instrumented { instr =>
157-
val categoricalFeatures: Map[Int, Int] =
158-
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
159-
160171
val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
161172

162-
// We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
163-
// 2 classes now. This lets us provide a more precise error message.
164-
val convert2LabeledPoint = (dataset: Dataset[_]) => {
165-
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
166-
case Row(label: Double, features: Vector) =>
167-
require(label == 0 || label == 1, s"GBTClassifier was given" +
168-
s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
169-
s" GBTClassifier currently only supports binary classification.")
170-
LabeledPoint(label, features)
171-
}
173+
val validateInstance = (instance: Instance) => {
174+
val label = instance.label
175+
require(label == 0 || label == 1, s"GBTClassifier was given" +
176+
s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
177+
s" GBTClassifier currently only supports binary classification.")
172178
}
173179

174180
val (trainDataset, validationDataset) = if (withValidation) {
175-
(
176-
convert2LabeledPoint(dataset.filter(not(col($(validationIndicatorCol))))),
177-
convert2LabeledPoint(dataset.filter(col($(validationIndicatorCol))))
178-
)
181+
(extractInstances(dataset.filter(not(col($(validationIndicatorCol)))), validateInstance),
182+
extractInstances(dataset.filter(col($(validationIndicatorCol))), validateInstance))
179183
} else {
180-
(convert2LabeledPoint(dataset), null)
184+
(extractInstances(dataset, validateInstance), null)
181185
}
182186

183-
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
184-
185187
val numClasses = 2
186188
if (isDefined(thresholds)) {
187189
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -191,12 +193,14 @@ class GBTClassifier @Since("1.4.0") (
191193

192194
instr.logPipelineStage(this)
193195
instr.logDataset(dataset)
194-
instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, impurity,
195-
lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
196-
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy,
197-
validationIndicatorCol, validationTol)
196+
instr.logParams(this, labelCol, weightCol, featuresCol, predictionCol, leafCol,
197+
impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain,
198+
minInstancesPerNode, minWeightFractionPerNode, seed, stepSize, subsamplingRate, cacheNodeIds,
199+
checkpointInterval, featureSubsetStrategy, validationIndicatorCol, validationTol)
198200
instr.logNumClasses(numClasses)
199201

202+
val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
203+
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
200204
val (baseLearners, learnerWeights) = if (withValidation) {
201205
GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
202206
$(seed), $(featureSubsetStrategy))
@@ -374,12 +378,9 @@ class GBTClassificationModel private[ml](
374378
*/
375379
@Since("2.4.0")
376380
def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = {
377-
val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
378-
case Row(label: Double, features: Vector) => LabeledPoint(label, features)
379-
}
381+
val data = extractInstances(dataset)
380382
GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss,
381-
OldAlgo.Classification
382-
)
383+
OldAlgo.Classification)
383384
}
384385

385386
@Since("2.0.0")
@@ -423,10 +424,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
423424
val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
424425
val numTrees = (metadata.metadata \ numTreesKey).extract[Int]
425426

426-
val trees: Array[DecisionTreeRegressionModel] = treesData.map {
427+
val trees = treesData.map {
427428
case (treeMetadata, root) =>
428-
val tree =
429-
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
429+
val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
430430
treeMetadata.getAndSetParams(tree)
431431
tree
432432
}

mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.ml.linalg.Vector
2626
* @param weight The weight of this instance.
2727
* @param features The vector of features for this data point.
2828
*/
29-
private[ml] case class Instance(label: Double, weight: Double, features: Vector)
29+
private[spark] case class Instance(label: Double, weight: Double, features: Vector)
3030

3131
/**
3232
* Case class that represents an instance of data point with

mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.json4s.JsonDSL._
2323

2424
import org.apache.spark.annotation.Since
2525
import org.apache.spark.ml.{PredictionModel, Predictor}
26-
import org.apache.spark.ml.feature.LabeledPoint
26+
import org.apache.spark.ml.feature.Instance
2727
import org.apache.spark.ml.linalg.Vector
2828
import org.apache.spark.ml.param.ParamMap
2929
import org.apache.spark.ml.tree._
@@ -132,15 +132,14 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
132132

133133
/** (private[ml]) Train a decision tree on an RDD */
134134
private[ml] def train(
135-
data: RDD[LabeledPoint],
135+
data: RDD[Instance],
136136
oldStrategy: OldStrategy,
137137
featureSubsetStrategy: String): DecisionTreeRegressionModel = instrumented { instr =>
138138
instr.logPipelineStage(this)
139139
instr.logDataset(data)
140140
instr.logParams(this, params: _*)
141141

142-
val instances = data.map(_.toInstance)
143-
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
142+
val trees = RandomForest.run(data, oldStrategy, numTrees = 1,
144143
featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid))
145144

146145
trees.head.asInstanceOf[DecisionTreeRegressionModel]

mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import org.json4s.JsonDSL._
2424
import org.apache.spark.annotation.Since
2525
import org.apache.spark.internal.Logging
2626
import org.apache.spark.ml.{PredictionModel, Predictor}
27-
import org.apache.spark.ml.feature.LabeledPoint
2827
import org.apache.spark.ml.linalg.Vector
2928
import org.apache.spark.ml.param.ParamMap
3029
import org.apache.spark.ml.tree._
@@ -34,7 +33,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
3433
import org.apache.spark.ml.util.Instrumentation.instrumented
3534
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
3635
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
37-
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
36+
import org.apache.spark.sql.{Column, DataFrame, Dataset}
3837
import org.apache.spark.sql.functions._
3938

4039
/**
@@ -78,6 +77,10 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
7877
@Since("1.4.0")
7978
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
8079

80+
/** @group setParam */
81+
@Since("3.0.0")
82+
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
83+
8184
/** @group setParam */
8285
@Since("1.4.0")
8386
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
@@ -151,29 +154,35 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
151154
set(validationIndicatorCol, value)
152155
}
153156

154-
override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr =>
155-
val categoricalFeatures: Map[Int, Int] =
156-
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
157+
/**
158+
* Sets the value of param [[weightCol]].
159+
* If this is not set or empty, we treat all instance weights as 1.0.
160+
* By default the weightCol is not set, so all instances have weight 1.0.
161+
*
162+
* @group setParam
163+
*/
164+
@Since("3.0.0")
165+
def setWeightCol(value: String): this.type = set(weightCol, value)
157166

167+
override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr =>
158168
val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
159169

160170
val (trainDataset, validationDataset) = if (withValidation) {
161-
(
162-
extractLabeledPoints(dataset.filter(not(col($(validationIndicatorCol))))),
163-
extractLabeledPoints(dataset.filter(col($(validationIndicatorCol))))
164-
)
171+
(extractInstances(dataset.filter(not(col($(validationIndicatorCol))))),
172+
extractInstances(dataset.filter(col($(validationIndicatorCol)))))
165173
} else {
166-
(extractLabeledPoints(dataset), null)
174+
(extractInstances(dataset), null)
167175
}
168-
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
169176

170177
instr.logPipelineStage(this)
171178
instr.logDataset(dataset)
172-
instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, impurity, lossType,
173-
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
174-
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy,
175-
validationIndicatorCol, validationTol)
179+
instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, weightCol, impurity,
180+
lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
181+
minWeightFractionPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval,
182+
featureSubsetStrategy, validationIndicatorCol, validationTol)
176183

184+
val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
185+
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
177186
val (baseLearners, learnerWeights) = if (withValidation) {
178187
GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
179188
$(seed), $(featureSubsetStrategy))
@@ -323,9 +332,7 @@ class GBTRegressionModel private[ml](
323332
*/
324333
@Since("2.4.0")
325334
def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = {
326-
val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
327-
case Row(label: Double, features: Vector) => LabeledPoint(label, features)
328-
}
335+
val data = extractInstances(dataset)
329336
GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights,
330337
convertToOldLossType(loss), OldAlgo.Regression)
331338
}
@@ -368,10 +375,9 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
368375
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
369376
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
370377

371-
val trees: Array[DecisionTreeRegressionModel] = treesData.map {
378+
val trees = treesData.map {
372379
case (treeMetadata, root) =>
373-
val tree =
374-
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
380+
val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
375381
treeMetadata.getAndSetParams(tree)
376382
tree
377383
}

0 commit comments

Comments
 (0)