@@ -82,13 +82,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
8282 // depth of the decision tree
8383 val maxDepth = strategy.maxDepth
8484 // the max number of nodes possible given the depth of the tree
85- val maxNumNodes = math.pow( 2 , maxDepth + 1 ).toInt - 1
85+ val maxNumNodes = ( 2 << maxDepth) - 1
8686 // Initialize an array to hold parent impurity calculations for each node.
8787 val parentImpurities = new Array [Double ](maxNumNodes)
8888 // dummy value for top node (updated during first split calculation)
8989 val nodes = new Array [Node ](maxNumNodes)
90- val nodesInTree = Array .fill[Boolean ](maxNumNodes)(false ) // put into nodes array later?
91- nodesInTree(0 ) = true
9290
9391 // Calculate level for single group construction
9492
@@ -129,7 +127,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
129127 metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
130128 timer.stop(" findBestSplits" )
131129
132- val levelNodeIndexOffset = math.pow( 2 , level).toInt - 1
130+ val levelNodeIndexOffset = ( 1 << level) - 1
133131 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
134132 val nodeIndex = levelNodeIndexOffset + index
135133 val isLeftChild = level != 0 && nodeIndex % 2 == 1
@@ -138,8 +136,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
138136 } else {
139137 (nodeIndex - 2 ) / 2
140138 }
141- // if (level == 0 || (nodesInTree(parentNodeIndex) && !nodes(parentNodeIndex).isLeaf))
142- // TODO: Use above check to skip unused branch of tree
143139 // Extract info for this node (index) at the current level.
144140 timer.start(" extractNodeInfo" )
145141 extractNodeInfo(nodeSplitStats, level, index, nodes)
@@ -158,7 +154,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
158154 timer.stop(" extractInfoForLowerLevels" )
159155 logDebug(" final best split = " + nodeSplitStats._1)
160156 }
161- require(math.pow( 2 , level) == splitsStatsForLevel.length)
157+ require(( 1 << level) == splitsStatsForLevel.length)
162158 // Check whether all the nodes at the current level at leaves.
163159 val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0 )
164160 logDebug(" all leaf = " + allLeaf)
@@ -196,7 +192,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
196192 nodes : Array [Node ]): Unit = {
197193 val split = nodeSplitStats._1
198194 val stats = nodeSplitStats._2
199- val nodeIndex = math.pow( 2 , level).toInt - 1 + index
195+ val nodeIndex = ( 1 << level) - 1 + index
200196 val isLeaf = (stats.gain <= 0 ) || (level == strategy.maxDepth)
201197 val node = new Node (nodeIndex, stats.predict, isLeaf, Some (split), None , None , Some (stats))
202198 logDebug(" Node = " + node)
@@ -212,24 +208,20 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
212208 maxDepth : Int ,
213209 nodeSplitStats : (Split , InformationGainStats ),
214210 parentImpurities : Array [Double ]): Unit = {
211+
215212 if (level >= maxDepth) {
216213 return
217214 }
218- // 0 corresponds to the left child node and 1 corresponds to the right child node.
219- var i = 0
220- while (i <= 1 ) {
221- // Calculate the index of the node from the node level and the index at the current level.
222- val nodeIndex = math.pow(2 , level + 1 ).toInt - 1 + 2 * index + i
223- val impurity = if (i == 0 ) {
224- nodeSplitStats._2.leftImpurity
225- } else {
226- nodeSplitStats._2.rightImpurity
227- }
228- logDebug(" nodeIndex = " + nodeIndex + " , impurity = " + impurity)
229- // noting the parent impurities
230- parentImpurities(nodeIndex) = impurity
231- i += 1
232- }
215+
216+ val leftNodeIndex = (2 << level) - 1 + 2 * index
217+ val leftImpurity = nodeSplitStats._2.leftImpurity
218+ logDebug(" leftNodeIndex = " + leftNodeIndex + " , impurity = " + leftImpurity)
219+ parentImpurities(leftNodeIndex) = leftImpurity
220+
221+ val rightNodeIndex = leftNodeIndex + 1
222+ val rightImpurity = nodeSplitStats._2.rightImpurity
223+ logDebug(" rightNodeIndex = " + rightNodeIndex + " , impurity = " + rightImpurity)
224+ parentImpurities(rightNodeIndex) = rightImpurity
233225 }
234226}
235227
@@ -464,7 +456,7 @@ object DecisionTree extends Serializable with Logging {
464456 // the nodes are divided into multiple groups at each level with the number of groups
465457 // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
466458 // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
467- val numGroups = math.pow( 2 , level - maxLevelForSingleGroup).toInt
459+ val numGroups = 1 << level - maxLevelForSingleGroup
468460 logDebug(" numGroups = " + numGroups)
469461 var bestSplits = new Array [(Split , InformationGainStats )](0 )
470462 // Iterate over each group of nodes at a level.
@@ -534,7 +526,7 @@ object DecisionTree extends Serializable with Logging {
534526
535527 // numNodes: Number of nodes in this (level of tree, group),
536528 // where nodes at deeper (larger) levels may be divided into groups.
537- val numNodes = math.pow( 2 , level).toInt / numGroups
529+ val numNodes = ( 1 << level) / numGroups
538530 logDebug(" numNodes = " + numNodes)
539531
540532 // Find the number of features by looking at the first sample.
@@ -563,24 +555,24 @@ object DecisionTree extends Serializable with Logging {
563555 * @return Leaf index if the data point reaches a leaf.
564556 * Otherwise, last node reachable in tree matching this example.
565557 */
566- def predictNodeIndex (node : Node , features : Array [Int ]): Int = {
558+ def predictNodeIndex (node : Node , binnedFeatures : Array [Int ]): Int = {
567559 if (node.isLeaf) {
568560 node.id
569561 } else {
570562 val featureIndex = node.split.get.feature
571563 val splitLeft = node.split.get.featureType match {
572564 case Continuous => {
573- val binIndex = features (featureIndex)
565+ val binIndex = binnedFeatures (featureIndex)
574566 val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
575567 // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
576568 // We do not need to check lowSplit since bins are separated by splits.
577569 featureValueUpperBound <= node.split.get.threshold
578570 }
579571 case Categorical => {
580572 val featureValue = if (metadata.isUnordered(featureIndex)) {
581- features (featureIndex)
573+ binnedFeatures (featureIndex)
582574 } else {
583- val binIndex = features (featureIndex)
575+ val binIndex = binnedFeatures (featureIndex)
584576 bins(featureIndex)(binIndex).category
585577 }
586578 node.split.get.categories.contains(featureValue)
@@ -596,9 +588,9 @@ object DecisionTree extends Serializable with Logging {
596588 }
597589 } else {
598590 if (splitLeft) {
599- predictNodeIndex(node.leftNode.get, features )
591+ predictNodeIndex(node.leftNode.get, binnedFeatures )
600592 } else {
601- predictNodeIndex(node.rightNode.get, features )
593+ predictNodeIndex(node.rightNode.get, binnedFeatures )
602594 }
603595 }
604596 }
@@ -613,7 +605,7 @@ object DecisionTree extends Serializable with Logging {
613605 }
614606
615607 // Used for treePointToNodeIndex
616- val levelOffset = (math.pow( 2 , level) - 1 ).toInt
608+ val levelOffset = (1 << level) - 1
617609
618610 /**
619611 * Find the node (indexed from 0 at the start of this level) for the given example.
@@ -678,7 +670,7 @@ object DecisionTree extends Serializable with Logging {
678670 treePoint.label.toInt
679671 // Find all matching bins and increment their values
680672 val featureCategories = metadata.featureArity(featureIndex)
681- val numCategoricalBins = math.pow( 2.0 , featureCategories - 1 ).toInt - 1
673+ val numCategoricalBins = ( 1 << featureCategories - 1 ) - 1
682674 var binIndex = 0
683675 while (binIndex < numCategoricalBins) {
684676 val aggIndex = aggShift + binIndex * numClasses
@@ -764,9 +756,9 @@ object DecisionTree extends Serializable with Logging {
764756 3 * numBins * numFeatures * nodeIndex +
765757 3 * numBins * featureIndex +
766758 3 * binIndex
767- agg(aggIndex) = agg(aggIndex) + 1
768- agg(aggIndex + 1 ) = agg(aggIndex + 1 ) + label
769- agg(aggIndex + 2 ) = agg(aggIndex + 2 ) + label * label
759+ agg(aggIndex) += 1
760+ agg(aggIndex + 1 ) += label
761+ agg(aggIndex + 2 ) += label * label
770762 featureIndex += 1
771763 }
772764 }
@@ -1165,7 +1157,7 @@ object DecisionTree extends Serializable with Logging {
11651157 // Categorical feature
11661158 val featureCategories = metadata.featureArity(featureIndex)
11671159 if (metadata.isUnordered(featureIndex)) {
1168- math.pow( 2.0 , featureCategories - 1 ).toInt - 1
1160+ ( 1 << featureCategories - 1 ) - 1
11691161 } else {
11701162 featureCategories
11711163 }
@@ -1257,7 +1249,7 @@ object DecisionTree extends Serializable with Logging {
12571249 // Iterating over all nodes at this level
12581250 var node = 0
12591251 while (node < numNodes) {
1260- val nodeImpurityIndex = math.pow( 2 , level).toInt - 1 + node + groupShift
1252+ val nodeImpurityIndex = ( 1 << level) - 1 + node + groupShift
12611253 val binsForNode : Array [Double ] = getBinDataForNode(node)
12621254 logDebug(" nodeImpurityIndex = " + nodeImpurityIndex)
12631255 val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
@@ -1302,7 +1294,7 @@ object DecisionTree extends Serializable with Logging {
13021294 * For multiclass classification with a low-arity feature
13031295 * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
13041296 * the feature is split based on subsets of categories.
1305- * There are math.pow(2, maxFeatureValue - 1) - 1 splits.
1297+ * There are (1 << maxFeatureValue - 1) - 1 splits.
13061298 * (b) "ordered features"
13071299 * For regression and binary classification,
13081300 * and for multiclass classification with a high-arity feature,
@@ -1391,7 +1383,7 @@ object DecisionTree extends Serializable with Logging {
13911383 if (metadata.isUnordered(featureIndex)) {
13921384 // 2^(maxFeatureValue- 1) - 1 combinations
13931385 var index = 0
1394- while (index < math.pow( 2.0 , featureCategories - 1 ).toInt - 1 ) {
1386+ while (index < ( 1 << featureCategories - 1 ) - 1 ) {
13951387 val categories : List [Double ]
13961388 = extractMultiClassCategories(index + 1 , featureCategories)
13971389 splits(featureIndex)(index)
0 commit comments