17
17
18
18
package org .apache .spark .mllib .tree
19
19
20
- import scala .util .control .Breaks ._
21
-
22
20
import org .apache .spark .annotation .Experimental
23
21
import org .apache .spark .{Logging , SparkContext }
24
22
import org .apache .spark .SparkContext ._
@@ -82,31 +80,34 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
82
80
* still survived the filters of the parent nodes.
83
81
*/
84
82
85
- // TODO: Convert for loop to while loop
86
- breakable {
87
- for (level <- 0 until maxDepth) {
88
-
89
- logDebug(" #####################################" )
90
- logDebug(" level = " + level)
91
- logDebug(" #####################################" )
92
-
93
- // Find best split for all nodes at a level.
94
- val splitsStatsForLevel = DecisionTree .findBestSplits(input, parentImpurities, strategy,
95
- level, filters, splits, bins)
96
-
97
- for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
98
- // Extract info for nodes at the current level.
99
- extractNodeInfo(nodeSplitStats, level, index, nodes)
100
- // Extract info for nodes at the next lower level.
101
- extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
102
- filters)
103
- logDebug(" final best split = " + nodeSplitStats._1)
104
- }
105
- require(scala.math.pow(2 , level) == splitsStatsForLevel.length)
106
- // Check whether all the nodes at the current level at leaves.
107
- val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0 )
108
- logDebug(" all leaf = " + allLeaf)
109
- if (allLeaf) break // no more tree construction
83
+ var level = 0
84
+ var break = false
85
+ while (level < maxDepth && ! break) {
86
+
87
+ logDebug(" #####################################" )
88
+ logDebug(" level = " + level)
89
+ logDebug(" #####################################" )
90
+
91
+ // Find best split for all nodes at a level.
92
+ val splitsStatsForLevel = DecisionTree .findBestSplits(input, parentImpurities, strategy,
93
+ level, filters, splits, bins)
94
+
95
+ for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
96
+ // Extract info for nodes at the current level.
97
+ extractNodeInfo(nodeSplitStats, level, index, nodes)
98
+ // Extract info for nodes at the next lower level.
99
+ extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
100
+ filters)
101
+ logDebug(" final best split = " + nodeSplitStats._1)
102
+ }
103
+ require(scala.math.pow(2 , level) == splitsStatsForLevel.length)
104
+ // Check whether all the nodes at the current level at leaves.
105
+ val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0 )
106
+ logDebug(" all leaf = " + allLeaf)
107
+ if (allLeaf) {
108
+ break = true // no more tree construction
109
+ } else {
110
+ level += 1
110
111
}
111
112
}
112
113
@@ -146,8 +147,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
146
147
parentImpurities : Array [Double ],
147
148
filters : Array [List [Filter ]]): Unit = {
148
149
// 0 corresponds to the left child node and 1 corresponds to the right child node.
149
- // TODO: Convert to while loop
150
- for (i <- 0 to 1 ) {
150
+ var i = 0
151
+ while (i <= 1 ) {
151
152
// Calculate the index of the node from the node level and the index at the current level.
152
153
val nodeIndex = scala.math.pow(2 , level + 1 ).toInt - 1 + 2 * index + i
153
154
if (level < maxDepth - 1 ) {
@@ -166,6 +167,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
166
167
logDebug(" Filter = " + filter)
167
168
}
168
169
}
170
+ i += 1
169
171
}
170
172
}
171
173
}
0 commit comments