Skip to content

Commit 5e82202

Browse files
committed
added documentation, fixed off by 1 error in max level calculation
1 parent cbd9f14 commit 5e82202

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,20 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7676

7777
// Max memory usage for aggregates
7878
val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
79-
logDebug("max memory usage for aggregates = " + maxMemoryUsage)
79+
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
8080
val numElementsPerNode = {
8181
strategy.algo match {
82-
case Classification => 2 * numBins * numFeatures
82+
case Classification => 2 * numBins * numFeatures
8383
case Regression => 3 * numBins * numFeatures
8484
}
8585
}
8686
logDebug("numElementsPerNode = " + numElementsPerNode)
8787
val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
8888
val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1)
8989
logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup)
90-
// nodes at a level is 2^(level-1). level is zero indexed.
90+
// nodes at a level is 2^level. level is zero indexed.
9191
val maxLevelForSingleGroup = math.max(
92-
(math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt - 1, 0)
92+
(math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0)
9393
logDebug("max level for single group = " + maxLevelForSingleGroup)
9494

9595
/*
@@ -299,11 +299,16 @@ object DecisionTree extends Serializable with Logging {
299299
bins: Array[Array[Bin]],
300300
maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = {
301301
// split into groups to avoid memory overflow during aggregation
302-
if (level > maxLevelForSingleGroup) {
302+
if (level > maxLevelForSingleGroup) {
303+
// When information for all nodes at a given level cannot be stored in memory,
304+
// the nodes are divided into multiple groups at each level with the number of groups
305+
// increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
306+
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
303307
val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt
304308
logDebug("numGroups = " + numGroups)
305-
var groupIndex = 0
306309
var bestSplits = new Array[(Split, InformationGainStats)](0)
310+
// Iterate over each group of nodes at a level.
311+
var groupIndex = 0
307312
while (groupIndex < numGroups) {
308313
val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
309314
filters, splits, bins, numGroups, groupIndex)

0 commit comments

Comments
 (0)