@@ -58,7 +58,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
5858 // Find the splits and the corresponding bins (interval between the splits) using a sample
5959 // of the input data.
6060 val (splits, bins) = DecisionTree .findSplitsBins(input, strategy)
61- logDebug(" numSplits = " + bins(0 ).length)
61+ val numBins = bins(0 ).length
62+ logDebug(" numBins = " + numBins)
6263
6364 // depth of the decision tree
6465 val maxDepth = strategy.maxDepth
@@ -72,7 +73,28 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7273 val parentImpurities = new Array [Double ](maxNumNodes)
7374 // dummy value for top node (updated during first split calculation)
7475 val nodes = new Array [Node ](maxNumNodes)
76+ // num features
77+ val numFeatures = input.take(1 )(0 ).features.size
78+
79+ // Calculate level for single group construction
7580
81+ // Max memory usage for aggregates
82+ val maxMemoryUsage = scala.math.pow(2 , 27 ).toInt // 128MB
83+ logDebug(" max memory usage for aggregates = " + maxMemoryUsage)
84+ val numElementsPerNode = {
85+ strategy.algo match {
86+ case Classification => 2 * numBins * numFeatures
87+ case Regression => 3 * numBins * numFeatures
88+ }
89+ }
90+ logDebug(" numElementsPerNode = " + numElementsPerNode)
91+ val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
92+ val maxNumberOfNodesPerGroup = scala.math.max(maxMemoryUsage / arraySizePerNode, 1 )
93+ logDebug(" maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup)
94+ // nodes at a level is 2^(level-1). level is zero indexed.
95+ val maxLevelForSingleGroup = scala.math.max(
96+ (scala.math.log(maxNumberOfNodesPerGroup) / scala.math.log(2 )).floor.toInt - 1 , 0 )
97+ logDebug(" max level for single group = " + maxLevelForSingleGroup)
7698
7799 /*
78100 * The main idea here is to perform level-wise training of the decision tree nodes thus
@@ -92,7 +114,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
92114
93115 // Find best split for all nodes at a level.
94116 val splitsStatsForLevel = DecisionTree .findBestSplits(input, parentImpurities, strategy,
95- level, filters, splits, bins)
117+ level, filters, splits, bins, maxLevelForSingleGroup )
96118
97119 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
98120 // Extract info for nodes at the current level.
@@ -110,6 +132,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
110132 }
111133 }
112134
135+ logDebug(" #####################################" )
136+ logDebug(" Extracting tree model" )
137+ logDebug(" #####################################" )
138+
113139 // Initialize the top or root node of the tree.
114140 val topNode = nodes(0 )
115141 // Build the full tree using the node info calculated in the level-wise best split calculations.
@@ -260,6 +286,7 @@ object DecisionTree extends Serializable with Logging {
260286 * @param filters Filters for all nodes at a given level
261287 * @param splits possible splits for all features
262288 * @param bins possible bins for all features
289+ * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
263290 * @return array of splits with best splits for all nodes at a given level.
264291 */
265292 protected [tree] def findBestSplits (
@@ -269,7 +296,50 @@ object DecisionTree extends Serializable with Logging {
269296 level : Int ,
270297 filters : Array [List [Filter ]],
271298 splits : Array [Array [Split ]],
272- bins : Array [Array [Bin ]]): Array [(Split , InformationGainStats )] = {
299+ bins : Array [Array [Bin ]],
300+ maxLevelForSingleGroup : Int ): Array [(Split , InformationGainStats )] = {
301+ // split into groups to avoid memory overflow during aggregation
302+ if (level > maxLevelForSingleGroup) {
303+ val numGroups = scala.math.pow(2 , (level - maxLevelForSingleGroup)).toInt
304+ logDebug(" numGroups = " + numGroups)
305+ var groupIndex = 0
306+ var bestSplits = new Array [(Split , InformationGainStats )](0 )
307+ while (groupIndex < numGroups) {
308+ val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
309+ filters, splits, bins, numGroups, groupIndex)
310+ bestSplits = Array .concat(bestSplits, bestSplitsForGroup)
311+ groupIndex += 1
312+ }
313+ bestSplits
314+ } else {
315+ findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins)
316+ }
317+ }
318+
319+ /**
320+ * Returns an array of optimal splits for a group of nodes at a given level
321+ *
322+ * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as training data
323+ * for DecisionTree
324+ * @param parentImpurities Impurities for all parent nodes for the current level
325+ * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy ]] instance containing
326+ * parameters for construction the DecisionTree
327+ * @param level Level of the tree
328+ * @param filters Filters for all nodes at a given level
329+ * @param splits possible splits for all features
330+ * @param bins possible bins for all features
331+ * @return array of splits with best splits for all nodes at a given level.
332+ */
333+ private def findBestSplitsPerGroup (
334+ input : RDD [LabeledPoint ],
335+ parentImpurities : Array [Double ],
336+ strategy : Strategy ,
337+ level : Int ,
338+ filters : Array [List [Filter ]],
339+ splits : Array [Array [Split ]],
340+ bins : Array [Array [Bin ]],
341+ numGroups : Int = 1 ,
342+ groupIndex : Int = 0 ): Array [(Split , InformationGainStats )] = {
273343
274344 /*
275345 * The high-level description for the best split optimizations are noted here.
@@ -296,20 +366,23 @@ object DecisionTree extends Serializable with Logging {
296366 */
297367
298368 // common calculations for multiple nested methods
299- val numNodes = scala.math.pow(2 , level).toInt
369+ val numNodes = scala.math.pow(2 , level).toInt / numGroups
300370 logDebug(" numNodes = " + numNodes)
301371 // Find the number of features by looking at the first sample.
302372 val numFeatures = input.first().features.size
303373 logDebug(" numFeatures = " + numFeatures)
304374 val numBins = bins(0 ).length
305375 logDebug(" numBins = " + numBins)
306376
377+ // shift when more than one group is used at deep tree level
378+ val groupShift = numNodes * groupIndex
379+
307380 /** Find the filters used before reaching the current code. */
308381 def findParentFilters (nodeIndex : Int ): List [Filter ] = {
309382 if (level == 0 ) {
310383 List [Filter ]()
311384 } else {
312- val nodeFilterIndex = scala.math.pow(2 , level).toInt - 1 + nodeIndex
385+ val nodeFilterIndex = scala.math.pow(2 , level).toInt - 1 + nodeIndex + groupShift
313386 filters(nodeFilterIndex)
314387 }
315388 }
@@ -878,7 +951,7 @@ object DecisionTree extends Serializable with Logging {
878951 // Iterating over all nodes at this level
879952 var node = 0
880953 while (node < numNodes) {
881- val nodeImpurityIndex = scala.math.pow(2 , level).toInt - 1 + node
954+ val nodeImpurityIndex = scala.math.pow(2 , level).toInt - 1 + node + groupShift
882955 val binsForNode : Array [Double ] = getBinDataForNode(node)
883956 logDebug(" nodeImpurityIndex = " + nodeImpurityIndex)
884957 val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
0 commit comments