Skip to content

[SPARK-9612][ML] Add instance weight support for GBTs #25926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ private[spark] trait ClassifierParams
val validateInstance = (instance: Instance) => {
val label = instance.label
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")
}
extractInstances(dataset, validateInstance)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
Expand All @@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._

/**
Expand Down Expand Up @@ -79,6 +79,10 @@ class GBTClassifier @Since("1.4.0") (
@Since("1.4.0")
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("3.0.0")
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)

/** @group setParam */
@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
Expand Down Expand Up @@ -152,36 +156,34 @@ class GBTClassifier @Since("1.4.0") (
set(validationIndicatorCol, value)
}

/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* By default the weightCol is not set, so all instances have weight 1.0.
*
* @group setParam
*/
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

override protected def train(
dataset: Dataset[_]): GBTClassificationModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))

val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty

// We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
// 2 classes now. This lets us provide a more precise error message.
val convert2LabeledPoint = (dataset: Dataset[_]) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the error message here was much nicer:

GBTClassifier currently only supports binary classification.

than the new one in extractInstances. Perhaps it would be nicer to keep this custom error message, or pass some part of the message to the extractInstances method.

dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label == 0 || label == 1, s"GBTClassifier was given" +
s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
s" GBTClassifier currently only supports binary classification.")
LabeledPoint(label, features)
}
val validateInstance = (instance: Instance) => {
val label = instance.label
require(label == 0 || label == 1, s"GBTClassifier was given" +
s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
s" GBTClassifier currently only supports binary classification.")
}

val (trainDataset, validationDataset) = if (withValidation) {
(
convert2LabeledPoint(dataset.filter(not(col($(validationIndicatorCol))))),
convert2LabeledPoint(dataset.filter(col($(validationIndicatorCol))))
)
(extractInstances(dataset.filter(not(col($(validationIndicatorCol)))), validateInstance),
extractInstances(dataset.filter(col($(validationIndicatorCol))), validateInstance))
} else {
(convert2LabeledPoint(dataset), null)
(extractInstances(dataset, validateInstance), null)
}

val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)

val numClasses = 2
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
Expand All @@ -191,12 +193,14 @@ class GBTClassifier @Since("1.4.0") (

instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, impurity,
lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy,
validationIndicatorCol, validationTol)
instr.logParams(this, labelCol, weightCol, featuresCol, predictionCol, leafCol,
impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain,
minInstancesPerNode, minWeightFractionPerNode, seed, stepSize, subsamplingRate, cacheNodeIds,
checkpointInterval, featureSubsetStrategy, validationIndicatorCol, validationTol)
instr.logNumClasses(numClasses)

val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val (baseLearners, learnerWeights) = if (withValidation) {
GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
$(seed), $(featureSubsetStrategy))
Expand Down Expand Up @@ -374,12 +378,9 @@ class GBTClassificationModel private[ml](
*/
@Since("2.4.0")
def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = {
val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) => LabeledPoint(label, features)
}
val data = extractInstances(dataset)
GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss,
OldAlgo.Classification
)
OldAlgo.Classification)
}

@Since("2.0.0")
Expand Down Expand Up @@ -423,10 +424,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
val numTrees = (metadata.metadata \ numTreesKey).extract[Int]

val trees: Array[DecisionTreeRegressionModel] = treesData.map {
val trees = treesData.map {
case (treeMetadata, root) =>
val tree =
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
treeMetadata.getAndSetParams(tree)
tree
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.ml.linalg.Vector
* @param weight The weight of this instance.
* @param features The vector of features for this data point.
*/
private[ml] case class Instance(label: Double, weight: Double, features: Vector)
private[spark] case class Instance(label: Double, weight: Double, features: Vector)

/**
* Case class that represents an instance of data point with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
Expand Down Expand Up @@ -132,15 +132,14 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S

/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(
data: RDD[LabeledPoint],
data: RDD[Instance],
oldStrategy: OldStrategy,
featureSubsetStrategy: String): DecisionTreeRegressionModel = instrumented { instr =>
instr.logPipelineStage(this)
instr.logDataset(data)
instr.logParams(this, params: _*)

val instances = data.map(_.toInstance)
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
val trees = RandomForest.run(data, oldStrategy, numTrees = 1,
featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid))

trees.head.asInstanceOf[DecisionTreeRegressionModel]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
Expand All @@ -34,7 +33,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions._

/**
Expand Down Expand Up @@ -78,6 +77,10 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0")
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("3.0.0")
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)

/** @group setParam */
@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
Expand Down Expand Up @@ -151,29 +154,35 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
set(validationIndicatorCol, value)
}

override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* By default the weightCol is not set, so all instances have weight 1.0.
*
* @group setParam
*/
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr =>
val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty

val (trainDataset, validationDataset) = if (withValidation) {
(
extractLabeledPoints(dataset.filter(not(col($(validationIndicatorCol))))),
extractLabeledPoints(dataset.filter(col($(validationIndicatorCol))))
)
(extractInstances(dataset.filter(not(col($(validationIndicatorCol))))),
extractInstances(dataset.filter(col($(validationIndicatorCol)))))
} else {
(extractLabeledPoints(dataset), null)
(extractInstances(dataset), null)
}
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)

instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy,
validationIndicatorCol, validationTol)
instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, weightCol, impurity,
lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
minWeightFractionPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval,
featureSubsetStrategy, validationIndicatorCol, validationTol)

val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
val (baseLearners, learnerWeights) = if (withValidation) {
GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
$(seed), $(featureSubsetStrategy))
Expand Down Expand Up @@ -323,9 +332,7 @@ class GBTRegressionModel private[ml](
*/
@Since("2.4.0")
def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = {
val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) => LabeledPoint(label, features)
}
val data = extractInstances(dataset)
GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights,
convertToOldLossType(loss), OldAlgo.Regression)
}
Expand Down Expand Up @@ -368,10 +375,9 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]

val trees: Array[DecisionTreeRegressionModel] = treesData.map {
val trees = treesData.map {
case (treeMetadata, root) =>
val tree =
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
treeMetadata.getAndSetParams(tree)
tree
}
Expand Down
Loading