@@ -811,128 +811,132 @@ object DecisionTree extends Serializable with Logging {
811
811
// For each (feature, split), calculate the gain, and select the best (feature, split).
812
812
val (bestSplit, bestSplitStats) =
813
813
Range (0 , binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
814
- val featureIndex = if (featuresForNode.nonEmpty) {
815
- featuresForNode.get.apply(featureIndexIdx)
816
- } else {
817
- featureIndexIdx
818
- }
819
- val numSplits = binAggregates.metadata.numSplits(featureIndex)
820
- if (binAggregates.metadata.isContinuous(featureIndex)) {
821
- // Cumulative sum (scanLeft) of bin statistics.
822
- // Afterwards, binAggregates for a bin is the sum of aggregates for
823
- // that bin + all preceding bins.
824
- val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
825
- var splitIndex = 0
826
- while (splitIndex < numSplits) {
827
- binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1 , splitIndex)
828
- splitIndex += 1
814
+ val featureIndex = if (featuresForNode.nonEmpty) {
815
+ featuresForNode.get.apply(featureIndexIdx)
816
+ } else {
817
+ featureIndexIdx
829
818
}
830
- // Find best split.
831
- val (bestFeatureSplitIndex, bestFeatureGainStats) =
832
- Range (0 , numSplits).map { case splitIdx =>
833
- val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
834
- val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
835
- rightChildStats.subtract(leftChildStats)
836
- predictWithImpurity = Some (predictWithImpurity.getOrElse(
837
- calculatePredictImpurity(leftChildStats, rightChildStats)))
838
- val gainStats = calculateGainForSplit(leftChildStats,
839
- rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
840
- (splitIdx, gainStats)
841
- }.maxBy(_._2.gain)
842
- (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
843
- } else if (binAggregates.metadata.isUnordered(featureIndex)) {
844
- // Unordered categorical feature
845
- val (leftChildOffset, rightChildOffset) =
846
- binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
847
- val (bestFeatureSplitIndex, bestFeatureGainStats) =
848
- Range (0 , numSplits).map { splitIndex =>
849
- val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
850
- val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
851
- predictWithImpurity = Some (predictWithImpurity.getOrElse(
852
- calculatePredictImpurity(leftChildStats, rightChildStats)))
853
- val gainStats = calculateGainForSplit(leftChildStats,
854
- rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
855
- (splitIndex, gainStats)
856
- }.maxBy(_._2.gain)
857
- (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
858
- } else {
859
- // Ordered categorical feature
860
- val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
861
- val numBins = binAggregates.metadata.numBins(featureIndex)
862
-
863
- /* Each bin is one category (feature value).
864
- * The bins are ordered based on centroidForCategories, and this ordering determines which
865
- * splits are considered. (With K categories, we consider K - 1 possible splits.)
866
- *
867
- * centroidForCategories is a list: (category, centroid)
868
- */
869
- val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
870
- // For categorical variables in multiclass classification,
871
- // the bins are ordered by the impurity of their corresponding labels.
872
- Range (0 , numBins).map { case featureValue =>
873
- val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
874
- val centroid = if (categoryStats.count != 0 ) {
875
- categoryStats.calculate()
876
- } else {
877
- Double .MaxValue
878
- }
879
- (featureValue, centroid)
819
+ val numSplits = binAggregates.metadata.numSplits(featureIndex)
820
+ if (binAggregates.metadata.isContinuous(featureIndex)) {
821
+ // Cumulative sum (scanLeft) of bin statistics.
822
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
823
+ // that bin + all preceding bins.
824
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
825
+ var splitIndex = 0
826
+ while (splitIndex < numSplits) {
827
+ binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1 , splitIndex)
828
+ splitIndex += 1
880
829
}
881
- } else { // regression or binary classification
882
- // For categorical variables in regression and binary classification,
883
- // the bins are ordered by the centroid of their corresponding labels.
884
- Range (0 , numBins).map { case featureValue =>
885
- val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
886
- val centroid = if (categoryStats.count != 0 ) {
887
- categoryStats.predict
888
- } else {
889
- Double .MaxValue
830
+ // Find best split.
831
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
832
+ Range (0 , numSplits).map { case splitIdx =>
833
+ val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
834
+ val rightChildStats =
835
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
836
+ rightChildStats.subtract(leftChildStats)
837
+ predictWithImpurity = Some (predictWithImpurity.getOrElse(
838
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
839
+ val gainStats = calculateGainForSplit(leftChildStats,
840
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
841
+ (splitIdx, gainStats)
842
+ }.maxBy(_._2.gain)
843
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
844
+ } else if (binAggregates.metadata.isUnordered(featureIndex)) {
845
+ // Unordered categorical feature
846
+ val (leftChildOffset, rightChildOffset) =
847
+ binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
848
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
849
+ Range (0 , numSplits).map { splitIndex =>
850
+ val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
851
+ val rightChildStats =
852
+ binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
853
+ predictWithImpurity = Some (predictWithImpurity.getOrElse(
854
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
855
+ val gainStats = calculateGainForSplit(leftChildStats,
856
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
857
+ (splitIndex, gainStats)
858
+ }.maxBy(_._2.gain)
859
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
860
+ } else {
861
+ // Ordered categorical feature
862
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
863
+ val numBins = binAggregates.metadata.numBins(featureIndex)
864
+
865
+ /* Each bin is one category (feature value).
866
+ * The bins are ordered based on centroidForCategories, and this ordering determines which
867
+ * splits are considered. (With K categories, we consider K - 1 possible splits.)
868
+ *
869
+ * centroidForCategories is a list: (category, centroid)
870
+ */
871
+ val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
872
+ // For categorical variables in multiclass classification,
873
+ // the bins are ordered by the impurity of their corresponding labels.
874
+ Range (0 , numBins).map { case featureValue =>
875
+ val categoryStats =
876
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
877
+ val centroid = if (categoryStats.count != 0 ) {
878
+ categoryStats.calculate()
879
+ } else {
880
+ Double .MaxValue
881
+ }
882
+ (featureValue, centroid)
883
+ }
884
+ } else { // regression or binary classification
885
+ // For categorical variables in regression and binary classification,
886
+ // the bins are ordered by the impurity of their corresponding labels.
887
+ Range (0 , numBins).map { case featureValue =>
888
+ val categoryStats =
889
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
890
+ val centroid = if (categoryStats.count != 0 ) {
891
+ categoryStats.calculate()
892
+ } else {
893
+ Double .MaxValue
894
+ }
895
+ (featureValue, centroid)
890
896
}
891
- (featureValue, centroid)
892
897
}
893
- }
894
898
895
- logDebug(" Centroids for categorical variable: " + centroidForCategories.mkString(" ," ))
899
+ logDebug(" Centroids for categorical variable: " + centroidForCategories.mkString(" ," ))
896
900
897
- // bins sorted by centroids
898
- val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
901
+ // bins sorted by centroids
902
+ val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
899
903
900
- logDebug(" Sorted centroids for categorical variable = " +
901
- categoriesSortedByCentroid.mkString(" ," ))
904
+ logDebug(" Sorted centroids for categorical variable = " +
905
+ categoriesSortedByCentroid.mkString(" ," ))
902
906
903
- // Cumulative sum (scanLeft) of bin statistics.
904
- // Afterwards, binAggregates for a bin is the sum of aggregates for
905
- // that bin + all preceding bins.
906
- var splitIndex = 0
907
- while (splitIndex < numSplits) {
908
- val currentCategory = categoriesSortedByCentroid(splitIndex)._1
909
- val nextCategory = categoriesSortedByCentroid(splitIndex + 1 )._1
910
- binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
911
- splitIndex += 1
907
+ // Cumulative sum (scanLeft) of bin statistics.
908
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
909
+ // that bin + all preceding bins.
910
+ var splitIndex = 0
911
+ while (splitIndex < numSplits) {
912
+ val currentCategory = categoriesSortedByCentroid(splitIndex)._1
913
+ val nextCategory = categoriesSortedByCentroid(splitIndex + 1 )._1
914
+ binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
915
+ splitIndex += 1
916
+ }
917
+ // lastCategory = index of bin with total aggregates for this (node, feature)
918
+ val lastCategory = categoriesSortedByCentroid.last._1
919
+ // Find best split.
920
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
921
+ Range (0 , numSplits).map { splitIndex =>
922
+ val featureValue = categoriesSortedByCentroid(splitIndex)._1
923
+ val leftChildStats =
924
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
925
+ val rightChildStats =
926
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
927
+ rightChildStats.subtract(leftChildStats)
928
+ predictWithImpurity = Some (predictWithImpurity.getOrElse(
929
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
930
+ val gainStats = calculateGainForSplit(leftChildStats,
931
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
932
+ (splitIndex, gainStats)
933
+ }.maxBy(_._2.gain)
934
+ val categoriesForSplit =
935
+ categoriesSortedByCentroid.map(_._1.toDouble).slice(0 , bestFeatureSplitIndex + 1 )
936
+ val bestFeatureSplit =
937
+ new Split (featureIndex, Double .MinValue , Categorical , categoriesForSplit)
938
+ (bestFeatureSplit, bestFeatureGainStats)
912
939
}
913
- // lastCategory = index of bin with total aggregates for this (node, feature)
914
- val lastCategory = categoriesSortedByCentroid.last._1
915
- // Find best split.
916
- val (bestFeatureSplitIndex, bestFeatureGainStats) =
917
- Range (0 , numSplits).map { splitIndex =>
918
- val featureValue = categoriesSortedByCentroid(splitIndex)._1
919
- val leftChildStats =
920
- binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
921
- val rightChildStats =
922
- binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
923
- rightChildStats.subtract(leftChildStats)
924
- predictWithImpurity = Some (predictWithImpurity.getOrElse(
925
- calculatePredictImpurity(leftChildStats, rightChildStats)))
926
- val gainStats = calculateGainForSplit(leftChildStats,
927
- rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
928
- (splitIndex, gainStats)
929
- }.maxBy(_._2.gain)
930
- val categoriesForSplit =
931
- categoriesSortedByCentroid.map(_._1.toDouble).slice(0 , bestFeatureSplitIndex + 1 )
932
- val bestFeatureSplit =
933
- new Split (featureIndex, Double .MinValue , Categorical , categoriesForSplit)
934
- (bestFeatureSplit, bestFeatureGainStats)
935
- }
936
940
}.maxBy(_._2.gain)
937
941
938
942
(bestSplit, bestSplitStats, predictWithImpurity.get._1)
0 commit comments