@@ -252,6 +252,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
252252 // noting the parents filters for the child nodes
253253 val childFilter = new Filter (nodeSplitStats._1, if (i == 0 ) - 1 else 1 )
254254 filters(nodeIndex) = childFilter :: filters((nodeIndex - 1 ) / 2 )
255+ // println(s"extractInfoForLowerLevels: Set filters(node:$nodeIndex): ${filters(nodeIndex).mkString(", ")}")
255256 for (filter <- filters(nodeIndex)) {
256257 logDebug(" Filter = " + filter)
257258 }
@@ -477,7 +478,7 @@ object DecisionTree extends Serializable with Logging {
477478 * @param splits possible splits for all features
478479 * @param bins possible bins for all features
479480 * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
480- * @return array of splits with best splits for all nodes at a given level.
481+ * @return array (over nodes) of splits with best split for each node at a given level.
481482 */
482483 protected [tree] def findBestSplits (
483484 input : RDD [TreePoint ],
@@ -490,6 +491,7 @@ object DecisionTree extends Serializable with Logging {
490491 maxLevelForSingleGroup : Int ,
491492 timer : TimeTracker = new TimeTracker ): Array [(Split , InformationGainStats )] = {
492493 // split into groups to avoid memory overflow during aggregation
494+ // println(s"findBestSplits: level = $level")
493495 if (level > maxLevelForSingleGroup) {
494496 // When information for all nodes at a given level cannot be stored in memory,
495497 // the nodes are divided into multiple groups at each level with the number of groups
@@ -617,22 +619,32 @@ object DecisionTree extends Serializable with Logging {
617619 val featureIndex = filter.split.feature
618620 val comparison = filter.comparison
619621 val isFeatureContinuous = filter.split.featureType == Continuous
620- val binId = treePoint.features(featureIndex)
621- val bin = bins(featureIndex)(binId)
622622 if (isFeatureContinuous) {
623+ val binId = treePoint.features(featureIndex)
624+ val bin = bins(featureIndex)(binId)
623625 val featureValue = bin.highSplit.threshold
624626 val threshold = filter.split.threshold
625627 comparison match {
626628 case - 1 => if (featureValue > threshold) return false
627629 case 1 => if (featureValue <= threshold) return false
628630 }
629631 } else {
630- val containsFeature = filter.split.categories.contains(bin.category)
632+ val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex)
633+ val isSpaceSufficientForAllCategoricalSplits =
634+ numBins > math.pow(2 , numFeatureCategories.toInt - 1 ) - 1
635+ val isUnorderedFeature =
636+ isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
637+ val featureValue = if (isUnorderedFeature) {
638+ treePoint.features(featureIndex)
639+ } else {
640+ val binId = treePoint.features(featureIndex)
641+ bins(featureIndex)(binId).category
642+ }
643+ val containsFeature = filter.split.categories.contains(featureValue)
631644 comparison match {
632645 case - 1 => if (! containsFeature) return false
633646 case 1 => if (containsFeature) return false
634647 }
635-
636648 }
637649 }
638650
@@ -669,6 +681,7 @@ object DecisionTree extends Serializable with Logging {
669681 val parentFilters = findParentFilters(nodeIndex)
670682 // Find out whether the sample qualifies for the particular node.
671683 val sampleValid = isSampleValid(parentFilters, treePoint)
684+ // println(s"==>findBinsForLevel: node:$nodeIndex, valid=$sampleValid, parentFilters:${parentFilters.mkString(",")}")
672685 val shift = 1 + numFeatures * nodeIndex
673686 if (! sampleValid) {
674687 // Mark one bin as -1 is sufficient.
@@ -739,6 +752,7 @@ object DecisionTree extends Serializable with Logging {
739752 label : Double ,
740753 agg : Array [Double ],
741754 rightChildShift : Int ): Unit = {
755+ // println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.")
742756 // Find the bin index for this feature.
743757 val arrIndex = 1 + numFeatures * nodeIndex + featureIndex
744758 val featureValue = arr(arrIndex).toInt
@@ -792,6 +806,8 @@ object DecisionTree extends Serializable with Logging {
792806 }
793807 }
794808
809+ val rightChildShift = numClasses * numBins * numFeatures * numNodes
810+
795811 /**
796812 * Helper for binSeqOp.
797813 *
@@ -814,8 +830,11 @@ object DecisionTree extends Serializable with Logging {
814830 // Check whether the instance was valid for this nodeIndex.
815831 val validSignalIndex = 1 + numFeatures * nodeIndex
816832 val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
833+ if (level == 1 ) {
834+ val nodeFilterIndex = math.pow(2 , level).toInt - 1 + nodeIndex + groupShift
835+ // println(s"-multiclassWithCategoricalBinSeqOp: filter: ${filters(nodeFilterIndex)}")
836+ }
817837 if (isSampleValidForNode) {
818- val rightChildShift = numClasses * numBins * numFeatures * numNodes
819838 // actual class label
820839 val label = arr(0 )
821840 // Iterate over all features.
@@ -874,7 +893,7 @@ object DecisionTree extends Serializable with Logging {
874893 val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3
875894 agg(aggIndex) = agg(aggIndex) + 1
876895 agg(aggIndex + 1 ) = agg(aggIndex + 1 ) + label
877- agg(aggIndex + 2 ) = agg(aggIndex + 2 ) + label* label
896+ agg(aggIndex + 2 ) = agg(aggIndex + 2 ) + label * label
878897 featureIndex += 1
879898 }
880899 }
@@ -944,6 +963,29 @@ object DecisionTree extends Serializable with Logging {
944963 logDebug(" binAggregates.length = " + binAggregates.length)
945964
946965 timer.binAggregatesTime += timer.elapsed()
966+ // 2 * numClasses * numBins * numFeatures * numNodes for unordered features.
967+ // (left/right, node, feature, bin, label)
968+ /*
969+ println(s"binAggregates:")
970+ for (i <- Range(0,2)) {
971+ for (n <- Range(0,numNodes)) {
972+ for (f <- Range(0,numFeatures)) {
973+ for (b <- Range(0,4)) {
974+ for (c <- Range(0,numClasses)) {
975+ val idx = i * numClasses * numBins * numFeatures * numNodes +
976+ n * numClasses * numBins * numFeatures +
977+ f * numBins * numFeatures +
978+ b * numFeatures +
979+ c
980+ if (binAggregates(idx) != 0) {
981+ println(s"\t ($i, c:$c, b:$b, f:$f, n:$n): ${binAggregates(idx)}")
982+ }
983+ }
984+ }
985+ }
986+ }
987+ }
988+ */
947989
948990 /**
949991 * Calculates the information gain for all splits based upon left/right split aggregates.
@@ -985,6 +1027,7 @@ object DecisionTree extends Serializable with Logging {
9851027 val totalCount = leftTotalCount + rightTotalCount
9861028 if (totalCount == 0 ) {
9871029 // Return arbitrary prediction.
1030+ // println(s"BLAH: feature $featureIndex, split $splitIndex. totalCount == 0")
9881031 return new InformationGainStats (0 , topImpurity, topImpurity, topImpurity, 0 )
9891032 }
9901033
@@ -997,13 +1040,23 @@ object DecisionTree extends Serializable with Logging {
9971040 def indexOfLargestArrayElement (array : Array [Double ]): Int = {
9981041 val result = array.foldLeft(- 1 , Double .MinValue , 0 ) {
9991042 case ((maxIndex, maxValue, currentIndex), currentValue) =>
1000- if (currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1 )
1001- else (maxIndex, maxValue, currentIndex + 1 )
1043+ if (currentValue > maxValue) {
1044+ (currentIndex, currentValue, currentIndex + 1 )
1045+ } else {
1046+ (maxIndex, maxValue, currentIndex + 1 )
1047+ }
10021048 }
1003- if (result._1 < 0 ) 0 else result._1
1049+ if (result._1 < 0 ) {
1050+ throw new RuntimeException (" DecisionTree internal error:" +
1051+ " calculateGainForSplit failed in indexOfLargestArrayElement" )
1052+ }
1053+ result._1
10041054 }
10051055
10061056 val predict = indexOfLargestArrayElement(leftRightCounts)
1057+ if (predict == 0 && featureIndex == 0 && splitIndex == 0 ) {
1058+ // println(s"AGHGHGHHGHG: leftCounts: ${leftCounts.mkString(",")}, rightCounts: ${rightCounts.mkString(",")}")
1059+ }
10071060 val prob = leftRightCounts(predict) / totalCount
10081061
10091062 val leftImpurity = if (leftTotalCount == 0 ) {
@@ -1023,6 +1076,7 @@ object DecisionTree extends Serializable with Logging {
10231076 val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
10241077
10251078 new InformationGainStats (gain, impurity, leftImpurity, rightImpurity, predict, prob)
1079+
10261080 case Regression =>
10271081 val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0 )
10281082 val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1 )
@@ -1140,6 +1194,7 @@ object DecisionTree extends Serializable with Logging {
11401194
11411195 val rightChildShift = numClasses * numBins * numFeatures
11421196 var splitIndex = 0
1197+ var TMPDEBUG = 0.0
11431198 while (splitIndex < numBins - 1 ) {
11441199 var classIndex = 0
11451200 while (classIndex < numClasses) {
@@ -1149,10 +1204,12 @@ object DecisionTree extends Serializable with Logging {
11491204 val rightBinValue = binData(rightChildShift + shift + classIndex)
11501205 leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue
11511206 rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue
1207+ TMPDEBUG += leftBinValue + rightBinValue
11521208 classIndex += 1
11531209 }
11541210 splitIndex += 1
11551211 }
1212+ // println(s"found Agg: $TMPDEBUG")
11561213 }
11571214
11581215 def findAggForRegression (
@@ -1247,14 +1304,36 @@ object DecisionTree extends Serializable with Logging {
12471304 val gains = Array .ofDim[InformationGainStats ](numFeatures, numBins - 1 )
12481305
12491306 for (featureIndex <- 0 until numFeatures) {
1250- for (splitIndex <- 0 until numBins - 1 ) {
1307+ val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
1308+ for (splitIndex <- 0 until numSplitsForFeature) {
12511309 gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
12521310 splitIndex, rightNodeAgg, nodeImpurity)
12531311 }
12541312 }
12551313 gains
12561314 }
12571315
1316+ /**
1317+ * Get the number of splits for a feature.
1318+ */
1319+ def getNumSplitsForFeature (featureIndex : Int ): Int = {
1320+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
1321+ if (isFeatureContinuous) {
1322+ numBins - 1
1323+ } else {
1324+ // Categorical feature
1325+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1326+ val isSpaceSufficientForAllCategoricalSplits =
1327+ numBins > math.pow(2 , featureCategories.toInt - 1 ) - 1
1328+ if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
1329+ math.pow(2.0 , featureCategories - 1 ).toInt - 1
1330+ } else {
1331+ // Ordered features
1332+ featureCategories
1333+ }
1334+ }
1335+ }
1336+
12581337 /**
12591338 * Find the best split for a node.
12601339 * @param binData Bin data slice for this node, given by getBinDataForNode.
@@ -1273,7 +1352,7 @@ object DecisionTree extends Serializable with Logging {
12731352 // Calculate gains for all splits.
12741353 val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
12751354
1276- val (bestFeatureIndex,bestSplitIndex, gainStats) = {
1355+ val (bestFeatureIndex, bestSplitIndex, gainStats) = {
12771356 // Initialize with infeasible values.
12781357 var bestFeatureIndex = Int .MinValue
12791358 var bestSplitIndex = Int .MinValue
@@ -1283,27 +1362,14 @@ object DecisionTree extends Serializable with Logging {
12831362 while (featureIndex < numFeatures) {
12841363 // Iterate over all splits.
12851364 var splitIndex = 0
1286- val maxSplitIndex : Double = {
1287- val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
1288- if (isFeatureContinuous) {
1289- numBins - 1
1290- } else { // Categorical feature
1291- val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1292- val isSpaceSufficientForAllCategoricalSplits
1293- = numBins > math.pow(2 , featureCategories.toInt - 1 ) - 1
1294- if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
1295- math.pow(2.0 , featureCategories - 1 ).toInt - 1
1296- } else { // Binary classification
1297- featureCategories
1298- }
1299- }
1300- }
1301- while (splitIndex < maxSplitIndex) {
1365+ val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
1366+ while (splitIndex < numSplitsForFeature) {
13021367 val gainStats = gains(featureIndex)(splitIndex)
13031368 if (gainStats.gain > bestGainStats.gain) {
13041369 bestGainStats = gainStats
13051370 bestFeatureIndex = featureIndex
13061371 bestSplitIndex = splitIndex
1372+ // println(s" feature $featureIndex UPGRADED split $splitIndex: ${splits(featureIndex)(splitIndex)}: gainstats: $gainStats")
13071373 }
13081374 splitIndex += 1
13091375 }
@@ -1361,6 +1427,7 @@ object DecisionTree extends Serializable with Logging {
13611427 val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
13621428 logDebug(" parent node impurity = " + parentNodeImpurity)
13631429 bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
1430+ // println(s"bestSplits(node:$node): ${bestSplits(node)}")
13641431 node += 1
13651432 }
13661433 timer.chooseSplitsTime += timer.elapsed()
0 commit comments