@@ -202,6 +202,10 @@ object DecisionTree extends Serializable with Logging {
202
202
* Method to train a decision tree model.
203
203
* The method supports binary and multiclass classification and regression.
204
204
*
205
+ * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier ]]
206
+ * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor ]]
207
+ * is recommended to clearly separate classification and regression.
208
+ *
205
209
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
206
210
* For classification, labels should take values {0, 1, ..., numClasses-1}.
207
211
* For regression, labels are real numbers.
@@ -218,6 +222,10 @@ object DecisionTree extends Serializable with Logging {
218
222
* Method to train a decision tree model.
219
223
* The method supports binary and multiclass classification and regression.
220
224
*
225
+ * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier ]]
226
+ * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor ]]
227
+ * is recommended to clearly separate classification and regression.
228
+ *
221
229
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
222
230
* For classification, labels should take values {0, 1, ..., numClasses-1}.
223
231
* For regression, labels are real numbers.
@@ -240,6 +248,10 @@ object DecisionTree extends Serializable with Logging {
240
248
* Method to train a decision tree model.
241
249
* The method supports binary and multiclass classification and regression.
242
250
*
251
+ * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier ]]
252
+ * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor ]]
253
+ * is recommended to clearly separate classification and regression.
254
+ *
243
255
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
244
256
* For classification, labels should take values {0, 1, ..., numClasses-1}.
245
257
* For regression, labels are real numbers.
@@ -264,6 +276,10 @@ object DecisionTree extends Serializable with Logging {
264
276
* Method to train a decision tree model.
265
277
* The method supports binary and multiclass classification and regression.
266
278
*
279
+ * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier ]]
280
+ * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor ]]
281
+ * is recommended to clearly separate classification and regression.
282
+ *
267
283
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
268
284
* For classification, labels should take values {0, 1, ..., numClasses-1}.
269
285
* For regression, labels are real numbers.
@@ -293,92 +309,22 @@ object DecisionTree extends Serializable with Logging {
293
309
new DecisionTree (strategy).train(input)
294
310
}
295
311
296
- /**
297
- * Method to train a decision tree model.
298
- * The method supports binary and multiclass classification and regression.
299
- * This version takes basic types, for consistency with Python API.
300
- *
301
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
302
- * For classification, labels should take values {0, 1, ..., numClasses-1}.
303
- * For regression, labels are real numbers.
304
- * @param algo "classification" or "regression"
305
- * @param numClassesForClassification number of classes for classification. Default value of 2.
306
- * @param categoricalFeaturesInfo Map storing arity of categorical features.
307
- * E.g., an entry (n -> k) indicates that feature n is categorical
308
- * with k categories indexed from 0: {0, 1, ..., k-1}.
309
- * @param impurity criterion used for information gain calculation
310
- * @param maxDepth Maximum depth of the tree.
311
- * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
312
- * @param maxBins maximum number of bins used for splitting features
313
- * (default Python value = 100)
314
- * @return DecisionTreeModel that can be used for prediction
315
- */
316
- def train (
317
- input : RDD [LabeledPoint ],
318
- algo : String ,
319
- numClassesForClassification : Int ,
320
- categoricalFeaturesInfo : Map [Int , Int ],
321
- impurity : String ,
322
- maxDepth : Int ,
323
- maxBins : Int ): DecisionTreeModel = {
324
- val algoType = Algo .stringToAlgo(algo)
325
- val impurityType = Impurities .stringToImpurity(impurity)
326
- train(input, algoType, impurityType, maxDepth, numClassesForClassification, maxBins, Sort ,
327
- categoricalFeaturesInfo)
328
- }
329
-
330
- /**
331
- * Method to train a decision tree model.
332
- * The method supports binary and multiclass classification and regression.
333
- * This version takes basic types, for consistency with Python API.
334
- * This version is Java-friendly, taking a Java map for categoricalFeaturesInfo.
335
- *
336
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
337
- * For classification, labels should take values {0, 1, ..., numClasses-1}.
338
- * For regression, labels are real numbers.
339
- * @param algo "classification" or "regression"
340
- * @param numClassesForClassification number of classes for classification. Default value of 2.
341
- * @param categoricalFeaturesInfo Map storing arity of categorical features.
342
- * E.g., an entry (n -> k) indicates that feature n is categorical
343
- * with k categories indexed from 0: {0, 1, ..., k-1}.
344
- * @param impurity criterion used for information gain calculation
345
- * @param maxDepth Maximum depth of the tree.
346
- * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
347
- * @param maxBins maximum number of bins used for splitting features
348
- * (default Python value = 100)
349
- * @return DecisionTreeModel that can be used for prediction
350
- */
351
- def train (
352
- input : RDD [LabeledPoint ],
353
- algo : String ,
354
- numClassesForClassification : Int ,
355
- categoricalFeaturesInfo : java.util.Map [java.lang.Integer , java.lang.Integer ],
356
- impurity : String ,
357
- maxDepth : Int ,
358
- maxBins : Int ): DecisionTreeModel = {
359
- train(input, algo, numClassesForClassification,
360
- categoricalFeaturesInfo.asInstanceOf [java.util.Map [Int , Int ]].asScala.toMap,
361
- impurity, maxDepth, maxBins)
362
- }
363
-
364
312
/**
365
313
* Method to train a decision tree model for binary or multiclass classification.
366
- * This version takes basic types, for consistency with Python API.
367
314
*
368
315
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
369
316
* Labels should take values {0, 1, ..., numClasses-1}.
370
317
* @param numClassesForClassification number of classes for classification.
371
318
* @param categoricalFeaturesInfo Map storing arity of categorical features.
372
319
* E.g., an entry (n -> k) indicates that feature n is categorical
373
320
* with k categories indexed from 0: {0, 1, ..., k-1}.
374
- * (default Python value = {}, i.e., no categorical features)
375
- * @param impurity criterion used for information gain calculation
376
- * (default Python value = "gini")
321
+ * @param impurity Criterion used for information gain calculation.
322
+ * Supported values: "gini" (recommended) or "entropy".
377
323
* @param maxDepth Maximum depth of the tree.
378
324
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
379
- * (default Python value = 4)
325
+ * (suggested value: 4)
380
326
* @param maxBins maximum number of bins used for splitting features
381
- * (default Python value = 100)
327
+ * (suggested value: 100)
382
328
* @return DecisionTreeModel that can be used for prediction
383
329
*/
384
330
def trainClassifier (
@@ -388,30 +334,13 @@ object DecisionTree extends Serializable with Logging {
388
334
impurity : String ,
389
335
maxDepth : Int ,
390
336
maxBins : Int ): DecisionTreeModel = {
391
- train(input, " classification" , numClassesForClassification, categoricalFeaturesInfo, impurity,
392
- maxDepth, maxBins)
337
+ val impurityType = Impurities .fromString(impurity)
338
+ train(input, Classification , impurityType, maxDepth, numClassesForClassification, maxBins, Sort ,
339
+ categoricalFeaturesInfo)
393
340
}
394
341
395
342
/**
396
- * Method to train a decision tree model for binary or multiclass classification.
397
- * This version takes basic types, for consistency with Python API.
398
- * This version is Java-friendly, taking a Java map for categoricalFeaturesInfo.
399
- *
400
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
401
- * Labels should take values {0, 1, ..., numClasses-1}.
402
- * @param numClassesForClassification number of classes for classification.
403
- * @param categoricalFeaturesInfo Map storing arity of categorical features.
404
- * E.g., an entry (n -> k) indicates that feature n is categorical
405
- * with k categories indexed from 0: {0, 1, ..., k-1}.
406
- * (default Python value = {}, i.e., no categorical features)
407
- * @param impurity criterion used for information gain calculation
408
- * (default Python value = "gini")
409
- * @param maxDepth Maximum depth of the tree.
410
- * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
411
- * (default Python value = 4)
412
- * @param maxBins maximum number of bins used for splitting features
413
- * (default Python value = 100)
414
- * @return DecisionTreeModel that can be used for prediction
343
+ * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier ]]
415
344
*/
416
345
def trainClassifier (
417
346
input : RDD [LabeledPoint ],
@@ -427,21 +356,19 @@ object DecisionTree extends Serializable with Logging {
427
356
428
357
/**
429
358
* Method to train a decision tree model for regression.
430
- * This version takes basic types, for consistency with Python API.
431
359
*
432
360
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
433
361
* Labels are real numbers.
434
362
* @param categoricalFeaturesInfo Map storing arity of categorical features.
435
363
* E.g., an entry (n -> k) indicates that feature n is categorical
436
364
* with k categories indexed from 0: {0, 1, ..., k-1}.
437
- * (default Python value = {}, i.e., no categorical features)
438
- * @param impurity criterion used for information gain calculation
439
- * (default Python value = "variance")
365
+ * @param impurity Criterion used for information gain calculation.
366
+ * Supported values: "variance".
440
367
* @param maxDepth Maximum depth of the tree.
441
368
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
442
- * (default Python value = 4)
369
+ * (suggested value: 4)
443
370
* @param maxBins maximum number of bins used for splitting features
444
- * (default Python value = 100)
371
+ * (suggested value: 100)
445
372
* @return DecisionTreeModel that can be used for prediction
446
373
*/
447
374
def trainRegressor (
@@ -450,28 +377,12 @@ object DecisionTree extends Serializable with Logging {
450
377
impurity : String ,
451
378
maxDepth : Int ,
452
379
maxBins : Int ): DecisionTreeModel = {
453
- train(input, " regression" , 0 , categoricalFeaturesInfo, impurity, maxDepth, maxBins)
380
+ val impurityType = Impurities .fromString(impurity)
381
+ train(input, Regression , impurityType, maxDepth, 0 , maxBins, Sort , categoricalFeaturesInfo)
454
382
}
455
383
456
384
/**
457
- * Method to train a decision tree model for regression.
458
- * This version takes basic types, for consistency with Python API.
459
- * This version is Java-friendly, taking a Java map for categoricalFeaturesInfo.
460
- *
461
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
462
- * Labels are real numbers.
463
- * @param categoricalFeaturesInfo Map storing arity of categorical features.
464
- * E.g., an entry (n -> k) indicates that feature n is categorical
465
- * with k categories indexed from 0: {0, 1, ..., k-1}.
466
- * (default Python value = {}, i.e., no categorical features)
467
- * @param impurity criterion used for information gain calculation
468
- * (default Python value = "variance")
469
- * @param maxDepth Maximum depth of the tree.
470
- * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
471
- * (default Python value = 4)
472
- * @param maxBins maximum number of bins used for splitting features
473
- * (default Python value = 100)
474
- * @return DecisionTreeModel that can be used for prediction
385
+ * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor ]]
475
386
*/
476
387
def trainRegressor (
477
388
input : RDD [LabeledPoint ],
@@ -1516,16 +1427,15 @@ object DecisionTree extends Serializable with Logging {
1516
1427
* Categorical features:
1517
1428
* For each feature, there is 1 bin per split.
1518
1429
* Splits and bins are handled in 2 ways:
1519
- * (a) For multiclass classification with a low-arity feature
1430
+ * (a) "unordered features"
1431
+ * For multiclass classification with a low-arity feature
1520
1432
* (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
1521
1433
* the feature is split based on subsets of categories.
1522
- * There are math.pow(2, (maxFeatureValue - 1) - 1) splits.
1523
- * (b) For regression and binary classification,
1434
+ * There are math.pow(2, maxFeatureValue - 1) - 1 splits.
1435
+ * (b) "ordered features"
1436
+ * For regression and binary classification,
1524
1437
* and for multiclass classification with a high-arity feature,
1525
- *
1526
-
1527
- * Categorical case (a) features are called unordered features.
1528
- * Other cases are called ordered features.
1438
+ * there is one bin per category.
1529
1439
*
1530
1440
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]]
1531
1441
* @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy ]] instance containing
0 commit comments