Skip to content

Commit 84260ca

Browse files
committed
Use the soft prediction to order categories' bins.
1 parent 5f46444 commit 84260ca

File tree

3 files changed

+179
-120
lines changed

3 files changed

+179
-120
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ private[ml] object RandomForest extends Logging {
776776
val categoryStats =
777777
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
778778
val centroid = if (categoryStats.count != 0) {
779-
categoryStats.predict
779+
categoryStats.calculate()
780780
} else {
781781
Double.MaxValue
782782
}

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 117 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -811,128 +811,132 @@ object DecisionTree extends Serializable with Logging {
811811
// For each (feature, split), calculate the gain, and select the best (feature, split).
812812
val (bestSplit, bestSplitStats) =
813813
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
814-
val featureIndex = if (featuresForNode.nonEmpty) {
815-
featuresForNode.get.apply(featureIndexIdx)
816-
} else {
817-
featureIndexIdx
818-
}
819-
val numSplits = binAggregates.metadata.numSplits(featureIndex)
820-
if (binAggregates.metadata.isContinuous(featureIndex)) {
821-
// Cumulative sum (scanLeft) of bin statistics.
822-
// Afterwards, binAggregates for a bin is the sum of aggregates for
823-
// that bin + all preceding bins.
824-
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
825-
var splitIndex = 0
826-
while (splitIndex < numSplits) {
827-
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
828-
splitIndex += 1
814+
val featureIndex = if (featuresForNode.nonEmpty) {
815+
featuresForNode.get.apply(featureIndexIdx)
816+
} else {
817+
featureIndexIdx
829818
}
830-
// Find best split.
831-
val (bestFeatureSplitIndex, bestFeatureGainStats) =
832-
Range(0, numSplits).map { case splitIdx =>
833-
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
834-
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
835-
rightChildStats.subtract(leftChildStats)
836-
predictWithImpurity = Some(predictWithImpurity.getOrElse(
837-
calculatePredictImpurity(leftChildStats, rightChildStats)))
838-
val gainStats = calculateGainForSplit(leftChildStats,
839-
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
840-
(splitIdx, gainStats)
841-
}.maxBy(_._2.gain)
842-
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
843-
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
844-
// Unordered categorical feature
845-
val (leftChildOffset, rightChildOffset) =
846-
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
847-
val (bestFeatureSplitIndex, bestFeatureGainStats) =
848-
Range(0, numSplits).map { splitIndex =>
849-
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
850-
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
851-
predictWithImpurity = Some(predictWithImpurity.getOrElse(
852-
calculatePredictImpurity(leftChildStats, rightChildStats)))
853-
val gainStats = calculateGainForSplit(leftChildStats,
854-
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
855-
(splitIndex, gainStats)
856-
}.maxBy(_._2.gain)
857-
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
858-
} else {
859-
// Ordered categorical feature
860-
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
861-
val numBins = binAggregates.metadata.numBins(featureIndex)
862-
863-
/* Each bin is one category (feature value).
864-
* The bins are ordered based on centroidForCategories, and this ordering determines which
865-
* splits are considered. (With K categories, we consider K - 1 possible splits.)
866-
*
867-
* centroidForCategories is a list: (category, centroid)
868-
*/
869-
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
870-
// For categorical variables in multiclass classification,
871-
// the bins are ordered by the impurity of their corresponding labels.
872-
Range(0, numBins).map { case featureValue =>
873-
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
874-
val centroid = if (categoryStats.count != 0) {
875-
categoryStats.calculate()
876-
} else {
877-
Double.MaxValue
878-
}
879-
(featureValue, centroid)
819+
val numSplits = binAggregates.metadata.numSplits(featureIndex)
820+
if (binAggregates.metadata.isContinuous(featureIndex)) {
821+
// Cumulative sum (scanLeft) of bin statistics.
822+
// Afterwards, binAggregates for a bin is the sum of aggregates for
823+
// that bin + all preceding bins.
824+
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
825+
var splitIndex = 0
826+
while (splitIndex < numSplits) {
827+
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
828+
splitIndex += 1
880829
}
881-
} else { // regression or binary classification
882-
// For categorical variables in regression and binary classification,
883-
// the bins are ordered by the centroid of their corresponding labels.
884-
Range(0, numBins).map { case featureValue =>
885-
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
886-
val centroid = if (categoryStats.count != 0) {
887-
categoryStats.predict
888-
} else {
889-
Double.MaxValue
830+
// Find best split.
831+
val (bestFeatureSplitIndex, bestFeatureGainStats) =
832+
Range(0, numSplits).map { case splitIdx =>
833+
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
834+
val rightChildStats =
835+
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
836+
rightChildStats.subtract(leftChildStats)
837+
predictWithImpurity = Some(predictWithImpurity.getOrElse(
838+
calculatePredictImpurity(leftChildStats, rightChildStats)))
839+
val gainStats = calculateGainForSplit(leftChildStats,
840+
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
841+
(splitIdx, gainStats)
842+
}.maxBy(_._2.gain)
843+
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
844+
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
845+
// Unordered categorical feature
846+
val (leftChildOffset, rightChildOffset) =
847+
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
848+
val (bestFeatureSplitIndex, bestFeatureGainStats) =
849+
Range(0, numSplits).map { splitIndex =>
850+
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
851+
val rightChildStats =
852+
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
853+
predictWithImpurity = Some(predictWithImpurity.getOrElse(
854+
calculatePredictImpurity(leftChildStats, rightChildStats)))
855+
val gainStats = calculateGainForSplit(leftChildStats,
856+
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
857+
(splitIndex, gainStats)
858+
}.maxBy(_._2.gain)
859+
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
860+
} else {
861+
// Ordered categorical feature
862+
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
863+
val numBins = binAggregates.metadata.numBins(featureIndex)
864+
865+
/* Each bin is one category (feature value).
866+
* The bins are ordered based on centroidForCategories, and this ordering determines which
867+
* splits are considered. (With K categories, we consider K - 1 possible splits.)
868+
*
869+
* centroidForCategories is a list: (category, centroid)
870+
*/
871+
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
872+
// For categorical variables in multiclass classification,
873+
// the bins are ordered by the impurity of their corresponding labels.
874+
Range(0, numBins).map { case featureValue =>
875+
val categoryStats =
876+
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
877+
val centroid = if (categoryStats.count != 0) {
878+
categoryStats.calculate()
879+
} else {
880+
Double.MaxValue
881+
}
882+
(featureValue, centroid)
883+
}
884+
} else { // regression or binary classification
885+
// For categorical variables in regression and binary classification,
886+
// the bins are ordered by the impurity of their corresponding labels.
887+
Range(0, numBins).map { case featureValue =>
888+
val categoryStats =
889+
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
890+
val centroid = if (categoryStats.count != 0) {
891+
categoryStats.calculate()
892+
} else {
893+
Double.MaxValue
894+
}
895+
(featureValue, centroid)
890896
}
891-
(featureValue, centroid)
892897
}
893-
}
894898

895-
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
899+
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
896900

897-
// bins sorted by centroids
898-
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
901+
// bins sorted by centroids
902+
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
899903

900-
logDebug("Sorted centroids for categorical variable = " +
901-
categoriesSortedByCentroid.mkString(","))
904+
logDebug("Sorted centroids for categorical variable = " +
905+
categoriesSortedByCentroid.mkString(","))
902906

903-
// Cumulative sum (scanLeft) of bin statistics.
904-
// Afterwards, binAggregates for a bin is the sum of aggregates for
905-
// that bin + all preceding bins.
906-
var splitIndex = 0
907-
while (splitIndex < numSplits) {
908-
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
909-
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
910-
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
911-
splitIndex += 1
907+
// Cumulative sum (scanLeft) of bin statistics.
908+
// Afterwards, binAggregates for a bin is the sum of aggregates for
909+
// that bin + all preceding bins.
910+
var splitIndex = 0
911+
while (splitIndex < numSplits) {
912+
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
913+
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
914+
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
915+
splitIndex += 1
916+
}
917+
// lastCategory = index of bin with total aggregates for this (node, feature)
918+
val lastCategory = categoriesSortedByCentroid.last._1
919+
// Find best split.
920+
val (bestFeatureSplitIndex, bestFeatureGainStats) =
921+
Range(0, numSplits).map { splitIndex =>
922+
val featureValue = categoriesSortedByCentroid(splitIndex)._1
923+
val leftChildStats =
924+
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
925+
val rightChildStats =
926+
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
927+
rightChildStats.subtract(leftChildStats)
928+
predictWithImpurity = Some(predictWithImpurity.getOrElse(
929+
calculatePredictImpurity(leftChildStats, rightChildStats)))
930+
val gainStats = calculateGainForSplit(leftChildStats,
931+
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
932+
(splitIndex, gainStats)
933+
}.maxBy(_._2.gain)
934+
val categoriesForSplit =
935+
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
936+
val bestFeatureSplit =
937+
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
938+
(bestFeatureSplit, bestFeatureGainStats)
912939
}
913-
// lastCategory = index of bin with total aggregates for this (node, feature)
914-
val lastCategory = categoriesSortedByCentroid.last._1
915-
// Find best split.
916-
val (bestFeatureSplitIndex, bestFeatureGainStats) =
917-
Range(0, numSplits).map { splitIndex =>
918-
val featureValue = categoriesSortedByCentroid(splitIndex)._1
919-
val leftChildStats =
920-
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
921-
val rightChildStats =
922-
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
923-
rightChildStats.subtract(leftChildStats)
924-
predictWithImpurity = Some(predictWithImpurity.getOrElse(
925-
calculatePredictImpurity(leftChildStats, rightChildStats)))
926-
val gainStats = calculateGainForSplit(leftChildStats,
927-
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
928-
(splitIndex, gainStats)
929-
}.maxBy(_._2.gain)
930-
val categoriesForSplit =
931-
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
932-
val bestFeatureSplit =
933-
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
934-
(bestFeatureSplit, bestFeatureGainStats)
935-
}
936940
}.maxBy(_._2.gain)
937941

938942
(bestSplit, bestSplitStats, predictWithImpurity.get._1)

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree
3030
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
3131
import org.apache.spark.mllib.tree.model._
3232
import org.apache.spark.mllib.util.MLlibTestSparkContext
33+
import org.apache.spark.mllib.util.TestingUtils._
3334
import org.apache.spark.util.Utils
3435

3536

@@ -294,8 +295,12 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
294295
assert(topNode.impurity !== -1.0)
295296

296297
// set impurity and predict for child nodes
297-
assert(topNode.leftNode.get.predict.predict === 0.0)
298-
assert(topNode.rightNode.get.predict.predict === 1.0)
298+
if (topNode.leftNode.get.predict.predict === 0.0) {
299+
assert(topNode.rightNode.get.predict.predict === 1.0)
300+
} else {
301+
assert(topNode.leftNode.get.predict.predict === 1.0)
302+
assert(topNode.rightNode.get.predict.predict === 0.0)
303+
}
299304
assert(topNode.leftNode.get.impurity === 0.0)
300305
assert(topNode.rightNode.get.impurity === 0.0)
301306
}
@@ -337,12 +342,62 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
337342
assert(topNode.impurity !== -1.0)
338343

339344
// set impurity and predict for child nodes
340-
assert(topNode.leftNode.get.predict.predict === 0.0)
341-
assert(topNode.rightNode.get.predict.predict === 1.0)
345+
if (topNode.leftNode.get.predict.predict === 0.0) {
346+
assert(topNode.rightNode.get.predict.predict === 1.0)
347+
} else {
348+
assert(topNode.leftNode.get.predict.predict === 1.0)
349+
assert(topNode.rightNode.get.predict.predict === 0.0)
350+
}
342351
assert(topNode.leftNode.get.impurity === 0.0)
343352
assert(topNode.rightNode.get.impurity === 0.0)
344353
}
345354

355+
test("Use soft prediction for binary classification with ordered categorical features") {
356+
val arr = Array(
357+
LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), // left node
358+
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), // right node
359+
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), // left node
360+
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)), // right node
361+
LabeledPoint(1.0, Vectors.dense(1.0, 1.0, 0.0)), // left node
362+
LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 2.0))) // left node
363+
val input = sc.parallelize(arr)
364+
365+
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
366+
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
367+
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
368+
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
369+
370+
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
371+
val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
372+
373+
val topNode = Node.emptyNode(nodeIndex = 1)
374+
assert(topNode.predict.predict === Double.MinValue)
375+
assert(topNode.impurity === -1.0)
376+
assert(topNode.isLeaf === false)
377+
378+
val nodesForGroup = Map((0, Array(topNode)))
379+
val treeToNodeToIndexInfo = Map((0, Map(
380+
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
381+
)))
382+
val nodeQueue = new mutable.Queue[(Int, Node)]()
383+
DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
384+
nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
385+
386+
// don't enqueue leaf nodes into node queue
387+
assert(nodeQueue.isEmpty)
388+
389+
// set impurity and predict for topNode
390+
assert(topNode.predict.predict !== Double.MinValue)
391+
assert(topNode.impurity !== -1.0)
392+
393+
val impurityForRightNode = Gini.calculate(Array(0.0, 3.0, 1.0), 4.0)
394+
395+
// set impurity and predict for child nodes
396+
assert(topNode.leftNode.get.predict.predict === 0.0)
397+
assert(topNode.rightNode.get.predict.predict === 1.0)
398+
assert(topNode.leftNode.get.impurity ~== 0.44 absTol impurityForRightNode)
399+
assert(topNode.rightNode.get.impurity === 0.0)
400+
}
346401
test("Second level node building with vs. without groups") {
347402
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
348403
assert(arr.length === 1000)
@@ -442,7 +497,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
442497
val rootNode = DecisionTree.train(rdd, strategy).topNode
443498

444499
val split = rootNode.split.get
445-
assert(split.categories === List(1.0))
500+
assert(split.categories === List(0.0))
446501
assert(split.featureType === Categorical)
447502
assert(split.threshold === Double.MinValue)
448503

@@ -471,7 +526,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
471526

472527
val split = rootNode.split.get
473528
assert(split.categories.length === 1)
474-
assert(split.categories.contains(1.0))
529+
assert(split.categories.contains(0.0))
475530
assert(split.featureType === Categorical)
476531
assert(split.threshold === Double.MinValue)
477532

0 commit comments

Comments
 (0)