@@ -76,20 +76,20 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7676
7777 // Max memory usage for aggregates
7878 val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
79- logDebug(" max memory usage for aggregates = " + maxMemoryUsage)
79+ logDebug(" max memory usage for aggregates = " + maxMemoryUsage + " bytes. " )
8080 val numElementsPerNode = {
8181 strategy.algo match {
82- case Classification => 2 * numBins * numFeatures
82+ case Classification => 2 * numBins * numFeatures
8383 case Regression => 3 * numBins * numFeatures
8484 }
8585 }
8686 logDebug(" numElementsPerNode = " + numElementsPerNode)
8787 val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
8888 val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1 )
8989 logDebug(" maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup)
90- // nodes at a level is 2^( level-1) . level is zero indexed.
90+ // nodes at a level is 2^level. level is zero indexed.
9191 val maxLevelForSingleGroup = math.max(
92- (math.log(maxNumberOfNodesPerGroup) / math.log(2 )).floor.toInt - 1 , 0 )
92+ (math.log(maxNumberOfNodesPerGroup) / math.log(2 )).floor.toInt, 0 )
9393 logDebug(" max level for single group = " + maxLevelForSingleGroup)
9494
9595 /*
@@ -299,11 +299,16 @@ object DecisionTree extends Serializable with Logging {
299299 bins : Array [Array [Bin ]],
300300 maxLevelForSingleGroup : Int ): Array [(Split , InformationGainStats )] = {
301301 // split into groups to avoid memory overflow during aggregation
302- if (level > maxLevelForSingleGroup) {
302+ if (level > maxLevelForSingleGroup) {
303+ // When information for all nodes at a given level cannot be stored in memory,
304+ // the nodes are divided into multiple groups at each level with the number of groups
305+ // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
306+ // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
303307 val numGroups = math.pow(2 , (level - maxLevelForSingleGroup)).toInt
304308 logDebug(" numGroups = " + numGroups)
305- var groupIndex = 0
306309 var bestSplits = new Array [(Split , InformationGainStats )](0 )
310+ // Iterate over each group of nodes at a level.
311+ var groupIndex = 0
307312 while (groupIndex < numGroups) {
308313 val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
309314 filters, splits, bins, numGroups, groupIndex)
0 commit comments