Skip to content

[SPARK-10524][ML] Use the soft prediction to order categories' bins #8734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ private[ml] object RandomForest extends Logging {
* @param binAggregates Bin statistics.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
private def binsToBestSplit(
private[tree] def binsToBestSplit(
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
Expand Down Expand Up @@ -720,32 +720,30 @@ private[ml] object RandomForest extends Logging {
*
* centroidForCategories is a list: (category, centroid)
*/
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numCategories).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
val centroidForCategories = Range(0, numCategories).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
if (binAggregates.metadata.isMulticlass) {
// multiclass classification
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
categoryStats.calculate()
} else if (binAggregates.metadata.isClassification) {
// binary classification
// For categorical variables in binary classification,
// the bins are ordered by the count of class 1.
categoryStats.stats(1)
} else {
Double.MaxValue
}
(featureValue, centroid)
}
} else { // regression or binary classification
// For categorical variables in regression and binary classification,
// the bins are ordered by the centroid of their corresponding labels.
Range(0, numCategories).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
// regression
// For categorical variables in regression and binary classification,
// the bins are ordered by the prediction.
categoryStats.predict
} else {
Double.MaxValue
}
(featureValue, centroid)
} else {
Double.MaxValue
}
(featureValue, centroid)
}

logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
Expand Down
219 changes: 109 additions & 110 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ object DecisionTree extends Serializable with Logging {
* @param binAggregates Bin statistics.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
private def binsToBestSplit(
private[tree] def binsToBestSplit(
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
Expand All @@ -808,128 +808,127 @@ object DecisionTree extends Serializable with Logging {
// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) =
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
val numSplits = binAggregates.metadata.numSplits(featureIndex)
if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
var splitIndex = 0
while (splitIndex < numSplits) {
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { case splitIdx =>
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val numBins = binAggregates.metadata.numBins(featureIndex)

/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
* splits are considered. (With K categories, we consider K - 1 possible splits.)
*
* centroidForCategories is a list: (category, centroid)
*/
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numBins).map { case featureValue =>
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.calculate()
} else {
Double.MaxValue
}
(featureValue, centroid)
val numSplits = binAggregates.metadata.numSplits(featureIndex)
if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
var splitIndex = 0
while (splitIndex < numSplits) {
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
}
} else { // regression or binary classification
// For categorical variables in regression and binary classification,
// the bins are ordered by the centroid of their corresponding labels.
Range(0, numBins).map { case featureValue =>
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { case splitIdx =>
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats =
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val numBins = binAggregates.metadata.numBins(featureIndex)

/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
* splits are considered. (With K categories, we consider K - 1 possible splits.)
*
* centroidForCategories is a list: (category, centroid)
*/
val centroidForCategories = Range(0, numBins).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.predict
if (binAggregates.metadata.isMulticlass) {
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
categoryStats.calculate()
} else if (binAggregates.metadata.isClassification) {
// For categorical variables in binary classification,
// the bins are ordered by the count of class 1.
categoryStats.stats(1)
} else {
// For categorical variables in regression,
// the bins are ordered by the prediction.
categoryStats.predict
}
} else {
Double.MaxValue
}
(featureValue, centroid)
}
}

logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))

// bins sorted by centroids
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
// bins sorted by centroids
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)

logDebug("Sorted centroids for categorical variable = " +
categoriesSortedByCentroid.mkString(","))
logDebug("Sorted centroids for categorical variable = " +
categoriesSortedByCentroid.mkString(","))

// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
var splitIndex = 0
while (splitIndex < numSplits) {
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
splitIndex += 1
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
var splitIndex = 0
while (splitIndex < numSplits) {
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
splitIndex += 1
}
// lastCategory = index of bin with total aggregates for this (node, feature)
val lastCategory = categoriesSortedByCentroid.last._1
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val featureValue = categoriesSortedByCentroid(splitIndex)._1
val leftChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
(bestFeatureSplit, bestFeatureGainStats)
}
// lastCategory = index of bin with total aggregates for this (node, feature)
val lastCategory = categoriesSortedByCentroid.last._1
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val featureValue = categoriesSortedByCentroid(splitIndex)._1
val leftChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
(bestFeatureSplit, bestFeatureGainStats)
}
}.maxBy(_._2.gain)

(bestSplit, bestSplitStats, predictWithImpurity.get._1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down Expand Up @@ -275,6 +275,40 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
val model = dt.fit(df)
}

test("Use soft prediction for binary classification with ordered categorical features") {
// The following dataset is set up such that the best split is {1} vs. {0, 2}.
// If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(1.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(2.0)))
val data = sc.parallelize(arr)
val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)

// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
val dt = new DecisionTreeClassifier()
.setImpurity("gini")
.setMaxDepth(1)
.setMaxBins(3)
val model = dt.fit(df)
model.rootNode match {
case n: InternalNode =>
n.split match {
case s: CategoricalSplit =>
assert(s.leftCategories === Array(1.0))
}
}
}

/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
Expand Down
Loading