@@ -516,7 +516,7 @@ object DecisionTree extends Serializable with Logging {
516
516
* Find bin for one feature.
517
517
*/
518
518
def findBin (featureIndex : Int , labeledPoint : WeightedLabeledPoint ,
519
- isFeatureContinuous : Boolean ): Int = {
519
+ isFeatureContinuous : Boolean , isSpaceSufficientForAllCategoricalSplits : Boolean ): Int = {
520
520
val binForFeatures = bins(featureIndex)
521
521
val feature = labeledPoint.features(featureIndex)
522
522
@@ -550,14 +550,14 @@ object DecisionTree extends Serializable with Logging {
550
550
* splits. The actual left/right child allocation per split is performed in the
551
551
* sequential phase of the bin aggregate operation.
552
552
*/
553
- def sequentialBinSearchForCategoricalFeatureInMulticlassClassification (): Int = {
553
+ def sequentialBinSearchForUnorderedCategoricalFeatureInClassification (): Int = {
554
554
labeledPoint.features(featureIndex).toInt
555
555
}
556
556
557
557
/**
558
558
* Sequential search helper method to find bin for categorical feature.
559
559
*/
560
- def sequentialBinSearchForCategoricalFeatureInBinaryClassification (): Int = {
560
+ def sequentialBinSearchForOrderedCategoricalFeatureInClassification (): Int = {
561
561
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
562
562
val numCategoricalBins = math.pow(2.0 , featureCategories - 1 ).toInt - 1
563
563
var binIndex = 0
@@ -583,10 +583,10 @@ object DecisionTree extends Serializable with Logging {
583
583
} else {
584
584
// Perform sequential search to find bin for categorical features.
585
585
val binIndex = {
586
- if (isMulticlassClassification) {
587
- sequentialBinSearchForCategoricalFeatureInMulticlassClassification ()
586
+ if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits ) {
587
+ sequentialBinSearchForUnorderedCategoricalFeatureInClassification ()
588
588
} else {
589
- sequentialBinSearchForCategoricalFeatureInBinaryClassification ()
589
+ sequentialBinSearchForOrderedCategoricalFeatureInClassification ()
590
590
}
591
591
}
592
592
if (binIndex == - 1 ){
@@ -622,8 +622,19 @@ object DecisionTree extends Serializable with Logging {
622
622
} else {
623
623
var featureIndex = 0
624
624
while (featureIndex < numFeatures) {
625
- val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
626
- arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous)
625
+ val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex)
626
+ val isFeatureContinuous = featureInfo.isEmpty
627
+ if (isFeatureContinuous) {
628
+ arr(shift + featureIndex)
629
+ = findBin(featureIndex, labeledPoint, isFeatureContinuous, false )
630
+ } else {
631
+ val featureCategories = featureInfo.get
632
+ val isSpaceSufficientForAllCategoricalSplits
633
+ = numBins > math.pow(2 , featureCategories.toInt - 1 ) - 1
634
+ arr(shift + featureIndex)
635
+ = findBin(featureIndex, labeledPoint, isFeatureContinuous,
636
+ isSpaceSufficientForAllCategoricalSplits)
637
+ }
627
638
featureIndex += 1
628
639
}
629
640
}
@@ -731,12 +742,19 @@ object DecisionTree extends Serializable with Logging {
731
742
// Iterate over all features.
732
743
var featureIndex = 0
733
744
while (featureIndex < numFeatures) {
734
- val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
735
- if (isContinuousFeature ) {
745
+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
746
+ if (isFeatureContinuous ) {
736
747
updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
737
748
} else {
738
- updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift)
749
+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
750
+ val isSpaceSufficientForAllCategoricalSplits
751
+ = numBins > math.pow(2 , featureCategories.toInt - 1 ) - 1
752
+ if (isSpaceSufficientForAllCategoricalSplits) {
753
+ updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift)
754
+ } else {
755
+ updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
739
756
}
757
+ }
740
758
featureIndex += 1
741
759
}
742
760
}
@@ -1093,7 +1111,14 @@ object DecisionTree extends Serializable with Logging {
1093
1111
if (isFeatureContinuous) {
1094
1112
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
1095
1113
} else {
1096
- findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
1114
+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1115
+ val isSpaceSufficientForAllCategoricalSplits
1116
+ = numBins > math.pow(2 , featureCategories.toInt - 1 ) - 1
1117
+ if (isSpaceSufficientForAllCategoricalSplits) {
1118
+ findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
1119
+ } else {
1120
+ findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
1121
+ }
1097
1122
}
1098
1123
} else {
1099
1124
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
@@ -1168,7 +1193,9 @@ object DecisionTree extends Serializable with Logging {
1168
1193
numBins - 1
1169
1194
} else { // Categorical feature
1170
1195
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1171
- if (isMulticlassClassification) {
1196
+ val isSpaceSufficientForAllCategoricalSplits
1197
+ = numBins > math.pow(2 , featureCategories.toInt - 1 ) - 1
1198
+ if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
1172
1199
math.pow(2.0 , featureCategories - 1 ).toInt - 1
1173
1200
} else { // Binary classification
1174
1201
featureCategories
@@ -1289,11 +1316,6 @@ object DecisionTree extends Serializable with Logging {
1289
1316
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
1290
1317
require(numBins > maxCategoriesForFeatures, " numBins should be greater than max categories " +
1291
1318
" in categorical features" )
1292
- if (isMulticlassClassification) {
1293
- require(numBins > math.pow(2 , maxCategoriesForFeatures.toInt - 1 ) - 1 ,
1294
- " numBins should be greater than 2^(maxNumCategories-1) -1 for multiclass classification" +
1295
- " with categorical variables" )
1296
- }
1297
1319
}
1298
1320
1299
1321
@@ -1332,10 +1354,12 @@ object DecisionTree extends Serializable with Logging {
1332
1354
}
1333
1355
} else { // Categorical feature
1334
1356
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1357
+ val isSpaceSufficientForAllCategoricalSplits
1358
+ = numBins > math.pow(2 , featureCategories.toInt - 1 ) - 1
1335
1359
1336
1360
// Use different bin/split calculation strategy for categorical features in multiclass
1337
- // classification
1338
- if (isMulticlassClassification) {
1361
+ // classification that satisfy the space constraint
1362
+ if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits ) {
1339
1363
// 2^(maxFeatureValue- 1) - 1 combinations
1340
1364
var index = 0
1341
1365
while (index < math.pow(2.0 , featureCategories - 1 ).toInt - 1 ) {
@@ -1360,14 +1384,29 @@ object DecisionTree extends Serializable with Logging {
1360
1384
}
1361
1385
index += 1
1362
1386
}
1363
- } else { // regression or binary classification
1364
-
1365
- // For categorical variables, each bin is a category. The bins are sorted and they
1366
- // are ordered by calculating the centroid of their corresponding labels.
1367
- val centroidForCategories =
1368
- sampledInput.map(lp => (lp.features(featureIndex),lp.label))
1369
- .groupBy(_._1)
1370
- .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
1387
+ } else {
1388
+
1389
+ val centroidForCategories = {
1390
+ if (isMulticlassClassification) {
1391
+ // For categorical variables in multiclass classification,
1392
+ // each bin is a category. The bins are sorted and they
1393
+ // are ordered by calculating the impurity of their corresponding labels.
1394
+ sampledInput.map(lp => (lp.features(featureIndex), lp.label))
1395
+ .groupBy(_._1)
1396
+ .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble))
1397
+ .map(x => (x._1, x._2.values.toArray))
1398
+ .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum)))
1399
+ } else { // regression or binary classification
1400
+ // For categorical variables in regression and binary classification,
1401
+ // each bin is a category. The bins are sorted and they
1402
+ // are ordered by calculating the centroid of their corresponding labels.
1403
+ sampledInput.map(lp => (lp.features(featureIndex), lp.label))
1404
+ .groupBy(_._1)
1405
+ .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
1406
+ }
1407
+ }
1408
+
1409
+ logDebug(" centriod for categories = " + centroidForCategories.mkString(" ," ))
1371
1410
1372
1411
// Check for missing categorical variables and putting them last in the sorted list.
1373
1412
val fullCentroidForCategories = scala.collection.mutable.Map [Double ,Double ]()
0 commit comments