@@ -437,6 +437,11 @@ object DecisionTree extends Serializable with Logging {
437
437
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
438
438
* @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
439
439
* Updated with new non-leaf nodes which are created.
440
+ * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
441
+ * each value in the array is the data point's node Id
442
+ * for a corresponding tree. This is used to prevent the need
443
+ * to pass the entire tree to the executors during
444
+ * the node stat aggregation phase.
440
445
*/
441
446
private [tree] def findBestSplits (
442
447
input : RDD [BaggedPoint [TreePoint ]],
@@ -447,7 +452,8 @@ object DecisionTree extends Serializable with Logging {
447
452
splits : Array [Array [Split ]],
448
453
bins : Array [Array [Bin ]],
449
454
nodeQueue : mutable.Queue [(Int , Node )],
450
- timer : TimeTracker = new TimeTracker ): Unit = {
455
+ timer : TimeTracker = new TimeTracker ,
456
+ nodeIdCache : Option [NodeIdCache ] = None ): Unit = {
451
457
452
458
/*
453
459
* The high-level descriptions of the best split optimizations are noted here.
@@ -479,6 +485,37 @@ object DecisionTree extends Serializable with Logging {
479
485
logDebug(" isMulticlass = " + metadata.isMulticlass)
480
486
logDebug(" isMulticlassWithCategoricalFeatures = " +
481
487
metadata.isMulticlassWithCategoricalFeatures)
488
+ logDebug(" using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
489
+
490
+ /**
491
+ * Performs a sequential aggregation over a partition for a particular tree and node.
492
+ *
493
+ * For each feature, the aggregate sufficient statistics are updated for the relevant
494
+ * bins.
495
+ *
496
+ * @param treeIndex Index of the tree that we want to perform aggregation for.
497
+ * @param nodeInfo The node info for the tree node.
498
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics
499
+ * for each (node, feature, bin).
500
+ * @param baggedPoint Data point being aggregated.
501
+ */
502
+ def nodeBinSeqOp (
503
+ treeIndex : Int ,
504
+ nodeInfo : RandomForest .NodeIndexInfo ,
505
+ agg : Array [DTStatsAggregator ],
506
+ baggedPoint : BaggedPoint [TreePoint ]): Unit = {
507
+ if (nodeInfo != null ) {
508
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
509
+ val featuresForNode = nodeInfo.featureSubset
510
+ val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
511
+ if (metadata.unorderedFeatures.isEmpty) {
512
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
513
+ } else {
514
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
515
+ instanceWeight, featuresForNode)
516
+ }
517
+ }
518
+ }
482
519
483
520
/**
484
521
* Performs a sequential aggregation over a partition.
@@ -497,20 +534,25 @@ object DecisionTree extends Serializable with Logging {
497
534
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
498
535
val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
499
536
bins, metadata.unorderedFeatures)
500
- val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null )
501
- // If the example does not reach a node in this group, then nodeIndex = null.
502
- if (nodeInfo != null ) {
503
- val aggNodeIndex = nodeInfo.nodeIndexInGroup
504
- val featuresForNode = nodeInfo.featureSubset
505
- val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
506
- if (metadata.unorderedFeatures.isEmpty) {
507
- orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
508
- } else {
509
- mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
510
- instanceWeight, featuresForNode)
511
- }
512
- }
537
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null ), agg, baggedPoint)
538
+ }
539
+
540
+ agg
541
+ }
542
+
543
+ /**
544
+ * Do the same thing as binSeqOp, but with nodeIdCache.
545
+ */
546
+ def binSeqOpWithNodeIdCache (
547
+ agg : Array [DTStatsAggregator ],
548
+ dataPoint : (BaggedPoint [TreePoint ], Array [Int ])): Array [DTStatsAggregator ] = {
549
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
550
+ val baggedPoint = dataPoint._1
551
+ val nodeIdCache = dataPoint._2
552
+ val nodeIndex = nodeIdCache(treeIndex)
553
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null ), agg, baggedPoint)
513
554
}
555
+
514
556
agg
515
557
}
516
558
@@ -553,7 +595,26 @@ object DecisionTree extends Serializable with Logging {
553
595
// Finally, only best Splits for nodes are collected to driver to construct decision tree.
554
596
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
555
597
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
556
- val nodeToBestSplits =
598
+
599
+ val partitionAggregates : RDD [(Int , DTStatsAggregator )] = if (nodeIdCache.nonEmpty) {
600
+ input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
601
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
602
+ // each node will have a nodeStatsAggregator
603
+ val nodeStatsAggregators = Array .tabulate(numNodes) { nodeIndex =>
604
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
605
+ Some (nodeToFeatures(nodeIndex))
606
+ }
607
+ new DTStatsAggregator (metadata, featuresForNode)
608
+ }
609
+
610
+ // iterator all instances in current partition and update aggregate stats
611
+ points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
612
+
613
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
614
+ // which can be combined with other partition using `reduceByKey`
615
+ nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
616
+ }
617
+ } else {
557
618
input.mapPartitions { points =>
558
619
// Construct a nodeStatsAggregators array to hold node aggregate stats,
559
620
// each node will have a nodeStatsAggregator
@@ -570,7 +631,10 @@ object DecisionTree extends Serializable with Logging {
570
631
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
571
632
// which can be combined with other partition using `reduceByKey`
572
633
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
573
- }.reduceByKey((a, b) => a.merge(b))
634
+ }
635
+ }
636
+
637
+ val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
574
638
.map { case (nodeIndex, aggStats) =>
575
639
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
576
640
Some (nodeToFeatures(nodeIndex))
@@ -584,6 +648,13 @@ object DecisionTree extends Serializable with Logging {
584
648
585
649
timer.stop(" chooseSplits" )
586
650
651
+ val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
652
+ Array .fill[mutable.Map [Int , NodeIndexUpdater ]](
653
+ metadata.numTrees)(mutable.Map [Int , NodeIndexUpdater ]())
654
+ } else {
655
+ null
656
+ }
657
+
587
658
// Iterate over all nodes in this group.
588
659
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
589
660
nodesForTree.foreach { node =>
@@ -613,6 +684,13 @@ object DecisionTree extends Serializable with Logging {
613
684
node.rightNode = Some (Node (Node .rightChildIndex(nodeIndex),
614
685
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
615
686
687
+ if (nodeIdCache.nonEmpty) {
688
+ val nodeIndexUpdater = NodeIndexUpdater (
689
+ split = split,
690
+ nodeIndex = nodeIndex)
691
+ nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
692
+ }
693
+
616
694
// enqueue left child and right child if they are not leaves
617
695
if (! leftChildIsLeaf) {
618
696
nodeQueue.enqueue((treeIndex, node.leftNode.get))
@@ -629,6 +707,10 @@ object DecisionTree extends Serializable with Logging {
629
707
}
630
708
}
631
709
710
+ if (nodeIdCache.nonEmpty) {
711
+ // Update the cache if needed.
712
+ nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins)
713
+ }
632
714
}
633
715
634
716
/**
0 commit comments