Skip to content

Commit 56f2c61

Browse files
Sung Chungmengxr
Sung Chung
authored andcommitted
[SPARK-3161][MLLIB] Adding a node Id caching mechanism for training deci...
...sion trees. jkbradley mengxr chouqin Please review this. Author: Sung Chung <schung@alpinenow.com> Closes #2868 from codedeft/SPARK-3161 and squashes the following commits: 5f5a156 [Sung Chung] [SPARK-3161][MLLIB] Adding a node Id caching mechanism for training decision trees.
1 parent d8176b1 commit 56f2c61

File tree

6 files changed

+405
-41
lines changed

6 files changed

+405
-41
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ object DecisionTreeRunner {
6262
minInfoGain: Double = 0.0,
6363
numTrees: Int = 1,
6464
featureSubsetStrategy: String = "auto",
65-
fracTest: Double = 0.2) extends AbstractParams[Params]
65+
fracTest: Double = 0.2,
66+
useNodeIdCache: Boolean = false,
67+
checkpointDir: Option[String] = None,
68+
checkpointInterval: Int = 10) extends AbstractParams[Params]
6669

6770
def main(args: Array[String]) {
6871
val defaultParams = Params()
@@ -102,6 +105,21 @@ object DecisionTreeRunner {
102105
.text(s"fraction of data to hold out for testing. If given option testInput, " +
103106
s"this option is ignored. default: ${defaultParams.fracTest}")
104107
.action((x, c) => c.copy(fracTest = x))
108+
opt[Boolean]("useNodeIdCache")
109+
.text(s"whether to use node Id cache during training, " +
110+
s"default: ${defaultParams.useNodeIdCache}")
111+
.action((x, c) => c.copy(useNodeIdCache = x))
112+
opt[String]("checkpointDir")
113+
.text(s"checkpoint directory where intermediate node Id caches will be stored, " +
114+
s"default: ${defaultParams.checkpointDir match {
115+
case Some(strVal) => strVal
116+
case None => "None"
117+
}}")
118+
.action((x, c) => c.copy(checkpointDir = Some(x)))
119+
opt[Int]("checkpointInterval")
120+
.text(s"how often to checkpoint the node Id cache, " +
121+
s"default: ${defaultParams.checkpointInterval}")
122+
.action((x, c) => c.copy(checkpointInterval = x))
105123
opt[String]("testInput")
106124
.text(s"input path to test dataset. If given, option fracTest is ignored." +
107125
s" default: ${defaultParams.testInput}")
@@ -236,7 +254,10 @@ object DecisionTreeRunner {
236254
maxBins = params.maxBins,
237255
numClassesForClassification = numClasses,
238256
minInstancesPerNode = params.minInstancesPerNode,
239-
minInfoGain = params.minInfoGain)
257+
minInfoGain = params.minInfoGain,
258+
useNodeIdCache = params.useNodeIdCache,
259+
checkpointDir = params.checkpointDir,
260+
checkpointInterval = params.checkpointInterval)
240261
if (params.numTrees == 1) {
241262
val startTime = System.nanoTime()
242263
val model = DecisionTree.train(training, strategy)

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 98 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,11 @@ object DecisionTree extends Serializable with Logging {
437437
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
438438
* @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
439439
* Updated with new non-leaf nodes which are created.
440+
* @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
441+
* each value in the array is the data point's node Id
442+
* for a corresponding tree. This is used to prevent the need
443+
* to pass the entire tree to the executors during
444+
* the node stat aggregation phase.
440445
*/
441446
private[tree] def findBestSplits(
442447
input: RDD[BaggedPoint[TreePoint]],
@@ -447,7 +452,8 @@ object DecisionTree extends Serializable with Logging {
447452
splits: Array[Array[Split]],
448453
bins: Array[Array[Bin]],
449454
nodeQueue: mutable.Queue[(Int, Node)],
450-
timer: TimeTracker = new TimeTracker): Unit = {
455+
timer: TimeTracker = new TimeTracker,
456+
nodeIdCache: Option[NodeIdCache] = None): Unit = {
451457

452458
/*
453459
* The high-level descriptions of the best split optimizations are noted here.
@@ -479,6 +485,37 @@ object DecisionTree extends Serializable with Logging {
479485
logDebug("isMulticlass = " + metadata.isMulticlass)
480486
logDebug("isMulticlassWithCategoricalFeatures = " +
481487
metadata.isMulticlassWithCategoricalFeatures)
488+
logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
489+
490+
/**
491+
* Performs a sequential aggregation over a partition for a particular tree and node.
492+
*
493+
* For each feature, the aggregate sufficient statistics are updated for the relevant
494+
* bins.
495+
*
496+
* @param treeIndex Index of the tree that we want to perform aggregation for.
497+
* @param nodeInfo The node info for the tree node.
498+
* @param agg Array storing aggregate calculation, with a set of sufficient statistics
499+
* for each (node, feature, bin).
500+
* @param baggedPoint Data point being aggregated.
501+
*/
502+
def nodeBinSeqOp(
503+
treeIndex: Int,
504+
nodeInfo: RandomForest.NodeIndexInfo,
505+
agg: Array[DTStatsAggregator],
506+
baggedPoint: BaggedPoint[TreePoint]): Unit = {
507+
if (nodeInfo != null) {
508+
val aggNodeIndex = nodeInfo.nodeIndexInGroup
509+
val featuresForNode = nodeInfo.featureSubset
510+
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
511+
if (metadata.unorderedFeatures.isEmpty) {
512+
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
513+
} else {
514+
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
515+
instanceWeight, featuresForNode)
516+
}
517+
}
518+
}
482519

483520
/**
484521
* Performs a sequential aggregation over a partition.
@@ -497,20 +534,25 @@ object DecisionTree extends Serializable with Logging {
497534
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
498535
val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
499536
bins, metadata.unorderedFeatures)
500-
val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null)
501-
// If the example does not reach a node in this group, then nodeIndex = null.
502-
if (nodeInfo != null) {
503-
val aggNodeIndex = nodeInfo.nodeIndexInGroup
504-
val featuresForNode = nodeInfo.featureSubset
505-
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
506-
if (metadata.unorderedFeatures.isEmpty) {
507-
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
508-
} else {
509-
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
510-
instanceWeight, featuresForNode)
511-
}
512-
}
537+
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
538+
}
539+
540+
agg
541+
}
542+
543+
/**
544+
* Do the same thing as binSeqOp, but with nodeIdCache.
545+
*/
546+
def binSeqOpWithNodeIdCache(
547+
agg: Array[DTStatsAggregator],
548+
dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
549+
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
550+
val baggedPoint = dataPoint._1
551+
val nodeIdCache = dataPoint._2
552+
val nodeIndex = nodeIdCache(treeIndex)
553+
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
513554
}
555+
514556
agg
515557
}
516558

@@ -553,7 +595,26 @@ object DecisionTree extends Serializable with Logging {
553595
// Finally, only best Splits for nodes are collected to driver to construct decision tree.
554596
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
555597
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
556-
val nodeToBestSplits =
598+
599+
val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
600+
input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
601+
// Construct a nodeStatsAggregators array to hold node aggregate stats,
602+
// each node will have a nodeStatsAggregator
603+
val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
604+
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
605+
Some(nodeToFeatures(nodeIndex))
606+
}
607+
new DTStatsAggregator(metadata, featuresForNode)
608+
}
609+
610+
// iterator all instances in current partition and update aggregate stats
611+
points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
612+
613+
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
614+
// which can be combined with other partition using `reduceByKey`
615+
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
616+
}
617+
} else {
557618
input.mapPartitions { points =>
558619
// Construct a nodeStatsAggregators array to hold node aggregate stats,
559620
// each node will have a nodeStatsAggregator
@@ -570,7 +631,10 @@ object DecisionTree extends Serializable with Logging {
570631
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
571632
// which can be combined with other partition using `reduceByKey`
572633
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
573-
}.reduceByKey((a, b) => a.merge(b))
634+
}
635+
}
636+
637+
val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
574638
.map { case (nodeIndex, aggStats) =>
575639
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
576640
Some(nodeToFeatures(nodeIndex))
@@ -584,6 +648,13 @@ object DecisionTree extends Serializable with Logging {
584648

585649
timer.stop("chooseSplits")
586650

651+
val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
652+
Array.fill[mutable.Map[Int, NodeIndexUpdater]](
653+
metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
654+
} else {
655+
null
656+
}
657+
587658
// Iterate over all nodes in this group.
588659
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
589660
nodesForTree.foreach { node =>
@@ -613,6 +684,13 @@ object DecisionTree extends Serializable with Logging {
613684
node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
614685
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
615686

687+
if (nodeIdCache.nonEmpty) {
688+
val nodeIndexUpdater = NodeIndexUpdater(
689+
split = split,
690+
nodeIndex = nodeIndex)
691+
nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
692+
}
693+
616694
// enqueue left child and right child if they are not leaves
617695
if (!leftChildIsLeaf) {
618696
nodeQueue.enqueue((treeIndex, node.leftNode.get))
@@ -629,6 +707,10 @@ object DecisionTree extends Serializable with Logging {
629707
}
630708
}
631709

710+
if (nodeIdCache.nonEmpty) {
711+
// Update the cache if needed.
712+
nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins)
713+
}
632714
}
633715

634716
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
2828
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
2929
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
3030
import org.apache.spark.mllib.tree.configuration.Strategy
31-
import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker}
31+
import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache }
3232
import org.apache.spark.mllib.tree.impurity.Impurities
3333
import org.apache.spark.mllib.tree.model._
3434
import org.apache.spark.rdd.RDD
@@ -160,6 +160,19 @@ private class RandomForest (
160160
* in lower levels).
161161
*/
162162

163+
// Create an RDD of node Id cache.
164+
// At first, all the rows belong to the root nodes (node Id == 1).
165+
val nodeIdCache = if (strategy.useNodeIdCache) {
166+
Some(NodeIdCache.init(
167+
data = baggedInput,
168+
numTrees = numTrees,
169+
checkpointDir = strategy.checkpointDir,
170+
checkpointInterval = strategy.checkpointInterval,
171+
initVal = 1))
172+
} else {
173+
None
174+
}
175+
163176
// FIFO queue of nodes to train: (treeIndex, node)
164177
val nodeQueue = new mutable.Queue[(Int, Node)]()
165178

@@ -182,7 +195,7 @@ private class RandomForest (
182195
// Choose node splits, and enqueue new nodes as needed.
183196
timer.start("findBestSplits")
184197
DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
185-
treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
198+
treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
186199
timer.stop("findBestSplits")
187200
}
188201

@@ -193,6 +206,11 @@ private class RandomForest (
193206
logInfo("Internal timing for DecisionTree:")
194207
logInfo(s"$timer")
195208

209+
// Delete any remaining checkpoints used for node Id cache.
210+
if (nodeIdCache.nonEmpty) {
211+
nodeIdCache.get.deleteAllCheckpoints()
212+
}
213+
196214
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
197215
val treeWeights = Array.fill[Double](numTrees)(1.0)
198216
new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average)

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
6060
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
6161
* 256 MB.
6262
* @param subsamplingRate Fraction of the training data used for learning decision tree.
63+
* @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
64+
* maintain a separate RDD of node Id cache for each row.
65+
* @param checkpointDir If the node Id cache is used, it will help to checkpoint
66+
* the node Id cache periodically. This is the checkpoint directory
67+
* to be used for the node Id cache.
68+
* @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
69+
* E.g. 10 means that the cache will get checkpointed every 10 updates.
6370
*/
6471
@Experimental
6572
class Strategy (
@@ -73,7 +80,10 @@ class Strategy (
7380
@BeanProperty var minInstancesPerNode: Int = 1,
7481
@BeanProperty var minInfoGain: Double = 0.0,
7582
@BeanProperty var maxMemoryInMB: Int = 256,
76-
@BeanProperty var subsamplingRate: Double = 1) extends Serializable {
83+
@BeanProperty var subsamplingRate: Double = 1,
84+
@BeanProperty var useNodeIdCache: Boolean = false,
85+
@BeanProperty var checkpointDir: Option[String] = None,
86+
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
7787

7888
if (algo == Classification) {
7989
require(numClassesForClassification >= 2)

0 commit comments

Comments
 (0)