Skip to content

Commit bce835f

Browse files
committed
code cleanup
1 parent 7e5f08c commit bce835f

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
144144
new DecisionTreeModel(topNode, strategy.algo)
145145
}
146146

147+
// TODO: Unit test this
147148
/**
148149
* Extract the decision tree node information for the given tree level and node index
149150
*/
@@ -161,6 +162,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
161162
nodes(nodeIndex) = node
162163
}
163164

165+
// TODO: Unit test this
164166
/**
165167
* Extract the decision tree node information for the children of the node
166168
*/
@@ -458,6 +460,8 @@ object DecisionTree extends Serializable with Logging {
458460
logDebug("numClasses = " + numClasses)
459461
val labelWeights = strategy.labelWeights
460462
logDebug("labelWeights = " + labelWeights)
463+
val isMulticlassClassification = strategy.isMulticlassClassification
464+
logDebug("isMulticlassClassification = " + isMulticlassClassification)
461465

462466

463467
// shift when more than one group is used at deep tree level
@@ -582,7 +586,7 @@ object DecisionTree extends Serializable with Logging {
582586
} else {
583587
// Perform sequential search to find bin for categorical features.
584588
val binIndex = {
585-
if (strategy.isMultiClassification) {
589+
if (isMulticlassClassification) {
586590
sequentialBinSearchForCategoricalFeatureInBinaryClassification()
587591
} else {
588592
sequentialBinSearchForCategoricalFeatureInMultiClassClassification()
@@ -606,7 +610,9 @@ object DecisionTree extends Serializable with Logging {
606610
def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = {
607611
// Calculate bin index and label per feature per node.
608612
val arr = new Array[Double](1 + (numFeatures * numNodes))
613+
// First element of the array is the label of the instance.
609614
arr(0) = labeledPoint.label
615+
// Iterate over nodes.
610616
var nodeIndex = 0
611617
while (nodeIndex < numNodes) {
612618
val parentFilters = findParentFilters(nodeIndex)
@@ -629,7 +635,10 @@ object DecisionTree extends Serializable with Logging {
629635
arr
630636
}
631637

632-
/**
638+
// Find feature bins for all nodes at a level.
639+
val binMappedRDD = input.map(x => findBinsForLevel(x))
640+
641+
/**
633642
* Performs a sequential aggregation over a partition for classification. For l nodes,
634643
* k features, either the left count or the right count of one of the p bins is
635644
* incremented based upon whether the feature is classified as 0 or 1.
@@ -663,7 +672,7 @@ object DecisionTree extends Serializable with Logging {
663672
label.toInt match {
664673
case n: Int =>
665674
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
666-
if (!isFeatureContinuous && strategy.isMultiClassification) {
675+
if (!isFeatureContinuous && isMulticlassClassification) {
667676
// Find all matching bins and increment their values
668677
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
669678
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
@@ -736,7 +745,6 @@ object DecisionTree extends Serializable with Logging {
736745
agg
737746
}
738747

739-
// TODO: Double-check this
740748
// Calculate bin aggregate length for classification or regression.
741749
val binAggregateLength = strategy.algo match {
742750
case Classification => numClasses * numBins * numFeatures * numNodes
@@ -760,9 +768,6 @@ object DecisionTree extends Serializable with Logging {
760768
combinedAggregate
761769
}
762770

763-
// Find feature bins for all nodes at a level.
764-
val binMappedRDD = input.map(x => findBinsForLevel(x))
765-
766771
// Calculate bin aggregates.
767772
val binAggregates = {
768773
binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
@@ -922,7 +927,7 @@ object DecisionTree extends Serializable with Logging {
922927
val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
923928
val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
924929

925-
if (strategy.isMultiClassification) {
930+
if (isMulticlassClassification) {
926931
var featureIndex = 0
927932
while (featureIndex < numFeatures){
928933
var splitIndex = 0
@@ -1096,7 +1101,7 @@ object DecisionTree extends Serializable with Logging {
10961101
numBins - 1
10971102
} else { // Categorical feature
10981103
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1099-
if (strategy.isMultiClassification) {
1104+
if (isMulticlassClassification) {
11001105
math.pow(2.0, featureCategories - 1).toInt - 1
11011106
} else { // Binary classification
11021107
featureCategories
@@ -1177,6 +1182,9 @@ object DecisionTree extends Serializable with Logging {
11771182
val maxBins = strategy.maxBins
11781183
val numBins = if (maxBins <= count) maxBins else count.toInt
11791184
logDebug("numBins = " + numBins)
1185+
val isMulticlassClassification = strategy.isMulticlassClassification
1186+
logDebug("isMulticlassClassification = " + isMulticlassClassification)
1187+
11801188

11811189
/*
11821190
* Ensure #bins is always greater than the categories. For multiclass classification,
@@ -1187,7 +1195,7 @@ object DecisionTree extends Serializable with Logging {
11871195
if (strategy.categoricalFeaturesInfo.size > 0) {
11881196
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
11891197
require(numBins > maxCategoriesForFeatures)
1190-
if (strategy.isMultiClassification) {
1198+
if (isMulticlassClassification) {
11911199
require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1)
11921200
}
11931201
}
@@ -1230,7 +1238,7 @@ object DecisionTree extends Serializable with Logging {
12301238
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
12311239

12321240
// Use different bin/split calculation strategy for multiclass classification
1233-
if (strategy.isMultiClassification) {
1241+
if (isMulticlassClassification) {
12341242
// 2^(maxFeatureValue- 1) - 1 combinations
12351243
var index = 0
12361244
while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,6 @@ class Strategy (
5858
val labelWeights: Map[Int, Int] = Map[Int, Int]()) extends Serializable {
5959

6060
require(numClassesForClassification >= 2)
61-
val isMultiClassification = numClassesForClassification > 2
61+
val isMulticlassClassification = numClassesForClassification > 2
6262

6363
}

0 commit comments

Comments
 (0)