Skip to content

Commit ee1d236

Browse files
committed
DecisionTree API updates:
* Removed train() function in Python API (tree.py) ** Removed corresponding function in Scala/Java API (the ones taking basic types) DecisionTree internal updates: * Renamed Algo and Impurity factory methods to fromString() DecisionTree doc updates: * Added notes recommending use of trainClassifier, trainRegressor * Say supported values for impurity * Shortened doc for Java-friendly train* functions.
1 parent 00f820e commit ee1d236

File tree

5 files changed

+57
-168
lines changed

5 files changed

+57
-168
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,8 +521,8 @@ class PythonMLLibAPI extends Serializable {
521521

522522
val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
523523

524-
val algo = Algo.stringToAlgo(algoStr)
525-
val impurity = Impurities.stringToImpurity(impurityStr)
524+
val algo = Algo.fromString(algoStr)
525+
val impurity = Impurities.fromString(impurityStr)
526526

527527
val strategy = new Strategy(
528528
algo = algo,

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 37 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ object DecisionTree extends Serializable with Logging {
202202
* Method to train a decision tree model.
203203
* The method supports binary and multiclass classification and regression.
204204
*
205+
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
206+
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
207+
* is recommended to clearly separate classification and regression.
208+
*
205209
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
206210
* For classification, labels should take values {0, 1, ..., numClasses-1}.
207211
* For regression, labels are real numbers.
@@ -218,6 +222,10 @@ object DecisionTree extends Serializable with Logging {
218222
* Method to train a decision tree model.
219223
* The method supports binary and multiclass classification and regression.
220224
*
225+
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
226+
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
227+
* is recommended to clearly separate classification and regression.
228+
*
221229
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
222230
* For classification, labels should take values {0, 1, ..., numClasses-1}.
223231
* For regression, labels are real numbers.
@@ -240,6 +248,10 @@ object DecisionTree extends Serializable with Logging {
240248
* Method to train a decision tree model.
241249
* The method supports binary and multiclass classification and regression.
242250
*
251+
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
252+
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
253+
* is recommended to clearly separate classification and regression.
254+
*
243255
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
244256
* For classification, labels should take values {0, 1, ..., numClasses-1}.
245257
* For regression, labels are real numbers.
@@ -264,6 +276,10 @@ object DecisionTree extends Serializable with Logging {
264276
* Method to train a decision tree model.
265277
* The method supports binary and multiclass classification and regression.
266278
*
279+
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
280+
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
281+
* is recommended to clearly separate classification and regression.
282+
*
267283
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
268284
* For classification, labels should take values {0, 1, ..., numClasses-1}.
269285
* For regression, labels are real numbers.
@@ -293,92 +309,22 @@ object DecisionTree extends Serializable with Logging {
293309
new DecisionTree(strategy).train(input)
294310
}
295311

296-
/**
297-
* Method to train a decision tree model.
298-
* The method supports binary and multiclass classification and regression.
299-
* This version takes basic types, for consistency with Python API.
300-
*
301-
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
302-
* For classification, labels should take values {0, 1, ..., numClasses-1}.
303-
* For regression, labels are real numbers.
304-
* @param algo "classification" or "regression"
305-
* @param numClassesForClassification number of classes for classification. Default value of 2.
306-
* @param categoricalFeaturesInfo Map storing arity of categorical features.
307-
* E.g., an entry (n -> k) indicates that feature n is categorical
308-
* with k categories indexed from 0: {0, 1, ..., k-1}.
309-
* @param impurity criterion used for information gain calculation
310-
* @param maxDepth Maximum depth of the tree.
311-
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
312-
* @param maxBins maximum number of bins used for splitting features
313-
* (default Python value = 100)
314-
* @return DecisionTreeModel that can be used for prediction
315-
*/
316-
def train(
317-
input: RDD[LabeledPoint],
318-
algo: String,
319-
numClassesForClassification: Int,
320-
categoricalFeaturesInfo: Map[Int, Int],
321-
impurity: String,
322-
maxDepth: Int,
323-
maxBins: Int): DecisionTreeModel = {
324-
val algoType = Algo.stringToAlgo(algo)
325-
val impurityType = Impurities.stringToImpurity(impurity)
326-
train(input, algoType, impurityType, maxDepth, numClassesForClassification, maxBins, Sort,
327-
categoricalFeaturesInfo)
328-
}
329-
330-
/**
331-
* Method to train a decision tree model.
332-
* The method supports binary and multiclass classification and regression.
333-
* This version takes basic types, for consistency with Python API.
334-
* This version is Java-friendly, taking a Java map for categoricalFeaturesInfo.
335-
*
336-
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
337-
* For classification, labels should take values {0, 1, ..., numClasses-1}.
338-
* For regression, labels are real numbers.
339-
* @param algo "classification" or "regression"
340-
* @param numClassesForClassification number of classes for classification. Default value of 2.
341-
* @param categoricalFeaturesInfo Map storing arity of categorical features.
342-
* E.g., an entry (n -> k) indicates that feature n is categorical
343-
* with k categories indexed from 0: {0, 1, ..., k-1}.
344-
* @param impurity criterion used for information gain calculation
345-
* @param maxDepth Maximum depth of the tree.
346-
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
347-
* @param maxBins maximum number of bins used for splitting features
348-
* (default Python value = 100)
349-
* @return DecisionTreeModel that can be used for prediction
350-
*/
351-
def train(
352-
input: RDD[LabeledPoint],
353-
algo: String,
354-
numClassesForClassification: Int,
355-
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
356-
impurity: String,
357-
maxDepth: Int,
358-
maxBins: Int): DecisionTreeModel = {
359-
train(input, algo, numClassesForClassification,
360-
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
361-
impurity, maxDepth, maxBins)
362-
}
363-
364312
/**
365313
* Method to train a decision tree model for binary or multiclass classification.
366-
* This version takes basic types, for consistency with Python API.
367314
*
368315
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
369316
* Labels should take values {0, 1, ..., numClasses-1}.
370317
* @param numClassesForClassification number of classes for classification.
371318
* @param categoricalFeaturesInfo Map storing arity of categorical features.
372319
* E.g., an entry (n -> k) indicates that feature n is categorical
373320
* with k categories indexed from 0: {0, 1, ..., k-1}.
374-
* (default Python value = {}, i.e., no categorical features)
375-
* @param impurity criterion used for information gain calculation
376-
* (default Python value = "gini")
321+
* @param impurity Criterion used for information gain calculation.
322+
* Supported values: "gini" (recommended) or "entropy".
377323
* @param maxDepth Maximum depth of the tree.
378324
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
379-
* (default Python value = 4)
325+
* (suggested value: 4)
380326
* @param maxBins maximum number of bins used for splitting features
381-
* (default Python value = 100)
327+
* (suggested value: 100)
382328
* @return DecisionTreeModel that can be used for prediction
383329
*/
384330
def trainClassifier(
@@ -388,30 +334,13 @@ object DecisionTree extends Serializable with Logging {
388334
impurity: String,
389335
maxDepth: Int,
390336
maxBins: Int): DecisionTreeModel = {
391-
train(input, "classification", numClassesForClassification, categoricalFeaturesInfo, impurity,
392-
maxDepth, maxBins)
337+
val impurityType = Impurities.fromString(impurity)
338+
train(input, Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort,
339+
categoricalFeaturesInfo)
393340
}
394341

395342
/**
396-
* Method to train a decision tree model for binary or multiclass classification.
397-
* This version takes basic types, for consistency with Python API.
398-
* This version is Java-friendly, taking a Java map for categoricalFeaturesInfo.
399-
*
400-
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
401-
* Labels should take values {0, 1, ..., numClasses-1}.
402-
* @param numClassesForClassification number of classes for classification.
403-
* @param categoricalFeaturesInfo Map storing arity of categorical features.
404-
* E.g., an entry (n -> k) indicates that feature n is categorical
405-
* with k categories indexed from 0: {0, 1, ..., k-1}.
406-
* (default Python value = {}, i.e., no categorical features)
407-
* @param impurity criterion used for information gain calculation
408-
* (default Python value = "gini")
409-
* @param maxDepth Maximum depth of the tree.
410-
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
411-
* (default Python value = 4)
412-
* @param maxBins maximum number of bins used for splitting features
413-
* (default Python value = 100)
414-
* @return DecisionTreeModel that can be used for prediction
343+
* Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
415344
*/
416345
def trainClassifier(
417346
input: RDD[LabeledPoint],
@@ -427,21 +356,19 @@ object DecisionTree extends Serializable with Logging {
427356

428357
/**
429358
* Method to train a decision tree model for regression.
430-
* This version takes basic types, for consistency with Python API.
431359
*
432360
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
433361
* Labels are real numbers.
434362
* @param categoricalFeaturesInfo Map storing arity of categorical features.
435363
* E.g., an entry (n -> k) indicates that feature n is categorical
436364
* with k categories indexed from 0: {0, 1, ..., k-1}.
437-
* (default Python value = {}, i.e., no categorical features)
438-
* @param impurity criterion used for information gain calculation
439-
* (default Python value = "variance")
365+
* @param impurity Criterion used for information gain calculation.
366+
* Supported values: "variance".
440367
* @param maxDepth Maximum depth of the tree.
441368
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
442-
* (default Python value = 4)
369+
* (suggested value: 4)
443370
* @param maxBins maximum number of bins used for splitting features
444-
* (default Python value = 100)
371+
* (suggested value: 100)
445372
* @return DecisionTreeModel that can be used for prediction
446373
*/
447374
def trainRegressor(
@@ -450,28 +377,12 @@ object DecisionTree extends Serializable with Logging {
450377
impurity: String,
451378
maxDepth: Int,
452379
maxBins: Int): DecisionTreeModel = {
453-
train(input, "regression", 0, categoricalFeaturesInfo, impurity, maxDepth, maxBins)
380+
val impurityType = Impurities.fromString(impurity)
381+
train(input, Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo)
454382
}
455383

456384
/**
457-
* Method to train a decision tree model for regression.
458-
* This version takes basic types, for consistency with Python API.
459-
* This version is Java-friendly, taking a Java map for categoricalFeaturesInfo.
460-
*
461-
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
462-
* Labels are real numbers.
463-
* @param categoricalFeaturesInfo Map storing arity of categorical features.
464-
* E.g., an entry (n -> k) indicates that feature n is categorical
465-
* with k categories indexed from 0: {0, 1, ..., k-1}.
466-
* (default Python value = {}, i.e., no categorical features)
467-
* @param impurity criterion used for information gain calculation
468-
* (default Python value = "variance")
469-
* @param maxDepth Maximum depth of the tree.
470-
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
471-
* (default Python value = 4)
472-
* @param maxBins maximum number of bins used for splitting features
473-
* (default Python value = 100)
474-
* @return DecisionTreeModel that can be used for prediction
385+
* Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
475386
*/
476387
def trainRegressor(
477388
input: RDD[LabeledPoint],
@@ -1516,16 +1427,15 @@ object DecisionTree extends Serializable with Logging {
15161427
* Categorical features:
15171428
* For each feature, there is 1 bin per split.
15181429
* Splits and bins are handled in 2 ways:
1519-
* (a) For multiclass classification with a low-arity feature
1430+
* (a) "unordered features"
1431+
* For multiclass classification with a low-arity feature
15201432
* (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
15211433
* the feature is split based on subsets of categories.
1522-
* There are math.pow(2, (maxFeatureValue - 1) - 1) splits.
1523-
* (b) For regression and binary classification,
1434+
* There are math.pow(2, maxFeatureValue - 1) - 1 splits.
1435+
* (b) "ordered features"
1436+
* For regression and binary classification,
15241437
* and for multiclass classification with a high-arity feature,
1525-
*
1526-
1527-
* Categorical case (a) features are called unordered features.
1528-
* Other cases are called ordered features.
1438+
* there is one bin per category.
15291439
*
15301440
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
15311441
* @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ object Algo extends Enumeration {
2828
type Algo = Value
2929
val Classification, Regression = Value
3030

31-
private[mllib] def stringToAlgo(name: String): Algo = name match {
31+
private[mllib] def fromString(name: String): Algo = name match {
3232
case "classification" => Classification
3333
case "regression" => Regression
3434
case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name")

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
package org.apache.spark.mllib.tree.impurity
1919

2020
/**
21-
* Factory class for Impurity types.
21+
* Factory for Impurity.
2222
*/
2323
private[mllib] object Impurities {
2424

25-
def stringToImpurity(name: String): Impurity = name match {
25+
def fromString(name: String): Impurity = name match {
2626
case "gini" => Gini
2727
case "entropy" => Entropy
2828
case "variance" => Variance

python/pyspark/mllib/tree.py

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class DecisionTree(object):
128128
"""
129129

130130
@staticmethod
131-
def trainClassifier(data, numClasses, categoricalFeaturesInfo={},
131+
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
132132
impurity="gini", maxDepth=4, maxBins=100):
133133
"""
134134
Train a DecisionTreeModel for classification.
@@ -147,12 +147,20 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo={},
147147
:param maxBins: Number of bins used for finding splits at each node.
148148
:return: DecisionTreeModel
149149
"""
150-
return DecisionTree.train(data, "classification", numClasses,
151-
categoricalFeaturesInfo,
152-
impurity, maxDepth, maxBins)
150+
sc = data.context
151+
dataBytes = _get_unmangled_labeled_point_rdd(data)
152+
categoricalFeaturesInfoJMap = \
153+
MapConverter().convert(categoricalFeaturesInfo,
154+
sc._gateway._gateway_client)
155+
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
156+
dataBytes._jrdd, "classification",
157+
numClasses, categoricalFeaturesInfoJMap,
158+
impurity, maxDepth, maxBins)
159+
dataBytes.unpersist()
160+
return DecisionTreeModel(sc, model)
153161

154162
@staticmethod
155-
def trainRegressor(data, categoricalFeaturesInfo={},
163+
def trainRegressor(data, categoricalFeaturesInfo,
156164
impurity="variance", maxDepth=4, maxBins=100):
157165
"""
158166
Train a DecisionTreeModel for regression.
@@ -170,43 +178,14 @@ def trainRegressor(data, categoricalFeaturesInfo={},
170178
:param maxBins: Number of bins used for finding splits at each node.
171179
:return: DecisionTreeModel
172180
"""
173-
return DecisionTree.train(data, "regression", 0,
174-
categoricalFeaturesInfo,
175-
impurity, maxDepth, maxBins)
176-
177-
178-
@staticmethod
179-
def train(data, algo, numClasses, categoricalFeaturesInfo,
180-
impurity, maxDepth, maxBins=100):
181-
"""
182-
Train a DecisionTreeModel for classification or regression.
183-
184-
:param data: Training data: RDD of LabeledPoint.
185-
For classification, labels are integers
186-
{0,1,...,numClasses}.
187-
For regression, labels are real numbers.
188-
:param algo: "classification" or "regression"
189-
:param numClasses: Number of classes for classification.
190-
:param categoricalFeaturesInfo: Map from categorical feature index
191-
to number of categories.
192-
Any feature not in this map
193-
is treated as continuous.
194-
:param impurity: For classification: "entropy" or "gini".
195-
For regression: "variance".
196-
:param maxDepth: Max depth of tree.
197-
E.g., depth 0 means 1 leaf node.
198-
Depth 1 means 1 internal node + 2 leaf nodes.
199-
:param maxBins: Number of bins used for finding splits at each node.
200-
:return: DecisionTreeModel
201-
"""
202181
sc = data.context
203182
dataBytes = _get_unmangled_labeled_point_rdd(data)
204183
categoricalFeaturesInfoJMap = \
205184
MapConverter().convert(categoricalFeaturesInfo,
206185
sc._gateway._gateway_client)
207186
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
208-
dataBytes._jrdd, algo,
209-
numClasses, categoricalFeaturesInfoJMap,
187+
dataBytes._jrdd, "regression",
188+
0, categoricalFeaturesInfoJMap,
210189
impurity, maxDepth, maxBins)
211190
dataBytes.unpersist()
212191
return DecisionTreeModel(sc, model)

0 commit comments

Comments
 (0)