Skip to content

Commit 6d32ccd

Browse files
committed
In DecisionTree.binsToBestSplit, changed loops to iterators to shorten code.
1 parent 807cd00 commit 6d32ccd

File tree

1 file changed

+14
-41
lines changed

1 file changed

+14
-41
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -813,15 +813,9 @@ object DecisionTree extends Serializable with Logging {
813813
logDebug("node impurity = " + nodeImpurity)
814814

815815
// For each (feature, split), calculate the gain, and select the best (feature, split).
816-
// Initialize with infeasible values.
817-
var bestFeatureIndex = Int.MinValue
818-
var bestSplitIndex = Int.MinValue
819-
var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0)
820-
var featureIndex = 0
821-
// TODO: Change loops over splits into iterators.
822-
while (featureIndex < metadata.numFeatures) {
816+
Range(0, metadata.numFeatures).map { featureIndex =>
823817
val numSplits = metadata.numSplits(featureIndex)
824-
if (metadata.isContinuous(featureIndex)) {
818+
val (bestSplitIndex, bestGainStats) = if (metadata.isContinuous(featureIndex)) {
825819
//println(s"binsToBestSplit: feature $featureIndex (continuous)")
826820
// Cumulative sum (scanLeft) of bin statistics.
827821
// Afterwards, binAggregates for a bin is the sum of aggregates for
@@ -833,39 +827,26 @@ object DecisionTree extends Serializable with Logging {
833827
splitIndex += 1
834828
}
835829
// Find best split.
836-
splitIndex = 0
837-
while (splitIndex < numSplits) {
838-
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex)
830+
Range(0, numSplits).map { case splitIdx =>
831+
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
839832
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
840833
rightChildStats.subtract(leftChildStats)
841834
val gainStats =
842835
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
843-
if (gainStats.gain > bestGainStats.gain) {
844-
bestGainStats = gainStats
845-
bestFeatureIndex = featureIndex
846-
bestSplitIndex = splitIndex
847-
}
848-
splitIndex += 1
849-
}
836+
(splitIdx, gainStats)
837+
}.maxBy(_._2.gain)
850838
} else if (metadata.isUnordered(featureIndex)) {
851839
//println(s"binsToBestSplit: feature $featureIndex (unordered cat)")
852840
// Unordered categorical feature
853841
val (leftChildOffset, rightChildOffset) =
854842
binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
855-
var splitIndex = 0
856-
while (splitIndex < numSplits) {
843+
Range(0, numSplits).map { splitIndex =>
857844
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
858845
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
859846
val gainStats =
860847
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
861-
//println(s"\t split $splitIndex: gain: ${bestGainStats.gain}")
862-
if (gainStats.gain > bestGainStats.gain) {
863-
bestGainStats = gainStats
864-
bestFeatureIndex = featureIndex
865-
bestSplitIndex = splitIndex
866-
}
867-
splitIndex += 1
868-
}
848+
(splitIndex, gainStats)
849+
}.maxBy(_._2.gain)
869850
} else {
870851
//println(s"binsToBestSplit: feature $featureIndex (ordered cat)")
871852
// Ordered categorical feature
@@ -880,25 +861,17 @@ object DecisionTree extends Serializable with Logging {
880861
splitIndex += 1
881862
}
882863
// Find best split.
883-
splitIndex = 0
884-
while (splitIndex < numSplits) {
864+
Range(0, numSplits).map { splitIndex =>
885865
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex)
886866
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
887867
rightChildStats.subtract(leftChildStats)
888868
val gainStats =
889869
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
890-
//println(s"\t split $splitIndex: gain: ${bestGainStats.gain}")
891-
if (gainStats.gain > bestGainStats.gain) {
892-
bestGainStats = gainStats
893-
bestFeatureIndex = featureIndex
894-
bestSplitIndex = splitIndex
895-
}
896-
splitIndex += 1
897-
}
870+
(splitIndex, gainStats)
871+
}.maxBy(_._2.gain)
898872
}
899-
featureIndex += 1
900-
}
901-
(bestFeatureIndex, bestSplitIndex, bestGainStats)
873+
(featureIndex, bestSplitIndex, bestGainStats)
874+
}.maxBy(_._3.gain)
902875
}
903876

904877
/**

0 commit comments

Comments
 (0)