Skip to content

Commit ac0b9f8

Browse files
committed
Small updates based on code review.
Main change: Now using << instead of math.pow.
1 parent db0d773 commit ac0b9f8

File tree

3 files changed

+33
-42
lines changed

3 files changed

+33
-42
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
8282
// depth of the decision tree
8383
val maxDepth = strategy.maxDepth
8484
// the max number of nodes possible given the depth of the tree
85-
val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1
85+
val maxNumNodes = (2 << maxDepth) - 1
8686
// Initialize an array to hold parent impurity calculations for each node.
8787
val parentImpurities = new Array[Double](maxNumNodes)
8888
// dummy value for top node (updated during first split calculation)
8989
val nodes = new Array[Node](maxNumNodes)
90-
val nodesInTree = Array.fill[Boolean](maxNumNodes)(false) // put into nodes array later?
91-
nodesInTree(0) = true
9290

9391
// Calculate level for single group construction
9492

@@ -129,7 +127,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
129127
metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
130128
timer.stop("findBestSplits")
131129

132-
val levelNodeIndexOffset = math.pow(2, level).toInt - 1
130+
val levelNodeIndexOffset = (1 << level) - 1
133131
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
134132
val nodeIndex = levelNodeIndexOffset + index
135133
val isLeftChild = level != 0 && nodeIndex % 2 == 1
@@ -138,8 +136,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
138136
} else {
139137
(nodeIndex - 2) / 2
140138
}
141-
// if (level == 0 || (nodesInTree(parentNodeIndex) && !nodes(parentNodeIndex).isLeaf))
142-
// TODO: Use above check to skip unused branch of tree
143139
// Extract info for this node (index) at the current level.
144140
timer.start("extractNodeInfo")
145141
extractNodeInfo(nodeSplitStats, level, index, nodes)
@@ -158,7 +154,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
158154
timer.stop("extractInfoForLowerLevels")
159155
logDebug("final best split = " + nodeSplitStats._1)
160156
}
161-
require(math.pow(2, level) == splitsStatsForLevel.length)
157+
require((1 << level) == splitsStatsForLevel.length)
162158
// Check whether all the nodes at the current level at leaves.
163159
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
164160
logDebug("all leaf = " + allLeaf)
@@ -196,7 +192,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
196192
nodes: Array[Node]): Unit = {
197193
val split = nodeSplitStats._1
198194
val stats = nodeSplitStats._2
199-
val nodeIndex = math.pow(2, level).toInt - 1 + index
195+
val nodeIndex = (1 << level) - 1 + index
200196
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
201197
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
202198
logDebug("Node = " + node)
@@ -212,24 +208,20 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
212208
maxDepth: Int,
213209
nodeSplitStats: (Split, InformationGainStats),
214210
parentImpurities: Array[Double]): Unit = {
211+
215212
if (level >= maxDepth) {
216213
return
217214
}
218-
// 0 corresponds to the left child node and 1 corresponds to the right child node.
219-
var i = 0
220-
while (i <= 1) {
221-
// Calculate the index of the node from the node level and the index at the current level.
222-
val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i
223-
val impurity = if (i == 0) {
224-
nodeSplitStats._2.leftImpurity
225-
} else {
226-
nodeSplitStats._2.rightImpurity
227-
}
228-
logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
229-
// noting the parent impurities
230-
parentImpurities(nodeIndex) = impurity
231-
i += 1
232-
}
215+
216+
val leftNodeIndex = (2 << level) - 1 + 2 * index
217+
val leftImpurity = nodeSplitStats._2.leftImpurity
218+
logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity)
219+
parentImpurities(leftNodeIndex) = leftImpurity
220+
221+
val rightNodeIndex = leftNodeIndex + 1
222+
val rightImpurity = nodeSplitStats._2.rightImpurity
223+
logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity)
224+
parentImpurities(rightNodeIndex) = rightImpurity
233225
}
234226
}
235227

@@ -464,7 +456,7 @@ object DecisionTree extends Serializable with Logging {
464456
// the nodes are divided into multiple groups at each level with the number of groups
465457
// increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
466458
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
467-
val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt
459+
val numGroups = 1 << level - maxLevelForSingleGroup
468460
logDebug("numGroups = " + numGroups)
469461
var bestSplits = new Array[(Split, InformationGainStats)](0)
470462
// Iterate over each group of nodes at a level.
@@ -534,7 +526,7 @@ object DecisionTree extends Serializable with Logging {
534526

535527
// numNodes: Number of nodes in this (level of tree, group),
536528
// where nodes at deeper (larger) levels may be divided into groups.
537-
val numNodes = math.pow(2, level).toInt / numGroups
529+
val numNodes = (1 << level) / numGroups
538530
logDebug("numNodes = " + numNodes)
539531

540532
// Find the number of features by looking at the first sample.
@@ -563,24 +555,24 @@ object DecisionTree extends Serializable with Logging {
563555
* @return Leaf index if the data point reaches a leaf.
564556
* Otherwise, last node reachable in tree matching this example.
565557
*/
566-
def predictNodeIndex(node: Node, features: Array[Int]): Int = {
558+
def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = {
567559
if (node.isLeaf) {
568560
node.id
569561
} else {
570562
val featureIndex = node.split.get.feature
571563
val splitLeft = node.split.get.featureType match {
572564
case Continuous => {
573-
val binIndex = features(featureIndex)
565+
val binIndex = binnedFeatures(featureIndex)
574566
val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
575567
// bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
576568
// We do not need to check lowSplit since bins are separated by splits.
577569
featureValueUpperBound <= node.split.get.threshold
578570
}
579571
case Categorical => {
580572
val featureValue = if (metadata.isUnordered(featureIndex)) {
581-
features(featureIndex)
573+
binnedFeatures(featureIndex)
582574
} else {
583-
val binIndex = features(featureIndex)
575+
val binIndex = binnedFeatures(featureIndex)
584576
bins(featureIndex)(binIndex).category
585577
}
586578
node.split.get.categories.contains(featureValue)
@@ -596,9 +588,9 @@ object DecisionTree extends Serializable with Logging {
596588
}
597589
} else {
598590
if (splitLeft) {
599-
predictNodeIndex(node.leftNode.get, features)
591+
predictNodeIndex(node.leftNode.get, binnedFeatures)
600592
} else {
601-
predictNodeIndex(node.rightNode.get, features)
593+
predictNodeIndex(node.rightNode.get, binnedFeatures)
602594
}
603595
}
604596
}
@@ -613,7 +605,7 @@ object DecisionTree extends Serializable with Logging {
613605
}
614606

615607
// Used for treePointToNodeIndex
616-
val levelOffset = (math.pow(2, level) - 1).toInt
608+
val levelOffset = (1 << level) - 1
617609

618610
/**
619611
* Find the node (indexed from 0 at the start of this level) for the given example.
@@ -678,7 +670,7 @@ object DecisionTree extends Serializable with Logging {
678670
treePoint.label.toInt
679671
// Find all matching bins and increment their values
680672
val featureCategories = metadata.featureArity(featureIndex)
681-
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
673+
val numCategoricalBins = (1 << featureCategories - 1) - 1
682674
var binIndex = 0
683675
while (binIndex < numCategoricalBins) {
684676
val aggIndex = aggShift + binIndex * numClasses
@@ -764,9 +756,9 @@ object DecisionTree extends Serializable with Logging {
764756
3 * numBins * numFeatures * nodeIndex +
765757
3 * numBins * featureIndex +
766758
3 * binIndex
767-
agg(aggIndex) = agg(aggIndex) + 1
768-
agg(aggIndex + 1) = agg(aggIndex + 1) + label
769-
agg(aggIndex + 2) = agg(aggIndex + 2) + label * label
759+
agg(aggIndex) += 1
760+
agg(aggIndex + 1) += label
761+
agg(aggIndex + 2) += label * label
770762
featureIndex += 1
771763
}
772764
}
@@ -1165,7 +1157,7 @@ object DecisionTree extends Serializable with Logging {
11651157
// Categorical feature
11661158
val featureCategories = metadata.featureArity(featureIndex)
11671159
if (metadata.isUnordered(featureIndex)) {
1168-
math.pow(2.0, featureCategories - 1).toInt - 1
1160+
(1 << featureCategories - 1) - 1
11691161
} else {
11701162
featureCategories
11711163
}
@@ -1257,7 +1249,7 @@ object DecisionTree extends Serializable with Logging {
12571249
// Iterating over all nodes at this level
12581250
var node = 0
12591251
while (node < numNodes) {
1260-
val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift
1252+
val nodeImpurityIndex = (1 << level) - 1 + node + groupShift
12611253
val binsForNode: Array[Double] = getBinDataForNode(node)
12621254
logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
12631255
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
@@ -1302,7 +1294,7 @@ object DecisionTree extends Serializable with Logging {
13021294
* For multiclass classification with a low-arity feature
13031295
* (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
13041296
* the feature is split based on subsets of categories.
1305-
* There are math.pow(2, maxFeatureValue - 1) - 1 splits.
1297+
* There are (1 << maxFeatureValue - 1) - 1 splits.
13061298
* (b) "ordered features"
13071299
* For regression and binary classification,
13081300
* and for multiclass classification with a high-arity feature,
@@ -1391,7 +1383,7 @@ object DecisionTree extends Serializable with Logging {
13911383
if (metadata.isUnordered(featureIndex)) {
13921384
// 2^(maxFeatureValue- 1) - 1 combinations
13931385
var index = 0
1394-
while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {
1386+
while (index < (1 << featureCategories - 1) - 1) {
13951387
val categories: List[Double]
13961388
= extractMultiClassCategories(index + 1, featureCategories)
13971389
splits(featureIndex)(index)

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTMetadata.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ private[tree] object DTMetadata {
7373
val unorderedFeatures = new mutable.HashSet[Int]()
7474
if (numClasses > 2) {
7575
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
76-
val numUnorderedBins = math.pow(2, k - 1) - 1
76+
val numUnorderedBins = (1 << k - 1) - 1
7777
if (numUnorderedBins < maxBins) {
7878
unorderedFeatures.add(f)
7979
} else {

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.mllib.tree.impl
1919

2020
import org.apache.spark.mllib.regression.LabeledPoint
21-
import org.apache.spark.mllib.tree.configuration.Strategy
2221
import org.apache.spark.mllib.tree.model.Bin
2322
import org.apache.spark.rdd.RDD
2423

0 commit comments

Comments
 (0)