@@ -76,20 +76,20 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
76
76
77
77
// Max memory usage for aggregates
78
78
val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
79
- logDebug(" max memory usage for aggregates = " + maxMemoryUsage)
79
+ logDebug(" max memory usage for aggregates = " + maxMemoryUsage + " bytes. " )
80
80
val numElementsPerNode = {
81
81
strategy.algo match {
82
- case Classification => 2 * numBins * numFeatures
82
+ case Classification => 2 * numBins * numFeatures
83
83
case Regression => 3 * numBins * numFeatures
84
84
}
85
85
}
86
86
logDebug(" numElementsPerNode = " + numElementsPerNode)
87
87
val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
88
88
val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1 )
89
89
logDebug(" maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup)
90
- // nodes at a level is 2^( level-1) . level is zero indexed.
90
+ // nodes at a level is 2^level. level is zero indexed.
91
91
val maxLevelForSingleGroup = math.max(
92
- (math.log(maxNumberOfNodesPerGroup) / math.log(2 )).floor.toInt - 1 , 0 )
92
+ (math.log(maxNumberOfNodesPerGroup) / math.log(2 )).floor.toInt, 0 )
93
93
logDebug(" max level for single group = " + maxLevelForSingleGroup)
94
94
95
95
/*
@@ -299,11 +299,16 @@ object DecisionTree extends Serializable with Logging {
299
299
bins : Array [Array [Bin ]],
300
300
maxLevelForSingleGroup : Int ): Array [(Split , InformationGainStats )] = {
301
301
// split into groups to avoid memory overflow during aggregation
302
- if (level > maxLevelForSingleGroup) {
302
+ if (level > maxLevelForSingleGroup) {
303
+ // When information for all nodes at a given level cannot be stored in memory,
304
+ // the nodes are divided into multiple groups at each level with the number of groups
305
+ // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
306
+ // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
303
307
val numGroups = math.pow(2 , (level - maxLevelForSingleGroup)).toInt
304
308
logDebug(" numGroups = " + numGroups)
305
- var groupIndex = 0
306
309
var bestSplits = new Array [(Split , InformationGainStats )](0 )
310
+ // Iterate over each group of nodes at a level.
311
+ var groupIndex = 0
307
312
while (groupIndex < numGroups) {
308
313
val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
309
314
filters, splits, bins, numGroups, groupIndex)
0 commit comments