Skip to content

Commit 0f676e2

Browse files
committed
Optimizations + Bug fix for DecisionTree
Optimization: Added TreePoint representation so we only call findBin once for each example, feature. Also, calculateGainsForAllNodeSplits now only searches over actual splits, not empty/unused ones. BUG FIX: isSampleValid * isSampleValid used to treat unordered categorical features incorrectly: It treated the bins as if indexed by featured values, rather than by subsets of values/categories. * exhibited for unordered features (multi-class classification with categorical features of low arity) * Fix: Index bins correctly for unordered categorical features. Also: some commented-out debugging println calls in DecisionTree, to be removed later
1 parent 3211f02 commit 0f676e2

File tree

1 file changed

+95
-28
lines changed

1 file changed

+95
-28
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 95 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
252252
// noting the parents filters for the child nodes
253253
val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
254254
filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2)
255+
//println(s"extractInfoForLowerLevels: Set filters(node:$nodeIndex): ${filters(nodeIndex).mkString(", ")}")
255256
for (filter <- filters(nodeIndex)) {
256257
logDebug("Filter = " + filter)
257258
}
@@ -477,7 +478,7 @@ object DecisionTree extends Serializable with Logging {
477478
* @param splits possible splits for all features
478479
* @param bins possible bins for all features
479480
* @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
480-
* @return array of splits with best splits for all nodes at a given level.
481+
* @return array (over nodes) of splits with best split for each node at a given level.
481482
*/
482483
protected[tree] def findBestSplits(
483484
input: RDD[TreePoint],
@@ -490,6 +491,7 @@ object DecisionTree extends Serializable with Logging {
490491
maxLevelForSingleGroup: Int,
491492
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
492493
// split into groups to avoid memory overflow during aggregation
494+
//println(s"findBestSplits: level = $level")
493495
if (level > maxLevelForSingleGroup) {
494496
// When information for all nodes at a given level cannot be stored in memory,
495497
// the nodes are divided into multiple groups at each level with the number of groups
@@ -617,22 +619,32 @@ object DecisionTree extends Serializable with Logging {
617619
val featureIndex = filter.split.feature
618620
val comparison = filter.comparison
619621
val isFeatureContinuous = filter.split.featureType == Continuous
620-
val binId = treePoint.features(featureIndex)
621-
val bin = bins(featureIndex)(binId)
622622
if (isFeatureContinuous) {
623+
val binId = treePoint.features(featureIndex)
624+
val bin = bins(featureIndex)(binId)
623625
val featureValue = bin.highSplit.threshold
624626
val threshold = filter.split.threshold
625627
comparison match {
626628
case -1 => if (featureValue > threshold) return false
627629
case 1 => if (featureValue <= threshold) return false
628630
}
629631
} else {
630-
val containsFeature = filter.split.categories.contains(bin.category)
632+
val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex)
633+
val isSpaceSufficientForAllCategoricalSplits =
634+
numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1
635+
val isUnorderedFeature =
636+
isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
637+
val featureValue = if (isUnorderedFeature) {
638+
treePoint.features(featureIndex)
639+
} else {
640+
val binId = treePoint.features(featureIndex)
641+
bins(featureIndex)(binId).category
642+
}
643+
val containsFeature = filter.split.categories.contains(featureValue)
631644
comparison match {
632645
case -1 => if (!containsFeature) return false
633646
case 1 => if (containsFeature) return false
634647
}
635-
636648
}
637649
}
638650

@@ -669,6 +681,7 @@ object DecisionTree extends Serializable with Logging {
669681
val parentFilters = findParentFilters(nodeIndex)
670682
// Find out whether the sample qualifies for the particular node.
671683
val sampleValid = isSampleValid(parentFilters, treePoint)
684+
//println(s"==>findBinsForLevel: node:$nodeIndex, valid=$sampleValid, parentFilters:${parentFilters.mkString(",")}")
672685
val shift = 1 + numFeatures * nodeIndex
673686
if (!sampleValid) {
674687
// Mark one bin as -1 is sufficient.
@@ -739,6 +752,7 @@ object DecisionTree extends Serializable with Logging {
739752
label: Double,
740753
agg: Array[Double],
741754
rightChildShift: Int): Unit = {
755+
//println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.")
742756
// Find the bin index for this feature.
743757
val arrIndex = 1 + numFeatures * nodeIndex + featureIndex
744758
val featureValue = arr(arrIndex).toInt
@@ -792,6 +806,8 @@ object DecisionTree extends Serializable with Logging {
792806
}
793807
}
794808

809+
val rightChildShift = numClasses * numBins * numFeatures * numNodes
810+
795811
/**
796812
* Helper for binSeqOp.
797813
*
@@ -814,8 +830,11 @@ object DecisionTree extends Serializable with Logging {
814830
// Check whether the instance was valid for this nodeIndex.
815831
val validSignalIndex = 1 + numFeatures * nodeIndex
816832
val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
833+
if (level == 1) {
834+
val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift
835+
//println(s"-multiclassWithCategoricalBinSeqOp: filter: ${filters(nodeFilterIndex)}")
836+
}
817837
if (isSampleValidForNode) {
818-
val rightChildShift = numClasses * numBins * numFeatures * numNodes
819838
// actual class label
820839
val label = arr(0)
821840
// Iterate over all features.
@@ -874,7 +893,7 @@ object DecisionTree extends Serializable with Logging {
874893
val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3
875894
agg(aggIndex) = agg(aggIndex) + 1
876895
agg(aggIndex + 1) = agg(aggIndex + 1) + label
877-
agg(aggIndex + 2) = agg(aggIndex + 2) + label*label
896+
agg(aggIndex + 2) = agg(aggIndex + 2) + label * label
878897
featureIndex += 1
879898
}
880899
}
@@ -944,6 +963,29 @@ object DecisionTree extends Serializable with Logging {
944963
logDebug("binAggregates.length = " + binAggregates.length)
945964

946965
timer.binAggregatesTime += timer.elapsed()
966+
//2 * numClasses * numBins * numFeatures * numNodes for unordered features.
967+
// (left/right, node, feature, bin, label)
968+
/*
969+
println(s"binAggregates:")
970+
for (i <- Range(0,2)) {
971+
for (n <- Range(0,numNodes)) {
972+
for (f <- Range(0,numFeatures)) {
973+
for (b <- Range(0,4)) {
974+
for (c <- Range(0,numClasses)) {
975+
val idx = i * numClasses * numBins * numFeatures * numNodes +
976+
n * numClasses * numBins * numFeatures +
977+
f * numBins * numFeatures +
978+
b * numFeatures +
979+
c
980+
if (binAggregates(idx) != 0) {
981+
println(s"\t ($i, c:$c, b:$b, f:$f, n:$n): ${binAggregates(idx)}")
982+
}
983+
}
984+
}
985+
}
986+
}
987+
}
988+
*/
947989

948990
/**
949991
* Calculates the information gain for all splits based upon left/right split aggregates.
@@ -985,6 +1027,7 @@ object DecisionTree extends Serializable with Logging {
9851027
val totalCount = leftTotalCount + rightTotalCount
9861028
if (totalCount == 0) {
9871029
// Return arbitrary prediction.
1030+
//println(s"BLAH: feature $featureIndex, split $splitIndex. totalCount == 0")
9881031
return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
9891032
}
9901033

@@ -997,13 +1040,23 @@ object DecisionTree extends Serializable with Logging {
9971040
def indexOfLargestArrayElement(array: Array[Double]): Int = {
9981041
val result = array.foldLeft(-1, Double.MinValue, 0) {
9991042
case ((maxIndex, maxValue, currentIndex), currentValue) =>
1000-
if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1)
1001-
else (maxIndex, maxValue, currentIndex + 1)
1043+
if (currentValue > maxValue) {
1044+
(currentIndex, currentValue, currentIndex + 1)
1045+
} else {
1046+
(maxIndex, maxValue, currentIndex + 1)
1047+
}
10021048
}
1003-
if (result._1 < 0) 0 else result._1
1049+
if (result._1 < 0) {
1050+
throw new RuntimeException("DecisionTree internal error:" +
1051+
" calculateGainForSplit failed in indexOfLargestArrayElement")
1052+
}
1053+
result._1
10041054
}
10051055

10061056
val predict = indexOfLargestArrayElement(leftRightCounts)
1057+
if (predict == 0 && featureIndex == 0 && splitIndex == 0) {
1058+
//println(s"AGHGHGHHGHG: leftCounts: ${leftCounts.mkString(",")}, rightCounts: ${rightCounts.mkString(",")}")
1059+
}
10071060
val prob = leftRightCounts(predict) / totalCount
10081061

10091062
val leftImpurity = if (leftTotalCount == 0) {
@@ -1023,6 +1076,7 @@ object DecisionTree extends Serializable with Logging {
10231076
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
10241077

10251078
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
1079+
10261080
case Regression =>
10271081
val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
10281082
val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
@@ -1140,6 +1194,7 @@ object DecisionTree extends Serializable with Logging {
11401194

11411195
val rightChildShift = numClasses * numBins * numFeatures
11421196
var splitIndex = 0
1197+
var TMPDEBUG = 0.0
11431198
while (splitIndex < numBins - 1) {
11441199
var classIndex = 0
11451200
while (classIndex < numClasses) {
@@ -1149,10 +1204,12 @@ object DecisionTree extends Serializable with Logging {
11491204
val rightBinValue = binData(rightChildShift + shift + classIndex)
11501205
leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue
11511206
rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue
1207+
TMPDEBUG += leftBinValue + rightBinValue
11521208
classIndex += 1
11531209
}
11541210
splitIndex += 1
11551211
}
1212+
//println(s"found Agg: $TMPDEBUG")
11561213
}
11571214

11581215
def findAggForRegression(
@@ -1247,14 +1304,36 @@ object DecisionTree extends Serializable with Logging {
12471304
val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
12481305

12491306
for (featureIndex <- 0 until numFeatures) {
1250-
for (splitIndex <- 0 until numBins - 1) {
1307+
val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
1308+
for (splitIndex <- 0 until numSplitsForFeature) {
12511309
gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
12521310
splitIndex, rightNodeAgg, nodeImpurity)
12531311
}
12541312
}
12551313
gains
12561314
}
12571315

1316+
/**
1317+
* Get the number of splits for a feature.
1318+
*/
1319+
def getNumSplitsForFeature(featureIndex: Int): Int = {
1320+
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
1321+
if (isFeatureContinuous) {
1322+
numBins - 1
1323+
} else {
1324+
// Categorical feature
1325+
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1326+
val isSpaceSufficientForAllCategoricalSplits =
1327+
numBins > math.pow(2, featureCategories.toInt - 1) - 1
1328+
if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
1329+
math.pow(2.0, featureCategories - 1).toInt - 1
1330+
} else {
1331+
// Ordered features
1332+
featureCategories
1333+
}
1334+
}
1335+
}
1336+
12581337
/**
12591338
* Find the best split for a node.
12601339
* @param binData Bin data slice for this node, given by getBinDataForNode.
@@ -1273,7 +1352,7 @@ object DecisionTree extends Serializable with Logging {
12731352
// Calculate gains for all splits.
12741353
val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
12751354

1276-
val (bestFeatureIndex,bestSplitIndex, gainStats) = {
1355+
val (bestFeatureIndex, bestSplitIndex, gainStats) = {
12771356
// Initialize with infeasible values.
12781357
var bestFeatureIndex = Int.MinValue
12791358
var bestSplitIndex = Int.MinValue
@@ -1283,27 +1362,14 @@ object DecisionTree extends Serializable with Logging {
12831362
while (featureIndex < numFeatures) {
12841363
// Iterate over all splits.
12851364
var splitIndex = 0
1286-
val maxSplitIndex: Double = {
1287-
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
1288-
if (isFeatureContinuous) {
1289-
numBins - 1
1290-
} else { // Categorical feature
1291-
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1292-
val isSpaceSufficientForAllCategoricalSplits
1293-
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
1294-
if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
1295-
math.pow(2.0, featureCategories - 1).toInt - 1
1296-
} else { // Binary classification
1297-
featureCategories
1298-
}
1299-
}
1300-
}
1301-
while (splitIndex < maxSplitIndex) {
1365+
val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
1366+
while (splitIndex < numSplitsForFeature) {
13021367
val gainStats = gains(featureIndex)(splitIndex)
13031368
if (gainStats.gain > bestGainStats.gain) {
13041369
bestGainStats = gainStats
13051370
bestFeatureIndex = featureIndex
13061371
bestSplitIndex = splitIndex
1372+
//println(s" feature $featureIndex UPGRADED split $splitIndex: ${splits(featureIndex)(splitIndex)}: gainstats: $gainStats")
13071373
}
13081374
splitIndex += 1
13091375
}
@@ -1361,6 +1427,7 @@ object DecisionTree extends Serializable with Logging {
13611427
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
13621428
logDebug("parent node impurity = " + parentNodeImpurity)
13631429
bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
1430+
//println(s"bestSplits(node:$node): ${bestSplits(node)}")
13641431
node += 1
13651432
}
13661433
timer.chooseSplitsTime += timer.elapsed()

0 commit comments

Comments
 (0)