Skip to content

Commit fd8df30

Browse files
committed
Moved some aggregation helpers outside of findBestSplitsPerGroup
1 parent d7c53ee commit fd8df30

File tree

1 file changed

+127
-135
lines changed

1 file changed

+127
-135
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 127 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,129 @@ object DecisionTree extends Serializable with Logging {
453453
}
454454
}
455455

456+
/**
457+
* Get the node index corresponding to this data point.
458+
* This function mimics prediction, passing an example from the root node down to a node
459+
* at the current level being trained; that node's index is returned.
460+
*
461+
* @return Leaf index if the data point reaches a leaf.
462+
* Otherwise, last node reachable in tree matching this example.
463+
* Note: This is the global node index, i.e., the index used in the tree.
464+
* This index is different from the index used during training a particular
465+
* set of nodes in a (level, group).
466+
*/
467+
def predictNodeIndex(node: Node, binnedFeatures: Array[Int], bins: Array[Array[Bin]], unorderedFeatures: Set[Int]): Int = {
468+
if (node.isLeaf) {
469+
node.id
470+
} else {
471+
val featureIndex = node.split.get.feature
472+
val splitLeft = node.split.get.featureType match {
473+
case Continuous => {
474+
val binIndex = binnedFeatures(featureIndex)
475+
val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
476+
// bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
477+
// We do not need to check lowSplit since bins are separated by splits.
478+
featureValueUpperBound <= node.split.get.threshold
479+
}
480+
case Categorical => {
481+
val featureValue = if (unorderedFeatures.contains(featureIndex)) {
482+
binnedFeatures(featureIndex)
483+
} else {
484+
val binIndex = binnedFeatures(featureIndex)
485+
bins(featureIndex)(binIndex).category
486+
}
487+
node.split.get.categories.contains(featureValue)
488+
}
489+
case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")
490+
}
491+
if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
492+
// Return index from next layer of nodes to train
493+
if (splitLeft) {
494+
Node.leftChildIndex(node.id)
495+
} else {
496+
Node.rightChildIndex(node.id)
497+
}
498+
} else {
499+
if (splitLeft) {
500+
predictNodeIndex(node.leftNode.get, binnedFeatures, bins, unorderedFeatures)
501+
} else {
502+
predictNodeIndex(node.rightNode.get, binnedFeatures, bins, unorderedFeatures)
503+
}
504+
}
505+
}
506+
}
507+
508+
/**
509+
* Helper for binSeqOp.
510+
*
511+
* @param agg Array storing aggregate calculation.
512+
* For ordered features, this is of size:
513+
* numClasses * numBins * numFeatures * numNodes.
514+
* For unordered features, this is of size:
515+
* 2 * numClasses * numBins * numFeatures * numNodes.
516+
* @param treePoint Data point being aggregated.
517+
* @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
518+
*/
519+
def someUnorderedBinSeqOp(
520+
agg: Array[Array[Array[ImpurityAggregator]]],
521+
treePoint: TreePoint,
522+
nodeIndex: Int, bins: Array[Array[Bin]], unorderedFeatures: Set[Int]): Unit = {
523+
// Iterate over all features.
524+
val numFeatures = treePoint.binnedFeatures.size
525+
var featureIndex = 0
526+
while (featureIndex < numFeatures) {
527+
if (unorderedFeatures.contains(featureIndex)) {
528+
// Unordered feature
529+
val featureValue = treePoint.binnedFeatures(featureIndex)
530+
// Update the left or right count for one bin.
531+
// Find all matching bins and increment their values.
532+
val numCategoricalBins = bins(featureIndex).size //metadata.numBins(featureIndex)
533+
var binIndex = 0
534+
while (binIndex < numCategoricalBins) {
535+
if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) {
536+
agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label)
537+
} else {
538+
agg(nodeIndex)(featureIndex)(numCategoricalBins + binIndex).add(treePoint.label)
539+
}
540+
binIndex += 1
541+
}
542+
} else {
543+
// Ordered feature
544+
val binIndex = treePoint.binnedFeatures(featureIndex)
545+
agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label)
546+
}
547+
featureIndex += 1
548+
}
549+
}
550+
551+
/**
552+
* Helper for binSeqOp: for regression and for classification with only ordered features.
553+
*
554+
* Performs a sequential aggregation over a partition for regression.
555+
* For l nodes, k features,
556+
* the count, sum, sum of squares of one of the p bins is incremented.
557+
*
558+
* @param agg Array storing aggregate calculation, updated by this function.
559+
* Size: 3 * numBins * numFeatures * numNodes
560+
* @param treePoint Data point being aggregated.
561+
* @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
562+
* @return agg
563+
*/
564+
def orderedBinSeqOp(
565+
agg: Array[Array[Array[ImpurityAggregator]]],
566+
treePoint: TreePoint,
567+
nodeIndex: Int): Unit = {
568+
val label = treePoint.label
569+
// Iterate over all features.
570+
val numFeatures = treePoint.binnedFeatures.size
571+
var featureIndex = 0
572+
while (featureIndex < numFeatures) {
573+
val binIndex = treePoint.binnedFeatures(featureIndex)
574+
agg(nodeIndex)(featureIndex)(binIndex).add(label)
575+
featureIndex += 1
576+
}
577+
}
578+
456579
/**
457580
* Returns an array of optimal splits for a group of nodes at a given level
458581
*
@@ -529,60 +652,8 @@ object DecisionTree extends Serializable with Logging {
529652
// shift when more than one group is used at deep tree level
530653
val groupShift = numNodes * groupIndex
531654

532-
/**
533-
* Get the node index corresponding to this data point.
534-
* This function mimics prediction, passing an example from the root node down to a node
535-
* at the current level being trained; that node's index is returned.
536-
*
537-
* @return Leaf index if the data point reaches a leaf.
538-
* Otherwise, last node reachable in tree matching this example.
539-
* Note: This is the global node index, i.e., the index used in the tree.
540-
* This index is different from the index used during training a particular
541-
* set of nodes in a (level, group).
542-
*/
543-
def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = {
544-
if (node.isLeaf) {
545-
node.id
546-
} else {
547-
val featureIndex = node.split.get.feature
548-
val splitLeft = node.split.get.featureType match {
549-
case Continuous => {
550-
val binIndex = binnedFeatures(featureIndex)
551-
val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
552-
// bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
553-
// We do not need to check lowSplit since bins are separated by splits.
554-
featureValueUpperBound <= node.split.get.threshold
555-
}
556-
case Categorical => {
557-
val featureValue = if (metadata.isUnordered(featureIndex)) {
558-
binnedFeatures(featureIndex)
559-
} else {
560-
val binIndex = binnedFeatures(featureIndex)
561-
bins(featureIndex)(binIndex).category
562-
}
563-
node.split.get.categories.contains(featureValue)
564-
}
565-
case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")
566-
}
567-
if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
568-
// Return index from next layer of nodes to train
569-
if (splitLeft) {
570-
Node.leftChildIndex(node.id)
571-
} else {
572-
Node.rightChildIndex(node.id)
573-
}
574-
} else {
575-
if (splitLeft) {
576-
predictNodeIndex(node.leftNode.get, binnedFeatures)
577-
} else {
578-
predictNodeIndex(node.rightNode.get, binnedFeatures)
579-
}
580-
}
581-
}
582-
}
583-
584655
// Used for treePointToNodeIndex
585-
val levelOffset = Node.maxNodesInSubtree(level - 1)
656+
val globalNodeIndexOffset = Node.maxNodesInSubtree(level - 1) + groupShift + 1
586657

587658
/**
588659
* Find the node index for the given example.
@@ -593,90 +664,12 @@ object DecisionTree extends Serializable with Logging {
593664
if (level == 0) {
594665
0
595666
} else {
596-
val globalNodeIndex = predictNodeIndex(nodes(1), treePoint.binnedFeatures)
667+
val globalNodeIndex = predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
597668
// Get index for this (level, group).
598669
// - levelOffset corrects for nodes before this level.
599670
// - groupShift corrects for groups in this level before the current group.
600671
// - 1 corrects for the fact that globalNodeIndex starts at 1, not 0.
601-
globalNodeIndex - levelOffset - groupShift - 1
602-
}
603-
}
604-
605-
606-
val rightChildShift = numClasses * numBins * numFeatures * numNodes
607-
608-
/**
609-
* Helper for binSeqOp.
610-
*
611-
* @param agg Array storing aggregate calculation.
612-
* For ordered features, this is of size:
613-
* numClasses * numBins * numFeatures * numNodes.
614-
* For unordered features, this is of size:
615-
* 2 * numClasses * numBins * numFeatures * numNodes.
616-
* @param treePoint Data point being aggregated.
617-
* @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
618-
*/
619-
def someUnorderedBinSeqOp(
620-
agg: Array[Array[Array[ImpurityAggregator]]],
621-
treePoint: TreePoint,
622-
nodeIndex: Int): Unit = {
623-
val label = treePoint.label
624-
// Iterate over all features.
625-
var featureIndex = 0
626-
while (featureIndex < numFeatures) {
627-
if (metadata.isUnordered(featureIndex)) {
628-
// Unordered feature
629-
val featureValue = treePoint.binnedFeatures(featureIndex)
630-
// Update the left or right count for one bin.
631-
// Find all matching bins and increment their values.
632-
val numCategoricalBins = metadata.numBins(featureIndex)
633-
var binIndex = 0
634-
while (binIndex < numCategoricalBins) {
635-
if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) {
636-
agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label)
637-
} else {
638-
agg(nodeIndex)(featureIndex)(numCategoricalBins + binIndex).add(treePoint.label)
639-
}
640-
binIndex += 1
641-
}
642-
} else {
643-
// Ordered feature
644-
val binIndex = treePoint.binnedFeatures(featureIndex)
645-
agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label)
646-
}
647-
featureIndex += 1
648-
}
649-
}
650-
651-
/**
652-
* Helper for binSeqOp: for regression and for classification with only ordered features.
653-
*
654-
* Performs a sequential aggregation over a partition for regression.
655-
* For l nodes, k features,
656-
* the count, sum, sum of squares of one of the p bins is incremented.
657-
*
658-
* @param agg Array storing aggregate calculation, updated by this function.
659-
* Size: 3 * numBins * numFeatures * numNodes
660-
* @param treePoint Data point being aggregated.
661-
* @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
662-
* @return agg
663-
*/
664-
def orderedBinSeqOp(
665-
agg: Array[Array[Array[ImpurityAggregator]]],
666-
treePoint: TreePoint,
667-
nodeIndex: Int): Unit = {
668-
val label = treePoint.label
669-
// Iterate over all features.
670-
var featureIndex = 0
671-
while (featureIndex < numFeatures) {
672-
// Update count, sum, and sum^2 for one bin.
673-
val binIndex = treePoint.binnedFeatures(featureIndex)
674-
if (binIndex >= agg(nodeIndex)(featureIndex).size) {
675-
throw new RuntimeException(
676-
s"binIndex: $binIndex, agg(nodeIndex)(featureIndex).size = ${agg(nodeIndex)(featureIndex).size}")
677-
}
678-
agg(nodeIndex)(featureIndex)(binIndex).add(label)
679-
featureIndex += 1
672+
globalNodeIndex - globalNodeIndexOffset
680673
}
681674
}
682675

@@ -709,7 +702,7 @@ object DecisionTree extends Serializable with Logging {
709702
if (metadata.unorderedFeatures.isEmpty) {
710703
orderedBinSeqOp(agg, treePoint, nodeIndex)
711704
} else {
712-
someUnorderedBinSeqOp(agg, treePoint, nodeIndex)
705+
someUnorderedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures)
713706
}
714707
}
715708
agg
@@ -751,7 +744,6 @@ object DecisionTree extends Serializable with Logging {
751744
// Calculate best splits for all nodes at a given level
752745
timer.start("chooseSplits")
753746
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
754-
val globalNodeIndexOffset = Node.maxNodesInSubtree(level - 1) + groupShift + 1
755747
// Iterating over all nodes at this level
756748
var nodeIndex = 0
757749
while (nodeIndex < numNodes) {

0 commit comments

Comments
 (0)