Skip to content

Commit 5c44e23

Browse files
committed
For comments.
1 parent a37d3d8 commit 5c44e23

File tree

3 files changed

+41
-64
lines changed

3 files changed

+41
-64
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -720,36 +720,30 @@ private[ml] object RandomForest extends Logging {
720720
*
721721
* centroidForCategories is a list: (category, centroid)
722722
*/
723-
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
724-
// For categorical variables in multiclass classification,
725-
// the bins are ordered by the impurity of their corresponding labels.
726-
Range(0, numCategories).map { case featureValue =>
727-
val categoryStats =
728-
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
729-
val centroid = if (categoryStats.count != 0) {
723+
val centroidForCategories = Range(0, numCategories).map { case featureValue =>
724+
val categoryStats =
725+
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
726+
val centroid = if (categoryStats.count != 0) {
727+
if (binAggregates.metadata.isMulticlass) {
728+
// multiclass classification
729+
// For categorical variables in multiclass classification,
730+
// the bins are ordered by the impurity of their corresponding labels.
730731
categoryStats.calculate()
732+
} else if (binAggregates.metadata.isClassification) {
733+
// binary classification
734+
// For categorical variables in binary classification,
735+
// the bins are ordered by the count of class 1.
736+
categoryStats.stats(1)
731737
} else {
732-
Double.MaxValue
733-
}
734-
(featureValue, centroid)
735-
}
736-
} else { // regression or binary classification
737-
// For categorical variables in regression and binary classification,
738-
// the bins are ordered by the centroid of their corresponding labels.
739-
Range(0, numCategories).map { case featureValue =>
740-
val categoryStats =
741-
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
742-
val centroid = if (categoryStats.count != 0) {
743-
if (categoryStats.count == 2) {
744-
categoryStats.stats(1)
745-
} else {
746-
categoryStats.predict
747-
}
748-
} else {
749-
Double.MaxValue
738+
// regression
739+
// For categorical variables in regression and binary classification,
740+
// the bins are ordered by the prediction.
741+
categoryStats.predict
750742
}
751-
(featureValue, centroid)
743+
} else {
744+
Double.MaxValue
752745
}
746+
(featureValue, centroid)
753747
}
754748

755749
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -865,36 +865,27 @@ object DecisionTree extends Serializable with Logging {
865865
*
866866
* centroidForCategories is a list: (category, centroid)
867867
*/
868-
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
869-
// For categorical variables in multiclass classification,
870-
// the bins are ordered by the impurity of their corresponding labels.
871-
Range(0, numBins).map { case featureValue =>
872-
val categoryStats =
873-
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
874-
val centroid = if (categoryStats.count != 0) {
868+
val centroidForCategories = Range(0, numBins).map { case featureValue =>
869+
val categoryStats =
870+
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
871+
val centroid = if (categoryStats.count != 0) {
872+
if (binAggregates.metadata.isMulticlass) {
873+
// For categorical variables in multiclass classification,
874+
// the bins are ordered by the impurity of their corresponding labels.
875875
categoryStats.calculate()
876+
} else if (binAggregates.metadata.isClassification) {
877+
// For categorical variables in binary classification,
878+
// the bins are ordered by the count of class 1.
879+
categoryStats.stats(1)
876880
} else {
877-
Double.MaxValue
878-
}
879-
(featureValue, centroid)
880-
}
881-
} else { // regression or binary classification
882-
// For categorical variables in regression and binary classification,
883-
// the bins are ordered by the impurity of their corresponding labels.
884-
Range(0, numBins).map { case featureValue =>
885-
val categoryStats =
886-
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
887-
val centroid = if (categoryStats.count != 0) {
888-
if (categoryStats.count == 2) {
889-
categoryStats.stats(1)
890-
} else {
891-
categoryStats.predict
892-
}
893-
} else {
894-
Double.MaxValue
881+
// For categorical variables in regression,
882+
// the bins are ordered by the prediction.
883+
categoryStats.predict
895884
}
896-
(featureValue, centroid)
885+
} else {
886+
Double.MaxValue
897887
}
888+
(featureValue, centroid)
898889
}
899890

900891
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,8 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
289289
assert(topNode.impurity !== -1.0)
290290

291291
// set impurity and predict for child nodes
292-
if (topNode.leftNode.get.predict.predict === 0.0) {
293-
assert(topNode.rightNode.get.predict.predict === 1.0)
294-
} else {
295-
assert(topNode.leftNode.get.predict.predict === 1.0)
296-
assert(topNode.rightNode.get.predict.predict === 0.0)
297-
}
292+
assert(topNode.leftNode.get.predict.predict === 0.0)
293+
assert(topNode.rightNode.get.predict.predict === 1.0)
298294
assert(topNode.leftNode.get.impurity === 0.0)
299295
assert(topNode.rightNode.get.impurity === 0.0)
300296
}
@@ -336,12 +332,8 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
336332
assert(topNode.impurity !== -1.0)
337333

338334
// set impurity and predict for child nodes
339-
if (topNode.leftNode.get.predict.predict === 0.0) {
340-
assert(topNode.rightNode.get.predict.predict === 1.0)
341-
} else {
342-
assert(topNode.leftNode.get.predict.predict === 1.0)
343-
assert(topNode.rightNode.get.predict.predict === 0.0)
344-
}
335+
assert(topNode.leftNode.get.predict.predict === 0.0)
336+
assert(topNode.rightNode.get.predict.predict === 1.0)
345337
assert(topNode.leftNode.get.impurity === 0.0)
346338
assert(topNode.rightNode.get.impurity === 0.0)
347339
}

0 commit comments

Comments
 (0)