Skip to content

Commit 8c81831

Browse files
committed
update default values of tree:
1. maxMemoryInMB: 128 -> 256 2. maxBins: 100 -> 32 3. maxDepth: 4 -> 5 (in some example code)
1 parent 16a73c2 commit 8c81831

File tree

5 files changed

+16
-16
lines changed

5 files changed

+16
-16
lines changed

docs/mllib-decision-tree.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ The ordered splits create "bins" and the maximum number of such
8080
bins can be specified using the `maxBins` parameter.
8181

8282
Note that the number of bins cannot be greater than the number of instances `$N$` (a rare scenario
83-
since the default `maxBins` value is 100). The tree algorithm automatically reduces the number of
83+
since the default `maxBins` value is 32). The tree algorithm automatically reduces the number of
8484
bins if the condition is not satisfied.
8585

8686
**Categorical features**
@@ -117,7 +117,7 @@ all nodes at each level of the tree. This could lead to high memory requirements
117117
of the tree, potentially leading to memory overflow errors. To alleviate this problem, a `maxMemoryInMB`
118118
training parameter specifies the maximum amount of memory at the workers (twice as much at the
119119
master) to be allocated to the histogram computation. The default value is conservatively chosen to
120-
be 128 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements
120+
be 256 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements
121121
for a level-wise computation cross the `maxMemoryInMB` threshold, the node training tasks at each
122122
subsequent level are split into smaller tasks.
123123

@@ -167,7 +167,7 @@ val numClasses = 2
167167
val categoricalFeaturesInfo = Map[Int, Int]()
168168
val impurity = "gini"
169169
val maxDepth = 5
170-
val maxBins = 100
170+
val maxBins = 32
171171

172172
val model = DecisionTree.trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity,
173173
maxDepth, maxBins)
@@ -213,7 +213,7 @@ Integer numClasses = 2;
213213
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
214214
String impurity = "gini";
215215
Integer maxDepth = 5;
216-
Integer maxBins = 100;
216+
Integer maxBins = 32;
217217

218218
// Train a DecisionTree model for classification.
219219
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
@@ -250,7 +250,7 @@ data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
250250
# Train a DecisionTree model.
251251
# Empty categoricalFeaturesInfo indicates all features are continuous.
252252
model = DecisionTree.trainClassifier(data, numClasses=2, categoricalFeaturesInfo={},
253-
impurity='gini', maxDepth=5, maxBins=100)
253+
impurity='gini', maxDepth=5, maxBins=32)
254254

255255
# Evaluate model on training instances and compute training error
256256
predictions = model.predict(data.map(lambda x: x.features))
@@ -293,7 +293,7 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache
293293
val categoricalFeaturesInfo = Map[Int, Int]()
294294
val impurity = "variance"
295295
val maxDepth = 5
296-
val maxBins = 100
296+
val maxBins = 32
297297

298298
val model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo, impurity,
299299
maxDepth, maxBins)
@@ -338,7 +338,7 @@ JavaSparkContext sc = new JavaSparkContext(sparkConf);
338338
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
339339
String impurity = "variance";
340340
Integer maxDepth = 5;
341-
Integer maxBins = 100;
341+
Integer maxBins = 32;
342342

343343
// Train a DecisionTree model.
344344
final DecisionTreeModel model = DecisionTree.trainRegressor(data,
@@ -380,7 +380,7 @@ data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
380380
# Train a DecisionTree model.
381381
# Empty categoricalFeaturesInfo indicates all features are continuous.
382382
model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo={},
383-
impurity='variance', maxDepth=5, maxBins=100)
383+
impurity='variance', maxDepth=5, maxBins=32)
384384

385385
# Evaluate model on training instances and compute training error
386386
predictions = model.predict(data.map(lambda x: x.features))

examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public static void main(String[] args) {
6363
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
6464
String impurity = "gini";
6565
Integer maxDepth = 5;
66-
Integer maxBins = 100;
66+
Integer maxBins = 32;
6767

6868
// Train a DecisionTree model for classification.
6969
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ object DecisionTreeRunner {
5252
input: String = null,
5353
dataFormat: String = "libsvm",
5454
algo: Algo = Classification,
55-
maxDepth: Int = 4,
55+
maxDepth: Int = 5,
5656
impurity: ImpurityType = Gini,
57-
maxBins: Int = 100,
57+
maxBins: Int = 32,
5858
fracTest: Double = 0.2)
5959

6060
def main(args: Array[String]) {

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,18 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
5050
* 1, 2, ... , k-1. It's important to note that features are
5151
* zero-indexed.
5252
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
53-
* 128 MB.
53+
* 256 MB.
5454
*/
5555
@Experimental
5656
class Strategy (
5757
val algo: Algo,
5858
val impurity: Impurity,
5959
val maxDepth: Int,
6060
val numClassesForClassification: Int = 2,
61-
val maxBins: Int = 100,
61+
val maxBins: Int = 32,
6262
val quantileCalculationStrategy: QuantileStrategy = Sort,
6363
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
64-
val maxMemoryInMB: Int = 128) extends Serializable {
64+
val maxMemoryInMB: Int = 256) extends Serializable {
6565

6666
if (algo == Classification) {
6767
require(numClassesForClassification >= 2)

python/pyspark/mllib/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class DecisionTree(object):
138138

139139
@staticmethod
140140
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
141-
impurity="gini", maxDepth=4, maxBins=100):
141+
impurity="gini", maxDepth=5, maxBins=32):
142142
"""
143143
Train a DecisionTreeModel for classification.
144144
@@ -170,7 +170,7 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
170170

171171
@staticmethod
172172
def trainRegressor(data, categoricalFeaturesInfo,
173-
impurity="variance", maxDepth=4, maxBins=100):
173+
impurity="variance", maxDepth=5, maxBins=32):
174174
"""
175175
Train a DecisionTreeModel for regression.
176176

0 commit comments

Comments
 (0)