Skip to content

Commit bb68f47

Browse files
techaddictrxin
authored andcommitted
[Fix #79] Replace Breakable For Loops By While Loops
Author: Sandeep <sandeep@techaddict.me> Closes #503 from techaddict/fix-79 and squashes the following commits: e3f6746 [Sandeep] Style changes 07a4f6b [Sandeep] for loop to While loop 0a6d8e9 [Sandeep] Breakable for loop to While loop
1 parent 6ab7578 commit bb68f47

File tree

1 file changed

+31
-29
lines changed

1 file changed

+31
-29
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.mllib.tree
1919

20-
import scala.util.control.Breaks._
21-
2220
import org.apache.spark.annotation.Experimental
2321
import org.apache.spark.{Logging, SparkContext}
2422
import org.apache.spark.SparkContext._
@@ -82,31 +80,34 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
8280
* still survived the filters of the parent nodes.
8381
*/
8482

85-
// TODO: Convert for loop to while loop
86-
breakable {
87-
for (level <- 0 until maxDepth) {
88-
89-
logDebug("#####################################")
90-
logDebug("level = " + level)
91-
logDebug("#####################################")
92-
93-
// Find best split for all nodes at a level.
94-
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
95-
level, filters, splits, bins)
96-
97-
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
98-
// Extract info for nodes at the current level.
99-
extractNodeInfo(nodeSplitStats, level, index, nodes)
100-
// Extract info for nodes at the next lower level.
101-
extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
102-
filters)
103-
logDebug("final best split = " + nodeSplitStats._1)
104-
}
105-
require(scala.math.pow(2, level) == splitsStatsForLevel.length)
106-
// Check whether all the nodes at the current level at leaves.
107-
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
108-
logDebug("all leaf = " + allLeaf)
109-
if (allLeaf) break // no more tree construction
83+
var level = 0
84+
var break = false
85+
while (level < maxDepth && !break) {
86+
87+
logDebug("#####################################")
88+
logDebug("level = " + level)
89+
logDebug("#####################################")
90+
91+
// Find best split for all nodes at a level.
92+
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
93+
level, filters, splits, bins)
94+
95+
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
96+
// Extract info for nodes at the current level.
97+
extractNodeInfo(nodeSplitStats, level, index, nodes)
98+
// Extract info for nodes at the next lower level.
99+
extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
100+
filters)
101+
logDebug("final best split = " + nodeSplitStats._1)
102+
}
103+
require(scala.math.pow(2, level) == splitsStatsForLevel.length)
104+
// Check whether all the nodes at the current level at leaves.
105+
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
106+
logDebug("all leaf = " + allLeaf)
107+
if (allLeaf) {
108+
break = true // no more tree construction
109+
} else {
110+
level += 1
110111
}
111112
}
112113

@@ -146,8 +147,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
146147
parentImpurities: Array[Double],
147148
filters: Array[List[Filter]]): Unit = {
148149
// 0 corresponds to the left child node and 1 corresponds to the right child node.
149-
// TODO: Convert to while loop
150-
for (i <- 0 to 1) {
150+
var i = 0
151+
while (i <= 1) {
151152
// Calculate the index of the node from the node level and the index at the current level.
152153
val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i
153154
if (level < maxDepth - 1) {
@@ -166,6 +167,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
166167
logDebug("Filter = " + filter)
167168
}
168169
}
170+
i += 1
169171
}
170172
}
171173
}

0 commit comments

Comments
 (0)