@@ -453,6 +453,129 @@ object DecisionTree extends Serializable with Logging {
453
453
}
454
454
}
455
455
456
+ /**
457
+ * Get the node index corresponding to this data point.
458
+ * This function mimics prediction, passing an example from the root node down to a node
459
+ * at the current level being trained; that node's index is returned.
460
+ *
461
+ * @return Leaf index if the data point reaches a leaf.
462
+ * Otherwise, last node reachable in tree matching this example.
463
+ * Note: This is the global node index, i.e., the index used in the tree.
464
+ * This index is different from the index used during training a particular
465
+ * set of nodes in a (level, group).
466
+ */
467
+ def predictNodeIndex (node : Node , binnedFeatures : Array [Int ], bins : Array [Array [Bin ]], unorderedFeatures : Set [Int ]): Int = {
468
+ if (node.isLeaf) {
469
+ node.id
470
+ } else {
471
+ val featureIndex = node.split.get.feature
472
+ val splitLeft = node.split.get.featureType match {
473
+ case Continuous => {
474
+ val binIndex = binnedFeatures(featureIndex)
475
+ val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
476
+ // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
477
+ // We do not need to check lowSplit since bins are separated by splits.
478
+ featureValueUpperBound <= node.split.get.threshold
479
+ }
480
+ case Categorical => {
481
+ val featureValue = if (unorderedFeatures.contains(featureIndex)) {
482
+ binnedFeatures(featureIndex)
483
+ } else {
484
+ val binIndex = binnedFeatures(featureIndex)
485
+ bins(featureIndex)(binIndex).category
486
+ }
487
+ node.split.get.categories.contains(featureValue)
488
+ }
489
+ case _ => throw new RuntimeException (s " predictNodeIndex failed for unknown reason. " )
490
+ }
491
+ if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
492
+ // Return index from next layer of nodes to train
493
+ if (splitLeft) {
494
+ Node .leftChildIndex(node.id)
495
+ } else {
496
+ Node .rightChildIndex(node.id)
497
+ }
498
+ } else {
499
+ if (splitLeft) {
500
+ predictNodeIndex(node.leftNode.get, binnedFeatures, bins, unorderedFeatures)
501
+ } else {
502
+ predictNodeIndex(node.rightNode.get, binnedFeatures, bins, unorderedFeatures)
503
+ }
504
+ }
505
+ }
506
+ }
507
+
508
+ /**
509
+ * Helper for binSeqOp.
510
+ *
511
+ * @param agg Array storing aggregate calculation.
512
+ * For ordered features, this is of size:
513
+ * numClasses * numBins * numFeatures * numNodes.
514
+ * For unordered features, this is of size:
515
+ * 2 * numClasses * numBins * numFeatures * numNodes.
516
+ * @param treePoint Data point being aggregated.
517
+ * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
518
+ */
519
+ def someUnorderedBinSeqOp (
520
+ agg : Array [Array [Array [ImpurityAggregator ]]],
521
+ treePoint : TreePoint ,
522
+ nodeIndex : Int , bins : Array [Array [Bin ]], unorderedFeatures : Set [Int ]): Unit = {
523
+ // Iterate over all features.
524
+ val numFeatures = treePoint.binnedFeatures.size
525
+ var featureIndex = 0
526
+ while (featureIndex < numFeatures) {
527
+ if (unorderedFeatures.contains(featureIndex)) {
528
+ // Unordered feature
529
+ val featureValue = treePoint.binnedFeatures(featureIndex)
530
+ // Update the left or right count for one bin.
531
+ // Find all matching bins and increment their values.
532
+ val numCategoricalBins = bins(featureIndex).size // metadata.numBins(featureIndex)
533
+ var binIndex = 0
534
+ while (binIndex < numCategoricalBins) {
535
+ if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) {
536
+ agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label)
537
+ } else {
538
+ agg(nodeIndex)(featureIndex)(numCategoricalBins + binIndex).add(treePoint.label)
539
+ }
540
+ binIndex += 1
541
+ }
542
+ } else {
543
+ // Ordered feature
544
+ val binIndex = treePoint.binnedFeatures(featureIndex)
545
+ agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label)
546
+ }
547
+ featureIndex += 1
548
+ }
549
+ }
550
+
551
+ /**
552
+ * Helper for binSeqOp: for regression and for classification with only ordered features.
553
+ *
554
+ * Performs a sequential aggregation over a partition for regression.
555
+ * For l nodes, k features,
556
+ * the count, sum, sum of squares of one of the p bins is incremented.
557
+ *
558
+ * @param agg Array storing aggregate calculation, updated by this function.
559
+ * Size: 3 * numBins * numFeatures * numNodes
560
+ * @param treePoint Data point being aggregated.
561
+ * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
562
+ * @return agg
563
+ */
564
+ def orderedBinSeqOp (
565
+ agg : Array [Array [Array [ImpurityAggregator ]]],
566
+ treePoint : TreePoint ,
567
+ nodeIndex : Int ): Unit = {
568
+ val label = treePoint.label
569
+ // Iterate over all features.
570
+ val numFeatures = treePoint.binnedFeatures.size
571
+ var featureIndex = 0
572
+ while (featureIndex < numFeatures) {
573
+ val binIndex = treePoint.binnedFeatures(featureIndex)
574
+ agg(nodeIndex)(featureIndex)(binIndex).add(label)
575
+ featureIndex += 1
576
+ }
577
+ }
578
+
456
579
/**
457
580
* Returns an array of optimal splits for a group of nodes at a given level
458
581
*
@@ -529,60 +652,8 @@ object DecisionTree extends Serializable with Logging {
529
652
// shift when more than one group is used at deep tree level
530
653
val groupShift = numNodes * groupIndex
531
654
532
- /**
533
- * Get the node index corresponding to this data point.
534
- * This function mimics prediction, passing an example from the root node down to a node
535
- * at the current level being trained; that node's index is returned.
536
- *
537
- * @return Leaf index if the data point reaches a leaf.
538
- * Otherwise, last node reachable in tree matching this example.
539
- * Note: This is the global node index, i.e., the index used in the tree.
540
- * This index is different from the index used during training a particular
541
- * set of nodes in a (level, group).
542
- */
543
- def predictNodeIndex (node : Node , binnedFeatures : Array [Int ]): Int = {
544
- if (node.isLeaf) {
545
- node.id
546
- } else {
547
- val featureIndex = node.split.get.feature
548
- val splitLeft = node.split.get.featureType match {
549
- case Continuous => {
550
- val binIndex = binnedFeatures(featureIndex)
551
- val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
552
- // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
553
- // We do not need to check lowSplit since bins are separated by splits.
554
- featureValueUpperBound <= node.split.get.threshold
555
- }
556
- case Categorical => {
557
- val featureValue = if (metadata.isUnordered(featureIndex)) {
558
- binnedFeatures(featureIndex)
559
- } else {
560
- val binIndex = binnedFeatures(featureIndex)
561
- bins(featureIndex)(binIndex).category
562
- }
563
- node.split.get.categories.contains(featureValue)
564
- }
565
- case _ => throw new RuntimeException (s " predictNodeIndex failed for unknown reason. " )
566
- }
567
- if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
568
- // Return index from next layer of nodes to train
569
- if (splitLeft) {
570
- Node .leftChildIndex(node.id)
571
- } else {
572
- Node .rightChildIndex(node.id)
573
- }
574
- } else {
575
- if (splitLeft) {
576
- predictNodeIndex(node.leftNode.get, binnedFeatures)
577
- } else {
578
- predictNodeIndex(node.rightNode.get, binnedFeatures)
579
- }
580
- }
581
- }
582
- }
583
-
584
655
// Used for treePointToNodeIndex
585
- val levelOffset = Node .maxNodesInSubtree(level - 1 )
656
+ val globalNodeIndexOffset = Node .maxNodesInSubtree(level - 1 ) + groupShift + 1
586
657
587
658
/**
588
659
* Find the node index for the given example.
@@ -593,90 +664,12 @@ object DecisionTree extends Serializable with Logging {
593
664
if (level == 0 ) {
594
665
0
595
666
} else {
596
- val globalNodeIndex = predictNodeIndex(nodes(1 ), treePoint.binnedFeatures)
667
+ val globalNodeIndex = predictNodeIndex(nodes(1 ), treePoint.binnedFeatures, bins, metadata.unorderedFeatures )
597
668
// Get index for this (level, group).
598
669
// - levelOffset corrects for nodes before this level.
599
670
// - groupShift corrects for groups in this level before the current group.
600
671
// - 1 corrects for the fact that globalNodeIndex starts at 1, not 0.
601
- globalNodeIndex - levelOffset - groupShift - 1
602
- }
603
- }
604
-
605
-
606
- val rightChildShift = numClasses * numBins * numFeatures * numNodes
607
-
608
- /**
609
- * Helper for binSeqOp.
610
- *
611
- * @param agg Array storing aggregate calculation.
612
- * For ordered features, this is of size:
613
- * numClasses * numBins * numFeatures * numNodes.
614
- * For unordered features, this is of size:
615
- * 2 * numClasses * numBins * numFeatures * numNodes.
616
- * @param treePoint Data point being aggregated.
617
- * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
618
- */
619
- def someUnorderedBinSeqOp (
620
- agg : Array [Array [Array [ImpurityAggregator ]]],
621
- treePoint : TreePoint ,
622
- nodeIndex : Int ): Unit = {
623
- val label = treePoint.label
624
- // Iterate over all features.
625
- var featureIndex = 0
626
- while (featureIndex < numFeatures) {
627
- if (metadata.isUnordered(featureIndex)) {
628
- // Unordered feature
629
- val featureValue = treePoint.binnedFeatures(featureIndex)
630
- // Update the left or right count for one bin.
631
- // Find all matching bins and increment their values.
632
- val numCategoricalBins = metadata.numBins(featureIndex)
633
- var binIndex = 0
634
- while (binIndex < numCategoricalBins) {
635
- if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) {
636
- agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label)
637
- } else {
638
- agg(nodeIndex)(featureIndex)(numCategoricalBins + binIndex).add(treePoint.label)
639
- }
640
- binIndex += 1
641
- }
642
- } else {
643
- // Ordered feature
644
- val binIndex = treePoint.binnedFeatures(featureIndex)
645
- agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label)
646
- }
647
- featureIndex += 1
648
- }
649
- }
650
-
651
- /**
652
- * Helper for binSeqOp: for regression and for classification with only ordered features.
653
- *
654
- * Performs a sequential aggregation over a partition for regression.
655
- * For l nodes, k features,
656
- * the count, sum, sum of squares of one of the p bins is incremented.
657
- *
658
- * @param agg Array storing aggregate calculation, updated by this function.
659
- * Size: 3 * numBins * numFeatures * numNodes
660
- * @param treePoint Data point being aggregated.
661
- * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
662
- * @return agg
663
- */
664
- def orderedBinSeqOp (
665
- agg : Array [Array [Array [ImpurityAggregator ]]],
666
- treePoint : TreePoint ,
667
- nodeIndex : Int ): Unit = {
668
- val label = treePoint.label
669
- // Iterate over all features.
670
- var featureIndex = 0
671
- while (featureIndex < numFeatures) {
672
- // Update count, sum, and sum^2 for one bin.
673
- val binIndex = treePoint.binnedFeatures(featureIndex)
674
- if (binIndex >= agg(nodeIndex)(featureIndex).size) {
675
- throw new RuntimeException (
676
- s " binIndex: $binIndex, agg(nodeIndex)(featureIndex).size = ${agg(nodeIndex)(featureIndex).size}" )
677
- }
678
- agg(nodeIndex)(featureIndex)(binIndex).add(label)
679
- featureIndex += 1
672
+ globalNodeIndex - globalNodeIndexOffset
680
673
}
681
674
}
682
675
@@ -709,7 +702,7 @@ object DecisionTree extends Serializable with Logging {
709
702
if (metadata.unorderedFeatures.isEmpty) {
710
703
orderedBinSeqOp(agg, treePoint, nodeIndex)
711
704
} else {
712
- someUnorderedBinSeqOp(agg, treePoint, nodeIndex)
705
+ someUnorderedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures )
713
706
}
714
707
}
715
708
agg
@@ -751,7 +744,6 @@ object DecisionTree extends Serializable with Logging {
751
744
// Calculate best splits for all nodes at a given level
752
745
timer.start(" chooseSplits" )
753
746
val bestSplits = new Array [(Split , InformationGainStats )](numNodes)
754
- val globalNodeIndexOffset = Node .maxNodesInSubtree(level - 1 ) + groupShift + 1
755
747
// Iterating over all nodes at this level
756
748
var nodeIndex = 0
757
749
while (nodeIndex < numNodes) {
0 commit comments