@@ -144,6 +144,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
144
144
new DecisionTreeModel (topNode, strategy.algo)
145
145
}
146
146
147
+ // TODO: Unit test this
147
148
/**
148
149
* Extract the decision tree node information for the given tree level and node index
149
150
*/
@@ -161,6 +162,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
161
162
nodes(nodeIndex) = node
162
163
}
163
164
165
+ // TODO: Unit test this
164
166
/**
165
167
* Extract the decision tree node information for the children of the node
166
168
*/
@@ -458,6 +460,8 @@ object DecisionTree extends Serializable with Logging {
458
460
logDebug(" numClasses = " + numClasses)
459
461
val labelWeights = strategy.labelWeights
460
462
logDebug(" labelWeights = " + labelWeights)
463
+ val isMulticlassClassification = strategy.isMulticlassClassification
464
+ logDebug(" isMulticlassClassification = " + isMulticlassClassification)
461
465
462
466
463
467
// shift when more than one group is used at deep tree level
@@ -582,7 +586,7 @@ object DecisionTree extends Serializable with Logging {
582
586
} else {
583
587
// Perform sequential search to find bin for categorical features.
584
588
val binIndex = {
585
- if (strategy.isMultiClassification ) {
589
+ if (isMulticlassClassification ) {
586
590
sequentialBinSearchForCategoricalFeatureInBinaryClassification()
587
591
} else {
588
592
sequentialBinSearchForCategoricalFeatureInMultiClassClassification()
@@ -606,7 +610,9 @@ object DecisionTree extends Serializable with Logging {
606
610
def findBinsForLevel (labeledPoint : WeightedLabeledPoint ): Array [Double ] = {
607
611
// Calculate bin index and label per feature per node.
608
612
val arr = new Array [Double ](1 + (numFeatures * numNodes))
613
+ // First element of the array is the label of the instance.
609
614
arr(0 ) = labeledPoint.label
615
+ // Iterate over nodes.
610
616
var nodeIndex = 0
611
617
while (nodeIndex < numNodes) {
612
618
val parentFilters = findParentFilters(nodeIndex)
@@ -629,7 +635,10 @@ object DecisionTree extends Serializable with Logging {
629
635
arr
630
636
}
631
637
632
- /**
638
+ // Find feature bins for all nodes at a level.
639
+ val binMappedRDD = input.map(x => findBinsForLevel(x))
640
+
641
+ /**
633
642
* Performs a sequential aggregation over a partition for classification. For l nodes,
634
643
* k features, either the left count or the right count of one of the p bins is
635
644
* incremented based upon whether the feature is classified as 0 or 1.
@@ -663,7 +672,7 @@ object DecisionTree extends Serializable with Logging {
663
672
label.toInt match {
664
673
case n : Int =>
665
674
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
666
- if (! isFeatureContinuous && strategy.isMultiClassification ) {
675
+ if (! isFeatureContinuous && isMulticlassClassification ) {
667
676
// Find all matching bins and increment their values
668
677
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
669
678
val numCategoricalBins = math.pow(2.0 , featureCategories - 1 ).toInt - 1
@@ -736,7 +745,6 @@ object DecisionTree extends Serializable with Logging {
736
745
agg
737
746
}
738
747
739
- // TODO: Double-check this
740
748
// Calculate bin aggregate length for classification or regression.
741
749
val binAggregateLength = strategy.algo match {
742
750
case Classification => numClasses * numBins * numFeatures * numNodes
@@ -760,9 +768,6 @@ object DecisionTree extends Serializable with Logging {
760
768
combinedAggregate
761
769
}
762
770
763
- // Find feature bins for all nodes at a level.
764
- val binMappedRDD = input.map(x => findBinsForLevel(x))
765
-
766
771
// Calculate bin aggregates.
767
772
val binAggregates = {
768
773
binMappedRDD.aggregate(Array .fill[Double ](binAggregateLength)(0 ))(binSeqOp,binCombOp)
@@ -922,7 +927,7 @@ object DecisionTree extends Serializable with Logging {
922
927
val leftNodeAgg = Array .ofDim[Double ](numFeatures, numBins - 1 , numClasses)
923
928
val rightNodeAgg = Array .ofDim[Double ](numFeatures, numBins - 1 , numClasses)
924
929
925
- if (strategy.isMultiClassification ) {
930
+ if (isMulticlassClassification ) {
926
931
var featureIndex = 0
927
932
while (featureIndex < numFeatures){
928
933
var splitIndex = 0
@@ -1096,7 +1101,7 @@ object DecisionTree extends Serializable with Logging {
1096
1101
numBins - 1
1097
1102
} else { // Categorical feature
1098
1103
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1099
- if (strategy.isMultiClassification ) {
1104
+ if (isMulticlassClassification ) {
1100
1105
math.pow(2.0 , featureCategories - 1 ).toInt - 1
1101
1106
} else { // Binary classification
1102
1107
featureCategories
@@ -1177,6 +1182,9 @@ object DecisionTree extends Serializable with Logging {
1177
1182
val maxBins = strategy.maxBins
1178
1183
val numBins = if (maxBins <= count) maxBins else count.toInt
1179
1184
logDebug(" numBins = " + numBins)
1185
+ val isMulticlassClassification = strategy.isMulticlassClassification
1186
+ logDebug(" isMulticlassClassification = " + isMulticlassClassification)
1187
+
1180
1188
1181
1189
/*
1182
1190
* Ensure #bins is always greater than the categories. For multiclass classification,
@@ -1187,7 +1195,7 @@ object DecisionTree extends Serializable with Logging {
1187
1195
if (strategy.categoricalFeaturesInfo.size > 0 ) {
1188
1196
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
1189
1197
require(numBins > maxCategoriesForFeatures)
1190
- if (strategy.isMultiClassification ) {
1198
+ if (isMulticlassClassification ) {
1191
1199
require(numBins > math.pow(2 , maxCategoriesForFeatures.toInt - 1 ) - 1 )
1192
1200
}
1193
1201
}
@@ -1230,7 +1238,7 @@ object DecisionTree extends Serializable with Logging {
1230
1238
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1231
1239
1232
1240
// Use different bin/split calculation strategy for multiclass classification
1233
- if (strategy.isMultiClassification ) {
1241
+ if (isMulticlassClassification ) {
1234
1242
// 2^(maxFeatureValue- 1) - 1 combinations
1235
1243
var index = 0
1236
1244
while (index < math.pow(2.0 , featureCategories - 1 ).toInt - 1 ) {
0 commit comments