Skip to content

Commit 15cacc8

Browse files
committed
[SPARK-4486][MLLIB] Improve GradientBoosting APIs and doc
There are some inconsistencies in the gradient boosting APIs. The target is a general boosting meta-algorithm, but the implementation is attached to trees. This was partially due to the delay of SPARK-1856. But for the 1.2 release, we should make the APIs consistent. 1. WeightedEnsembleModel -> private[tree] TreeEnsembleModel and renamed members accordingly. 1. GradientBoosting -> GradientBoostedTrees 1. Add RandomForestModel and GradientBoostedTreesModel and hide CombiningStrategy 1. Slightly refactored TreeEnsembleModel (Vote takes weights into consideration.) 1. Remove `trainClassifier` and `trainRegressor` from `GradientBoostedTrees` because they are the same as `train` 1. Rename class `train` method to `run` because it hides the static methods with the same name in Java. Deprecated `DecisionTree.train` class method. 1. Simplify BoostingStrategy and make sure the input strategy is not modified. Users should put algo and numClasses in treeStrategy. We create ensembleStrategy inside boosting. 1. Fix a bug in GradientBoostedTreesSuite with AbsoluteError 1. doc updates manishamde jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #3374 from mengxr/SPARK-4486 and squashes the following commits: 7097251 [Xiangrui Meng] address joseph's comments 98dea09 [Xiangrui Meng] address manish's comments 4aae3b7 [Xiangrui Meng] add RandomForestModel and GradientBoostedTreesModel, hide CombiningStrategy ea4c467 [Xiangrui Meng] fix unit tests 751da4e [Xiangrui Meng] rename class method train -> run 19030a5 [Xiangrui Meng] update boosting public APIs
1 parent e216ffa commit 15cacc8

File tree

20 files changed

+382
-437
lines changed

20 files changed

+382
-437
lines changed

examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java renamed to examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,18 @@
2727
import org.apache.spark.api.java.function.Function2;
2828
import org.apache.spark.api.java.function.PairFunction;
2929
import org.apache.spark.mllib.regression.LabeledPoint;
30-
import org.apache.spark.mllib.tree.GradientBoosting;
30+
import org.apache.spark.mllib.tree.GradientBoostedTrees;
3131
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
32-
import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
32+
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
3333
import org.apache.spark.mllib.util.MLUtils;
3434

3535
/**
3636
* Classification and regression using gradient-boosted decision trees.
3737
*/
38-
public final class JavaGradientBoostedTrees {
38+
public final class JavaGradientBoostedTreesRunner {
3939

4040
private static void usage() {
41-
System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" +
41+
System.err.println("Usage: JavaGradientBoostedTreesRunner <libsvm format data file>" +
4242
" <Classification/Regression>");
4343
System.exit(-1);
4444
}
@@ -55,7 +55,7 @@ public static void main(String[] args) {
5555
if (args.length > 2) {
5656
usage();
5757
}
58-
SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
58+
SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner");
5959
JavaSparkContext sc = new JavaSparkContext(sparkConf);
6060

6161
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
@@ -64,7 +64,7 @@ public static void main(String[] args) {
6464
// Note: All features are treated as continuous.
6565
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
6666
boostingStrategy.setNumIterations(10);
67-
boostingStrategy.weakLearnerParams().setMaxDepth(5);
67+
boostingStrategy.treeStrategy().setMaxDepth(5);
6868

6969
if (algo.equals("Classification")) {
7070
// Compute the number of classes from the data.
@@ -73,10 +73,10 @@ public static void main(String[] args) {
7373
return p.label();
7474
}
7575
}).countByValue().size();
76-
boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
76+
boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses);
7777

7878
// Train a GradientBoosting model for classification.
79-
final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
79+
final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
8080

8181
// Evaluate model on training instances and compute training error
8282
JavaPairRDD<Double, Double> predictionAndLabel =
@@ -95,7 +95,7 @@ public static void main(String[] args) {
9595
System.out.println("Learned classification tree model:\n" + model);
9696
} else if (algo.equals("Regression")) {
9797
// Train a GradientBoosting model for classification.
98-
final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
98+
final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
9999

100100
// Evaluate model on training instances and compute training error
101101
JavaPairRDD<Double, Double> predictionAndLabel =

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ import scopt.OptionParser
2222
import org.apache.spark.{SparkConf, SparkContext}
2323
import org.apache.spark.SparkContext._
2424
import org.apache.spark.mllib.evaluation.MulticlassMetrics
25+
import org.apache.spark.mllib.linalg.Vector
2526
import org.apache.spark.mllib.regression.LabeledPoint
26-
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
27+
import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity}
2728
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
2829
import org.apache.spark.mllib.tree.configuration.Algo._
29-
import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
3030
import org.apache.spark.mllib.util.MLUtils
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.util.Utils
@@ -349,24 +349,14 @@ object DecisionTreeRunner {
349349
sc.stop()
350350
}
351351

352-
/**
353-
* Calculates the mean squared error for regression.
354-
*/
355-
private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
356-
data.map { y =>
357-
val err = tree.predict(y.features) - y.label
358-
err * err
359-
}.mean()
360-
}
361-
362352
/**
363353
* Calculates the mean squared error for regression.
364354
*/
365355
private[mllib] def meanSquaredError(
366-
tree: WeightedEnsembleModel,
356+
model: { def predict(features: Vector): Double },
367357
data: RDD[LabeledPoint]): Double = {
368358
data.map { y =>
369-
val err = tree.predict(y.features) - y.label
359+
val err = model.predict(y.features) - y.label
370360
err * err
371361
}.mean()
372362
}

examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala renamed to examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,21 @@ import scopt.OptionParser
2121

2222
import org.apache.spark.{SparkConf, SparkContext}
2323
import org.apache.spark.mllib.evaluation.MulticlassMetrics
24-
import org.apache.spark.mllib.tree.GradientBoosting
24+
import org.apache.spark.mllib.tree.GradientBoostedTrees
2525
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
2626
import org.apache.spark.util.Utils
2727

2828
/**
2929
* An example runner for Gradient Boosting using decision trees as weak learners. Run with
3030
* {{{
31-
* ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options]
31+
* ./bin/run-example mllib.GradientBoostedTreesRunner [options]
3232
* }}}
3333
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
3434
*
3535
* Note: This script treats all features as real-valued (not categorical).
3636
* To include categorical features, modify categoricalFeaturesInfo.
3737
*/
38-
object GradientBoostedTrees {
38+
object GradientBoostedTreesRunner {
3939

4040
case class Params(
4141
input: String = null,
@@ -93,24 +93,24 @@ object GradientBoostedTrees {
9393

9494
def run(params: Params) {
9595

96-
val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params")
96+
val conf = new SparkConf().setAppName(s"GradientBoostedTreesRunner with $params")
9797
val sc = new SparkContext(conf)
9898

99-
println(s"GradientBoostedTrees with parameters:\n$params")
99+
println(s"GradientBoostedTreesRunner with parameters:\n$params")
100100

101101
// Load training and test data and cache it.
102102
val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
103103
params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
104104

105105
val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
106-
boostingStrategy.numClassesForClassification = numClasses
106+
boostingStrategy.treeStrategy.numClassesForClassification = numClasses
107107
boostingStrategy.numIterations = params.numIterations
108-
boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth
108+
boostingStrategy.treeStrategy.maxDepth = params.maxDepth
109109

110110
val randomSeed = Utils.random.nextInt()
111111
if (params.algo == "Classification") {
112112
val startTime = System.nanoTime()
113-
val model = GradientBoosting.trainClassifier(training, boostingStrategy)
113+
val model = GradientBoostedTrees.train(training, boostingStrategy)
114114
val elapsedTime = (System.nanoTime() - startTime) / 1e9
115115
println(s"Training time: $elapsedTime seconds")
116116
if (model.totalNumNodes < 30) {
@@ -127,7 +127,7 @@ object GradientBoostedTrees {
127127
println(s"Test accuracy = $testAccuracy")
128128
} else if (params.algo == "Regression") {
129129
val startTime = System.nanoTime()
130-
val model = GradientBoosting.trainRegressor(training, boostingStrategy)
130+
val model = GradientBoostedTrees.train(training, boostingStrategy)
131131
val elapsedTime = (System.nanoTime() - startTime) / 1e9
132132
println(s"Training time: $elapsedTime seconds")
133133
if (model.totalNumNodes < 30) {

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,19 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
5858
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
5959
* @return DecisionTreeModel that can be used for prediction
6060
*/
61-
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
61+
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
6262
// Note: random seed will not be used since numTrees = 1.
6363
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
64-
val rfModel = rf.train(input)
65-
rfModel.weakHypotheses(0)
64+
val rfModel = rf.run(input)
65+
rfModel.trees(0)
6666
}
6767

68+
/**
69+
* Trains a decision tree model over an RDD. This is deprecated because it hides the static
70+
* methods with the same name in Java.
71+
*/
72+
@deprecated("Please use DecisionTree.run instead.", "1.2.0")
73+
def train(input: RDD[LabeledPoint]): DecisionTreeModel = run(input)
6874
}
6975

7076
object DecisionTree extends Serializable with Logging {
@@ -86,7 +92,7 @@ object DecisionTree extends Serializable with Logging {
8692
* @return DecisionTreeModel that can be used for prediction
8793
*/
8894
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
89-
new DecisionTree(strategy).train(input)
95+
new DecisionTree(strategy).run(input)
9096
}
9197

9298
/**
@@ -112,7 +118,7 @@ object DecisionTree extends Serializable with Logging {
112118
impurity: Impurity,
113119
maxDepth: Int): DecisionTreeModel = {
114120
val strategy = new Strategy(algo, impurity, maxDepth)
115-
new DecisionTree(strategy).train(input)
121+
new DecisionTree(strategy).run(input)
116122
}
117123

118124
/**
@@ -140,7 +146,7 @@ object DecisionTree extends Serializable with Logging {
140146
maxDepth: Int,
141147
numClassesForClassification: Int): DecisionTreeModel = {
142148
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
143-
new DecisionTree(strategy).train(input)
149+
new DecisionTree(strategy).run(input)
144150
}
145151

146152
/**
@@ -177,7 +183,7 @@ object DecisionTree extends Serializable with Logging {
177183
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
178184
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
179185
quantileCalculationStrategy, categoricalFeaturesInfo)
180-
new DecisionTree(strategy).train(input)
186+
new DecisionTree(strategy).run(input)
181187
}
182188

183189
/**

0 commit comments

Comments
 (0)