Skip to content

Commit 8602195

Browse files
manishamdemengxr
authored andcommitted
[MLLIB] SPARK-1547: Add Gradient Boosting to MLlib
Given the popular demand for gradient boosting and AdaBoost in MLlib, I am creating a WIP branch for early feedback on gradient boosting with AdaBoost to follow soon after this PR is accepted. This is based on work done along with hirakendu that was pending due to decision tree optimizations and random forests work. Ideally, boosting algorithms should work with any base learners. This will soon be possible once the MLlib API is finalized -- we want to ensure we use a consistent interface for the underlying base learners. In the meantime, this PR uses decision trees as base learners for the gradient boosting algorithm. The current PR allows "pluggable" loss functions and provides least squares error and least absolute error by default. Here is the task list: - [x] Gradient boosting support - [x] Pluggable loss functions - [x] Stochastic gradient boosting support – Re-use the BaggedPoint approach used for RandomForest. - [x] Binary classification support - [x] Support configurable checkpointing – This approach will avoid long lineage chains. - [x] Create classification and regression APIs - [x] Weighted Ensemble Model -- created a WeightedEnsembleModel class that can be used by ensemble algorithms such as random forests and boosting. - [x] Unit Tests Future work: + Multi-class classification is currently not supported by this PR since it requires discussion on the best way to support "deviance" as a loss function. + BaggedRDD caching -- Avoid repeating feature to bin mapping for each tree estimator after standard API work is completed. cc: jkbradley hirakendu mengxr etrain atalwalkar chouqin Author: Manish Amde <manish9ue@gmail.com> Author: manishamde <manish9ue@gmail.com> Closes apache#2607 from manishamde/gbt and squashes the following commits: 991c7b5 [Manish Amde] public api ff2a796 [Manish Amde] addressing comments b4c1318 [Manish Amde] removing spaces 8476b6b [Manish Amde] fixing line length 0183cb9 [Manish Amde] fixed naming and formatting issues 1c40c33 [Manish Amde] add newline, removed spaces e33ab61 [Manish Amde] minor comment eadbf09 [Manish Amde] parameter renaming 035a2ed [Manish Amde] jkbradley formatting suggestions 9f7359d [Manish Amde] simplified gbt logic and added more tests 49ba107 [Manish Amde] merged from master eff21fe [Manish Amde] Added gradient boosting tests 3fd0528 [Manish Amde] moved helper methods to new class a32a5ab [Manish Amde] added test for subsampling without replacement 781542a [Manish Amde] added support for fractional subsampling with replacement 3a18cc1 [Manish Amde] cleaned up api for conversion to bagged point and moved tests to it's own test suite 0e81906 [Manish Amde] improving caching unpersisting logic d971f73 [Manish Amde] moved RF code to use WeightedEnsembleModel class fee06d3 [Manish Amde] added weighted ensemble model 1b01943 [Manish Amde] add weights for base learners 9bc6e74 [Manish Amde] adding random seed as parameter d2c8323 [Manish Amde] Merge branch 'master' into gbt 2ae97b7 [Manish Amde] added documentation for the loss classes 9366b8f [Manish Amde] minor: using numTrees instead of trees.size 3b43896 [Manish Amde] added learning rate for prediction 9b2e35e [Manish Amde] Merge branch 'master' into gbt 6a11c02 [manishamde] fixing formatting 823691b [Manish Amde] fixing RF test 1f47941 [Manish Amde] changing access modifier 5b67102 [Manish Amde] shortened parameter list 5ab3796 [Manish Amde] minor reformatting 9155a9d [Manish Amde] consolidated boosting configuration and added public API 631baea [Manish Amde] Merge branch 'master' into gbt 2cb1258 [Manish Amde] public API support 3b8ffc0 [Manish Amde] added documentation 8e10c63 [Manish Amde] modified unpersist strategy f62bc48 [Manish Amde] added unpersist bdca43a [Manish Amde] added timing parameters 2fbc9c7 [Manish Amde] fixing binomial classification prediction 6dd4dd8 [Manish Amde] added support for log loss 9af0231 [Manish Amde] classification attempt 62cc000 [Manish Amde] basic checkpointing 4784091 [Manish Amde] formatting 78ed452 [Manish Amde] added newline and fixed if statement 3973dd1 [Manish Amde] minor indicating subsample is double during comparison aa8fae7 [Manish Amde] minor refactoring 1a8031c [Manish Amde] sampling with replacement f1c9ef7 [Manish Amde] Merge branch 'master' into gbt cdceeef [Manish Amde] added documentation 6251fd5 [Manish Amde] modified method name 5538521 [Manish Amde] disable checkpointing for now 0ae1c0a [Manish Amde] basic gradient boosting code from earlier branches
1 parent e07fb6a commit 8602195

20 files changed

+1331
-267
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
2626
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
2727
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
2828
import org.apache.spark.mllib.tree.configuration.Algo._
29-
import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
29+
import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
3030
import org.apache.spark.mllib.util.MLUtils
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.util.Utils
@@ -317,7 +317,7 @@ object DecisionTreeRunner {
317317
/**
318318
* Calculates the mean squared error for regression.
319319
*/
320-
private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = {
320+
private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
321321
data.map { y =>
322322
val err = tree.predict(y.features) - y.label
323323
err * err

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
6262
// Note: random seed will not be used since numTrees = 1.
6363
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
6464
val rfModel = rf.train(input)
65-
rfModel.trees(0)
65+
rfModel.weakHypotheses(0)
6666
}
6767

6868
}
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree
19+
20+
import scala.collection.JavaConverters._
21+
22+
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.api.java.JavaRDD
24+
import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy}
25+
import org.apache.spark.Logging
26+
import org.apache.spark.mllib.tree.impl.TimeTracker
27+
import org.apache.spark.mllib.tree.loss.Losses
28+
import org.apache.spark.rdd.RDD
29+
import org.apache.spark.mllib.regression.LabeledPoint
30+
import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
31+
import org.apache.spark.mllib.tree.configuration.Algo._
32+
import org.apache.spark.storage.StorageLevel
33+
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
34+
35+
/**
36+
* :: Experimental ::
37+
* A class that implements gradient boosting for regression and binary classification problems.
38+
* @param boostingStrategy Parameters for the gradient boosting algorithm
39+
*/
40+
@Experimental
41+
class GradientBoosting (
42+
private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {
43+
44+
/**
45+
* Method to train a gradient boosting model
46+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
47+
* @return WeightedEnsembleModel that can be used for prediction
48+
*/
49+
def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
50+
val algo = boostingStrategy.algo
51+
algo match {
52+
case Regression => GradientBoosting.boost(input, boostingStrategy)
53+
case Classification =>
54+
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
55+
GradientBoosting.boost(remappedInput, boostingStrategy)
56+
case _ =>
57+
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
58+
}
59+
}
60+
61+
}
62+
63+
64+
object GradientBoosting extends Logging {
65+
66+
/**
67+
* Method to train a gradient boosting model.
68+
*
69+
* Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
70+
* is recommended to clearly specify regression.
71+
* Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
72+
* is recommended to clearly specify regression.
73+
*
74+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
75+
* For classification, labels should take values {0, 1, ..., numClasses-1}.
76+
* For regression, labels are real numbers.
77+
* @param boostingStrategy Configuration options for the boosting algorithm.
78+
* @return WeightedEnsembleModel that can be used for prediction
79+
*/
80+
def train(
81+
input: RDD[LabeledPoint],
82+
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
83+
new GradientBoosting(boostingStrategy).train(input)
84+
}
85+
86+
/**
87+
* Method to train a gradient boosting classification model.
88+
*
89+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
90+
* For classification, labels should take values {0, 1, ..., numClasses-1}.
91+
* For regression, labels are real numbers.
92+
* @param boostingStrategy Configuration options for the boosting algorithm.
93+
* @return WeightedEnsembleModel that can be used for prediction
94+
*/
95+
def trainClassifier(
96+
input: RDD[LabeledPoint],
97+
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
98+
val algo = boostingStrategy.algo
99+
require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.")
100+
new GradientBoosting(boostingStrategy).train(input)
101+
}
102+
103+
/**
104+
* Method to train a gradient boosting regression model.
105+
*
106+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
107+
* For classification, labels should take values {0, 1, ..., numClasses-1}.
108+
* For regression, labels are real numbers.
109+
* @param boostingStrategy Configuration options for the boosting algorithm.
110+
* @return WeightedEnsembleModel that can be used for prediction
111+
*/
112+
def trainRegressor(
113+
input: RDD[LabeledPoint],
114+
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
115+
val algo = boostingStrategy.algo
116+
require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.")
117+
new GradientBoosting(boostingStrategy).train(input)
118+
}
119+
120+
/**
121+
* Method to train a gradient boosting binary classification model.
122+
*
123+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
124+
* For classification, labels should take values {0, 1, ..., numClasses-1}.
125+
* For regression, labels are real numbers.
126+
* @param numEstimators Number of estimators used in boosting stages. In other words,
127+
* number of boosting iterations performed.
128+
* @param loss Loss function used for minimization during gradient boosting.
129+
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
130+
* learning rate should be between in the interval (0, 1]
131+
* @param subsamplingRate Fraction of the training data used for learning the decision tree.
132+
* @param numClassesForClassification Number of classes for classification.
133+
* (Ignored for regression.)
134+
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
135+
* the number of discrete values they take. For example,
136+
* an entry (n -> k) implies the feature n is categorical with k
137+
* categories 0, 1, 2, ... , k-1. It's important to note that
138+
* features are zero-indexed.
139+
* @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
140+
* supported.)
141+
* @return WeightedEnsembleModel that can be used for prediction
142+
*/
143+
def trainClassifier(
144+
input: RDD[LabeledPoint],
145+
numEstimators: Int,
146+
loss: String,
147+
learningRate: Double,
148+
subsamplingRate: Double,
149+
numClassesForClassification: Int,
150+
categoricalFeaturesInfo: Map[Int, Int],
151+
weakLearnerParams: Strategy): WeightedEnsembleModel = {
152+
val lossType = Losses.fromString(loss)
153+
val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType,
154+
learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
155+
weakLearnerParams)
156+
new GradientBoosting(boostingStrategy).train(input)
157+
}
158+
159+
/**
160+
* Method to train a gradient boosting regression model.
161+
*
162+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
163+
* For classification, labels should take values {0, 1, ..., numClasses-1}.
164+
* For regression, labels are real numbers.
165+
* @param numEstimators Number of estimators used in boosting stages. In other words,
166+
* number of boosting iterations performed.
167+
* @param loss Loss function used for minimization during gradient boosting.
168+
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
169+
* learning rate should be between in the interval (0, 1]
170+
* @param subsamplingRate Fraction of the training data used for learning the decision tree.
171+
* @param numClassesForClassification Number of classes for classification.
172+
* (Ignored for regression.)
173+
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
174+
* the number of discrete values they take. For example,
175+
* an entry (n -> k) implies the feature n is categorical with k
176+
* categories 0, 1, 2, ... , k-1. It's important to note that
177+
* features are zero-indexed.
178+
* @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
179+
* supported.)
180+
* @return WeightedEnsembleModel that can be used for prediction
181+
*/
182+
def trainRegressor(
183+
input: RDD[LabeledPoint],
184+
numEstimators: Int,
185+
loss: String,
186+
learningRate: Double,
187+
subsamplingRate: Double,
188+
numClassesForClassification: Int,
189+
categoricalFeaturesInfo: Map[Int, Int],
190+
weakLearnerParams: Strategy): WeightedEnsembleModel = {
191+
val lossType = Losses.fromString(loss)
192+
val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType,
193+
learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
194+
weakLearnerParams)
195+
new GradientBoosting(boostingStrategy).train(input)
196+
}
197+
198+
/**
199+
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
200+
*/
201+
def trainClassifier(
202+
input: RDD[LabeledPoint],
203+
numEstimators: Int,
204+
loss: String,
205+
learningRate: Double,
206+
subsamplingRate: Double,
207+
numClassesForClassification: Int,
208+
categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer],
209+
weakLearnerParams: Strategy): WeightedEnsembleModel = {
210+
trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate,
211+
numClassesForClassification,
212+
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
213+
weakLearnerParams)
214+
}
215+
216+
/**
217+
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
218+
*/
219+
def trainRegressor(
220+
input: RDD[LabeledPoint],
221+
numEstimators: Int,
222+
loss: String,
223+
learningRate: Double,
224+
subsamplingRate: Double,
225+
numClassesForClassification: Int,
226+
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
227+
weakLearnerParams: Strategy): WeightedEnsembleModel = {
228+
trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate,
229+
numClassesForClassification,
230+
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
231+
weakLearnerParams)
232+
}
233+
234+
235+
/**
236+
* Internal method for performing regression using trees as base learners.
237+
* @param input training dataset
238+
* @param boostingStrategy boosting parameters
239+
* @return
240+
*/
241+
private def boost(
242+
input: RDD[LabeledPoint],
243+
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
244+
245+
val timer = new TimeTracker()
246+
timer.start("total")
247+
timer.start("init")
248+
249+
// Initialize gradient boosting parameters
250+
val numEstimators = boostingStrategy.numEstimators
251+
val baseLearners = new Array[DecisionTreeModel](numEstimators)
252+
val baseLearnerWeights = new Array[Double](numEstimators)
253+
val loss = boostingStrategy.loss
254+
val learningRate = boostingStrategy.learningRate
255+
val strategy = boostingStrategy.weakLearnerParams
256+
257+
// Cache input
258+
input.persist(StorageLevel.MEMORY_AND_DISK)
259+
260+
timer.stop("init")
261+
262+
logDebug("##########")
263+
logDebug("Building tree 0")
264+
logDebug("##########")
265+
var data = input
266+
267+
// 1. Initialize tree
268+
timer.start("building tree 0")
269+
val firstTreeModel = new DecisionTree(strategy).train(data)
270+
baseLearners(0) = firstTreeModel
271+
baseLearnerWeights(0) = 1.0
272+
val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression,
273+
Sum)
274+
logDebug("error of gbt = " + loss.computeError(startingModel, input))
275+
// Note: A model of type regression is used since we require raw prediction
276+
timer.stop("building tree 0")
277+
278+
// psuedo-residual for second iteration
279+
data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
280+
point.features))
281+
282+
var m = 1
283+
while (m < numEstimators) {
284+
timer.start(s"building tree $m")
285+
logDebug("###################################################")
286+
logDebug("Gradient boosting tree iteration " + m)
287+
logDebug("###################################################")
288+
val model = new DecisionTree(strategy).train(data)
289+
timer.stop(s"building tree $m")
290+
// Create partial model
291+
baseLearners(m) = model
292+
baseLearnerWeights(m) = learningRate
293+
// Note: A model of type regression is used since we require raw prediction
294+
val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
295+
baseLearnerWeights.slice(0, m + 1), Regression, Sum)
296+
logDebug("error of gbt = " + loss.computeError(partialModel, input))
297+
// Update data with pseudo-residuals
298+
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
299+
point.features))
300+
m += 1
301+
}
302+
303+
timer.stop("total")
304+
305+
logInfo("Internal timing for DecisionTree:")
306+
logInfo(s"$timer")
307+
308+
309+
// 3. Output classifier
310+
new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)
311+
312+
}
313+
314+
}

0 commit comments

Comments
 (0)