Skip to content

Commit 50b143a

Browse files
committed
adding support for very deep trees
1 parent 3a390bf commit 50b143a

File tree

2 files changed

+85
-12
lines changed

2 files changed

+85
-12
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
5858
// Find the splits and the corresponding bins (interval between the splits) using a sample
5959
// of the input data.
6060
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
61-
logDebug("numSplits = " + bins(0).length)
61+
val numBins = bins(0).length
62+
logDebug("numBins = " + numBins)
6263

6364
// depth of the decision tree
6465
val maxDepth = strategy.maxDepth
@@ -72,7 +73,28 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7273
val parentImpurities = new Array[Double](maxNumNodes)
7374
// dummy value for top node (updated during first split calculation)
7475
val nodes = new Array[Node](maxNumNodes)
76+
// num features
77+
val numFeatures = input.take(1)(0).features.size
78+
79+
// Calculate level for single group construction
7580

81+
// Max memory usage for aggregates
82+
val maxMemoryUsage = scala.math.pow(2, 27).toInt //128MB
83+
logDebug("max memory usage for aggregates = " + maxMemoryUsage)
84+
val numElementsPerNode = {
85+
strategy.algo match {
86+
case Classification => 2 * numBins * numFeatures
87+
case Regression => 3 * numBins * numFeatures
88+
}
89+
}
90+
logDebug("numElementsPerNode = " + numElementsPerNode)
91+
val arraySizePerNode = 8 * numElementsPerNode //approx. memory usage for bin aggregate array
92+
val maxNumberOfNodesPerGroup = scala.math.max(maxMemoryUsage / arraySizePerNode, 1)
93+
logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup)
94+
// nodes at a level is 2^(level-1). level is zero indexed.
95+
val maxLevelForSingleGroup = scala.math.max(
96+
(scala.math.log(maxNumberOfNodesPerGroup) / scala.math.log(2)).floor.toInt - 1, 0)
97+
logDebug("max level for single group = " + maxLevelForSingleGroup)
7698

7799
/*
78100
* The main idea here is to perform level-wise training of the decision tree nodes thus
@@ -92,7 +114,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
92114

93115
// Find best split for all nodes at a level.
94116
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
95-
level, filters, splits, bins)
117+
level, filters, splits, bins, maxLevelForSingleGroup)
96118

97119
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
98120
// Extract info for nodes at the current level.
@@ -110,6 +132,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
110132
}
111133
}
112134

135+
logDebug("#####################################")
136+
logDebug("Extracting tree model")
137+
logDebug("#####################################")
138+
113139
// Initialize the top or root node of the tree.
114140
val topNode = nodes(0)
115141
// Build the full tree using the node info calculated in the level-wise best split calculations.
@@ -260,6 +286,7 @@ object DecisionTree extends Serializable with Logging {
260286
* @param filters Filters for all nodes at a given level
261287
* @param splits possible splits for all features
262288
* @param bins possible bins for all features
289+
* @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
263290
* @return array of splits with best splits for all nodes at a given level.
264291
*/
265292
protected[tree] def findBestSplits(
@@ -269,7 +296,50 @@ object DecisionTree extends Serializable with Logging {
269296
level: Int,
270297
filters: Array[List[Filter]],
271298
splits: Array[Array[Split]],
272-
bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = {
299+
bins: Array[Array[Bin]],
300+
maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = {
301+
// split into groups to avoid memory overflow during aggregation
302+
if (level > maxLevelForSingleGroup) {
303+
val numGroups = scala.math.pow(2, (level - maxLevelForSingleGroup)).toInt
304+
logDebug("numGroups = " + numGroups)
305+
var groupIndex = 0
306+
var bestSplits = new Array[(Split, InformationGainStats)](0)
307+
while (groupIndex < numGroups) {
308+
val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
309+
filters, splits, bins, numGroups, groupIndex)
310+
bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
311+
groupIndex += 1
312+
}
313+
bestSplits
314+
} else {
315+
findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins)
316+
}
317+
}
318+
319+
/**
320+
* Returns an array of optimal splits for a group of nodes at a given level
321+
*
322+
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
323+
* for DecisionTree
324+
* @param parentImpurities Impurities for all parent nodes for the current level
325+
* @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
326+
* parameters for construction the DecisionTree
327+
* @param level Level of the tree
328+
* @param filters Filters for all nodes at a given level
329+
* @param splits possible splits for all features
330+
* @param bins possible bins for all features
331+
* @return array of splits with best splits for all nodes at a given level.
332+
*/
333+
private def findBestSplitsPerGroup(
334+
input: RDD[LabeledPoint],
335+
parentImpurities: Array[Double],
336+
strategy: Strategy,
337+
level: Int,
338+
filters: Array[List[Filter]],
339+
splits: Array[Array[Split]],
340+
bins: Array[Array[Bin]],
341+
numGroups: Int = 1,
342+
groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
273343

274344
/*
275345
* The high-level description for the best split optimizations are noted here.
@@ -296,20 +366,23 @@ object DecisionTree extends Serializable with Logging {
296366
*/
297367

298368
// common calculations for multiple nested methods
299-
val numNodes = scala.math.pow(2, level).toInt
369+
val numNodes = scala.math.pow(2, level).toInt / numGroups
300370
logDebug("numNodes = " + numNodes)
301371
// Find the number of features by looking at the first sample.
302372
val numFeatures = input.first().features.size
303373
logDebug("numFeatures = " + numFeatures)
304374
val numBins = bins(0).length
305375
logDebug("numBins = " + numBins)
306376

377+
// shift when more than one group is used at deep tree level
378+
val groupShift = numNodes * groupIndex
379+
307380
/** Find the filters used before reaching the current code. */
308381
def findParentFilters(nodeIndex: Int): List[Filter] = {
309382
if (level == 0) {
310383
List[Filter]()
311384
} else {
312-
val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex
385+
val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + groupShift
313386
filters(nodeFilterIndex)
314387
}
315388
}
@@ -878,7 +951,7 @@ object DecisionTree extends Serializable with Logging {
878951
// Iterating over all nodes at this level
879952
var node = 0
880953
while (node < numNodes) {
881-
val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node
954+
val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + groupShift
882955
val binsForNode: Array[Double] = getBinDataForNode(node)
883956
logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
884957
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
254254
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
255255
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
256256
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
257-
Array[List[Filter]](), splits, bins)
257+
Array[List[Filter]](), splits, bins, 10)
258258

259259
val split = bestSplits(0)._1
260260
assert(split.categories.length === 1)
@@ -281,7 +281,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
281281
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
282282
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
283283
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
284-
Array[List[Filter]](), splits, bins)
284+
Array[List[Filter]](), splits, bins, 10)
285285

286286
val split = bestSplits(0)._1
287287
assert(split.categories.length === 1)
@@ -310,7 +310,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
310310
assert(bins(0).length === 100)
311311

312312
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
313-
Array[List[Filter]](), splits, bins)
313+
Array[List[Filter]](), splits, bins, 10)
314314
assert(bestSplits.length === 1)
315315
assert(bestSplits(0)._1.feature === 0)
316316
assert(bestSplits(0)._1.threshold === 10)
@@ -333,7 +333,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
333333
assert(bins(0).length === 100)
334334

335335
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
336-
Array[List[Filter]](), splits, bins)
336+
Array[List[Filter]](), splits, bins, 10)
337337
assert(bestSplits.length === 1)
338338
assert(bestSplits(0)._1.feature === 0)
339339
assert(bestSplits(0)._1.threshold === 10)
@@ -357,7 +357,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
357357
assert(bins(0).length === 100)
358358

359359
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
360-
Array[List[Filter]](), splits, bins)
360+
Array[List[Filter]](), splits, bins, 10)
361361
assert(bestSplits.length === 1)
362362
assert(bestSplits(0)._1.feature === 0)
363363
assert(bestSplits(0)._1.threshold === 10)
@@ -381,7 +381,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
381381
assert(bins(0).length === 100)
382382

383383
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
384-
Array[List[Filter]](), splits, bins)
384+
Array[List[Filter]](), splits, bins, 10)
385385
assert(bestSplits.length === 1)
386386
assert(bestSplits(0)._1.feature === 0)
387387
assert(bestSplits(0)._1.threshold === 10)

0 commit comments

Comments
 (0)