@@ -813,15 +813,9 @@ object DecisionTree extends Serializable with Logging {
813
813
logDebug(" node impurity = " + nodeImpurity)
814
814
815
815
// For each (feature, split), calculate the gain, and select the best (feature, split).
816
- // Initialize with infeasible values.
817
- var bestFeatureIndex = Int .MinValue
818
- var bestSplitIndex = Int .MinValue
819
- var bestGainStats = new InformationGainStats (Double .MinValue , - 1.0 , - 1.0 , - 1.0 , - 1.0 )
820
- var featureIndex = 0
821
- // TODO: Change loops over splits into iterators.
822
- while (featureIndex < metadata.numFeatures) {
816
+ Range (0 , metadata.numFeatures).map { featureIndex =>
823
817
val numSplits = metadata.numSplits(featureIndex)
824
- if (metadata.isContinuous(featureIndex)) {
818
+ val (bestSplitIndex, bestGainStats) = if (metadata.isContinuous(featureIndex)) {
825
819
// println(s"binsToBestSplit: feature $featureIndex (continuous)")
826
820
// Cumulative sum (scanLeft) of bin statistics.
827
821
// Afterwards, binAggregates for a bin is the sum of aggregates for
@@ -833,39 +827,26 @@ object DecisionTree extends Serializable with Logging {
833
827
splitIndex += 1
834
828
}
835
829
// Find best split.
836
- splitIndex = 0
837
- while (splitIndex < numSplits) {
838
- val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex)
830
+ Range (0 , numSplits).map { case splitIdx =>
831
+ val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
839
832
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
840
833
rightChildStats.subtract(leftChildStats)
841
834
val gainStats =
842
835
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
843
- if (gainStats.gain > bestGainStats.gain) {
844
- bestGainStats = gainStats
845
- bestFeatureIndex = featureIndex
846
- bestSplitIndex = splitIndex
847
- }
848
- splitIndex += 1
849
- }
836
+ (splitIdx, gainStats)
837
+ }.maxBy(_._2.gain)
850
838
} else if (metadata.isUnordered(featureIndex)) {
851
839
// println(s"binsToBestSplit: feature $featureIndex (unordered cat)")
852
840
// Unordered categorical feature
853
841
val (leftChildOffset, rightChildOffset) =
854
842
binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
855
- var splitIndex = 0
856
- while (splitIndex < numSplits) {
843
+ Range (0 , numSplits).map { splitIndex =>
857
844
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
858
845
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
859
846
val gainStats =
860
847
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
861
- // println(s"\t split $splitIndex: gain: ${bestGainStats.gain}")
862
- if (gainStats.gain > bestGainStats.gain) {
863
- bestGainStats = gainStats
864
- bestFeatureIndex = featureIndex
865
- bestSplitIndex = splitIndex
866
- }
867
- splitIndex += 1
868
- }
848
+ (splitIndex, gainStats)
849
+ }.maxBy(_._2.gain)
869
850
} else {
870
851
// println(s"binsToBestSplit: feature $featureIndex (ordered cat)")
871
852
// Ordered categorical feature
@@ -880,25 +861,17 @@ object DecisionTree extends Serializable with Logging {
880
861
splitIndex += 1
881
862
}
882
863
// Find best split.
883
- splitIndex = 0
884
- while (splitIndex < numSplits) {
864
+ Range (0 , numSplits).map { splitIndex =>
885
865
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex)
886
866
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
887
867
rightChildStats.subtract(leftChildStats)
888
868
val gainStats =
889
869
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
890
- // println(s"\t split $splitIndex: gain: ${bestGainStats.gain}")
891
- if (gainStats.gain > bestGainStats.gain) {
892
- bestGainStats = gainStats
893
- bestFeatureIndex = featureIndex
894
- bestSplitIndex = splitIndex
895
- }
896
- splitIndex += 1
897
- }
870
+ (splitIndex, gainStats)
871
+ }.maxBy(_._2.gain)
898
872
}
899
- featureIndex += 1
900
- }
901
- (bestFeatureIndex, bestSplitIndex, bestGainStats)
873
+ (featureIndex, bestSplitIndex, bestGainStats)
874
+ }.maxBy(_._3.gain)
902
875
}
903
876
904
877
/**
0 commit comments