@@ -516,7 +516,7 @@ object DecisionTree extends Serializable with Logging {
516516 * Find bin for one feature.
517517 */
518518 def findBin (featureIndex : Int , labeledPoint : WeightedLabeledPoint ,
519- isFeatureContinuous : Boolean ): Int = {
519+ isFeatureContinuous : Boolean , isSpaceSufficientForAllCategoricalSplits : Boolean ): Int = {
520520 val binForFeatures = bins(featureIndex)
521521 val feature = labeledPoint.features(featureIndex)
522522
@@ -550,14 +550,14 @@ object DecisionTree extends Serializable with Logging {
550550 * splits. The actual left/right child allocation per split is performed in the
551551 * sequential phase of the bin aggregate operation.
552552 */
553- def sequentialBinSearchForCategoricalFeatureInMulticlassClassification (): Int = {
553+ def sequentialBinSearchForUnorderedCategoricalFeatureInClassification (): Int = {
554554 labeledPoint.features(featureIndex).toInt
555555 }
556556
557557 /**
558558 * Sequential search helper method to find bin for categorical feature.
559559 */
560- def sequentialBinSearchForCategoricalFeatureInBinaryClassification (): Int = {
560+ def sequentialBinSearchForOrderedCategoricalFeatureInClassification (): Int = {
561561 val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
562562 val numCategoricalBins = math.pow(2.0 , featureCategories - 1 ).toInt - 1
563563 var binIndex = 0
@@ -583,10 +583,10 @@ object DecisionTree extends Serializable with Logging {
583583 } else {
584584 // Perform sequential search to find bin for categorical features.
585585 val binIndex = {
586- if (isMulticlassClassification) {
587- sequentialBinSearchForCategoricalFeatureInMulticlassClassification ()
586+ if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits ) {
587+ sequentialBinSearchForUnorderedCategoricalFeatureInClassification ()
588588 } else {
589- sequentialBinSearchForCategoricalFeatureInBinaryClassification ()
589+ sequentialBinSearchForOrderedCategoricalFeatureInClassification ()
590590 }
591591 }
592592 if (binIndex == - 1 ){
@@ -622,8 +622,19 @@ object DecisionTree extends Serializable with Logging {
622622 } else {
623623 var featureIndex = 0
624624 while (featureIndex < numFeatures) {
625- val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
626- arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous)
625+ val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex)
626+ val isFeatureContinuous = featureInfo.isEmpty
627+ if (isFeatureContinuous) {
628+ arr(shift + featureIndex)
629+ = findBin(featureIndex, labeledPoint, isFeatureContinuous, false )
630+ } else {
631+ val featureCategories = featureInfo.get
632+ val isSpaceSufficientForAllCategoricalSplits
633+ = numBins > math.pow(2 , featureCategories.toInt - 1 ) - 1
634+ arr(shift + featureIndex)
635+ = findBin(featureIndex, labeledPoint, isFeatureContinuous,
636+ isSpaceSufficientForAllCategoricalSplits)
637+ }
627638 featureIndex += 1
628639 }
629640 }
@@ -731,12 +742,19 @@ object DecisionTree extends Serializable with Logging {
731742 // Iterate over all features.
732743 var featureIndex = 0
733744 while (featureIndex < numFeatures) {
734- val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
735- if (isContinuousFeature ) {
745+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
746+ if (isFeatureContinuous ) {
736747 updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
737748 } else {
738- updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift)
749+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
750+ val isSpaceSufficientForAllCategoricalSplits
751+ = numBins > math.pow(2 , featureCategories.toInt - 1 ) - 1
752+ if (isSpaceSufficientForAllCategoricalSplits) {
753+ updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift)
754+ } else {
755+ updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
739756 }
757+ }
740758 featureIndex += 1
741759 }
742760 }
@@ -1093,7 +1111,14 @@ object DecisionTree extends Serializable with Logging {
10931111 if (isFeatureContinuous) {
10941112 findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
10951113 } else {
1096- findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
1114+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1115+ val isSpaceSufficientForAllCategoricalSplits
1116+ = numBins > math.pow(2 , featureCategories.toInt - 1 ) - 1
1117+ if (isSpaceSufficientForAllCategoricalSplits) {
1118+ findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
1119+ } else {
1120+ findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
1121+ }
10971122 }
10981123 } else {
10991124 findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
@@ -1168,7 +1193,9 @@ object DecisionTree extends Serializable with Logging {
11681193 numBins - 1
11691194 } else { // Categorical feature
11701195 val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1171- if (isMulticlassClassification) {
1196+ val isSpaceSufficientForAllCategoricalSplits
1197+ = numBins > math.pow(2 , featureCategories.toInt - 1 ) - 1
1198+ if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
11721199 math.pow(2.0 , featureCategories - 1 ).toInt - 1
11731200 } else { // Binary classification
11741201 featureCategories
@@ -1289,11 +1316,6 @@ object DecisionTree extends Serializable with Logging {
12891316 val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
12901317 require(numBins > maxCategoriesForFeatures, " numBins should be greater than max categories " +
12911318 " in categorical features" )
1292- if (isMulticlassClassification) {
1293- require(numBins > math.pow(2 , maxCategoriesForFeatures.toInt - 1 ) - 1 ,
1294- " numBins should be greater than 2^(maxNumCategories-1) -1 for multiclass classification" +
1295- " with categorical variables" )
1296- }
12971319 }
12981320
12991321
@@ -1332,10 +1354,12 @@ object DecisionTree extends Serializable with Logging {
13321354 }
13331355 } else { // Categorical feature
13341356 val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1357+ val isSpaceSufficientForAllCategoricalSplits
1358+ = numBins > math.pow(2 , featureCategories.toInt - 1 ) - 1
13351359
13361360 // Use different bin/split calculation strategy for categorical features in multiclass
1337- // classification
1338- if (isMulticlassClassification) {
1361+ // classification that satisfy the space constraint
1362+ if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits ) {
13391363 // 2^(maxFeatureValue- 1) - 1 combinations
13401364 var index = 0
13411365 while (index < math.pow(2.0 , featureCategories - 1 ).toInt - 1 ) {
@@ -1360,14 +1384,29 @@ object DecisionTree extends Serializable with Logging {
13601384 }
13611385 index += 1
13621386 }
1363- } else { // regression or binary classification
1364-
1365- // For categorical variables, each bin is a category. The bins are sorted and they
1366- // are ordered by calculating the centroid of their corresponding labels.
1367- val centroidForCategories =
1368- sampledInput.map(lp => (lp.features(featureIndex),lp.label))
1369- .groupBy(_._1)
1370- .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
1387+ } else {
1388+
1389+ val centroidForCategories = {
1390+ if (isMulticlassClassification) {
1391+ // For categorical variables in multiclass classification,
1392+ // each bin is a category. The bins are sorted and they
1393+ // are ordered by calculating the impurity of their corresponding labels.
1394+ sampledInput.map(lp => (lp.features(featureIndex), lp.label))
1395+ .groupBy(_._1)
1396+ .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble))
1397+ .map(x => (x._1, x._2.values.toArray))
1398+ .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum)))
1399+ } else { // regression or binary classification
1400+ // For categorical variables in regression and binary classification,
1401+ // each bin is a category. The bins are sorted and they
1402+ // are ordered by calculating the centroid of their corresponding labels.
1403+ sampledInput.map(lp => (lp.features(featureIndex), lp.label))
1404+ .groupBy(_._1)
1405+ .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
1406+ }
1407+ }
1408+
1409+ logDebug(" centriod for categories = " + centroidForCategories.mkString(" ," ))
13711410
13721411 // Check for missing categorical variables and putting them last in the sorted list.
13731412 val fullCentroidForCategories = scala.collection.mutable.Map [Double ,Double ]()
0 commit comments