Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-3381] [MLlib] Eliminate bins for unordered features in DecisionTrees #4231

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 10 additions & 27 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,14 @@ object DecisionTree extends Serializable with Logging {
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (feature, bin).
* @param treePoint Data point being aggregated.
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param splits possible splits indexed (numFeatures)(numSplits)
* @param unorderedFeatures Set of indices of unordered features.
* @param instanceWeight Weight (importance) of instance in dataset.
*/
private def mixedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
bins: Array[Array[Bin]],
splits: Array[Array[Split]],
unorderedFeatures: Set[Int],
instanceWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
Expand Down Expand Up @@ -362,7 +362,7 @@ object DecisionTree extends Serializable with Logging {
val numSplits = agg.metadata.numSplits(featureIndex)
var splitIndex = 0
while (splitIndex < numSplits) {
if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
} else {
Expand Down Expand Up @@ -506,8 +506,8 @@ object DecisionTree extends Serializable with Logging {
if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
instanceWeight, featuresForNode)
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
metadata.unorderedFeatures, instanceWeight, featuresForNode)
}
}
}
Expand Down Expand Up @@ -1024,44 +1024,27 @@ object DecisionTree extends Serializable with Logging {
// Categorical feature
val featureArity = metadata.featureArity(featureIndex)
if (metadata.isUnordered(featureIndex)) {
// TODO: The second half of the bins are unused. Actually, we could just use
// splits and not build bins for unordered features. That should be part of
// a later PR since it will require changing other code (using splits instead
// of bins in a few places).
// Unordered features
// 2^(maxFeatureValue - 1) - 1 combinations
// 2^(maxFeatureValue - 1) - 1 combinations
splits(featureIndex) = new Array[Split](numSplits)
bins(featureIndex) = new Array[Bin](numBins)
var splitIndex = 0
while (splitIndex < numSplits) {
val categories: List[Double] =
extractMultiClassCategories(splitIndex + 1, featureArity)
splits(featureIndex)(splitIndex) =
new Split(featureIndex, Double.MinValue, Categorical, categories)
bins(featureIndex)(splitIndex) = {
if (splitIndex == 0) {
new Bin(
new DummyCategoricalSplit(featureIndex, Categorical),
splits(featureIndex)(0),
Categorical,
Double.MinValue)
} else {
new Bin(
splits(featureIndex)(splitIndex - 1),
splits(featureIndex)(splitIndex),
Categorical,
Double.MinValue)
}
}
splitIndex += 1
}
} else {
// Ordered features
// Bins correspond to feature values, so we do not need to compute splits or bins
// beforehand. Splits are constructed as needed during training.
splits(featureIndex) = new Array[Split](0)
bins(featureIndex) = new Array[Bin](0)
}
// For ordered features, bins correspond to feature values.
// For unordered categorical features, there is no need to construct the bins.
// since there is a one-to-one correspondence between the splits and the bins.
bins(featureIndex) = new Array[Bin](0)
}
featureIndex += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,15 @@ private[tree] object TreePoint {
input: RDD[LabeledPoint],
bins: Array[Array[Bin]],
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
// Construct arrays for featureArity and isUnordered for efficiency in the inner loop.
// Construct arrays for featureArity for efficiency in the inner loop.
val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures)
var featureIndex = 0
while (featureIndex < metadata.numFeatures) {
featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
isUnordered(featureIndex) = metadata.isUnordered(featureIndex)
featureIndex += 1
}
input.map { x =>
TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered)
TreePoint.labeledPointToTreePoint(x, bins, featureArity)
}
}

Expand All @@ -74,19 +72,17 @@ private[tree] object TreePoint {
* @param bins Bins for features, of size (numFeatures, numBins).
* @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
* for categorical features.
* @param isUnordered Array index by feature, with value true for unordered categorical features.
*/
private def labeledPointToTreePoint(
labeledPoint: LabeledPoint,
bins: Array[Array[Bin]],
featureArity: Array[Int],
isUnordered: Array[Boolean]): TreePoint = {
featureArity: Array[Int]): TreePoint = {
val numFeatures = labeledPoint.features.size
val arr = new Array[Int](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
isUnordered(featureIndex), bins)
bins)
featureIndex += 1
}
new TreePoint(labeledPoint.label, arr)
Expand All @@ -96,14 +92,12 @@ private[tree] object TreePoint {
* Find bin for one (labeledPoint, feature).
*
* @param featureArity 0 for continuous features; number of categories for categorical features.
* @param isUnorderedFeature (only applies if feature is categorical)
* @param bins Bins for features, of size (numFeatures, numBins).
*/
private def findBin(
featureIndex: Int,
labeledPoint: LabeledPoint,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbradley I removed this param as it is unused. I don't think it is a problem since all tests pass.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Can you please also remove it from labeledPointToTreePoint and not compute it in convertToTreeRDD?

featureArity: Int,
isUnorderedFeature: Boolean,
bins: Array[Array[Bin]]): Int = {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(splits.length === 2)
assert(bins.length === 2)
assert(splits(0).length === 3)
assert(bins(0).length === 6)
assert(bins(0).length === 0)

// Expecting 2^2 - 1 = 3 bins/splits
assert(splits(0)(0).feature === 0)
Expand Down Expand Up @@ -228,41 +228,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(splits(1)(2).categories.contains(0.0))
assert(splits(1)(2).categories.contains(1.0))

// Check bins.

assert(bins(0)(0).category === Double.MinValue)
assert(bins(0)(0).lowSplit.categories.length === 0)
assert(bins(0)(0).highSplit.categories.length === 1)
assert(bins(0)(0).highSplit.categories.contains(0.0))
assert(bins(1)(0).category === Double.MinValue)
assert(bins(1)(0).lowSplit.categories.length === 0)
assert(bins(1)(0).highSplit.categories.length === 1)
assert(bins(1)(0).highSplit.categories.contains(0.0))

assert(bins(0)(1).category === Double.MinValue)
assert(bins(0)(1).lowSplit.categories.length === 1)
assert(bins(0)(1).lowSplit.categories.contains(0.0))
assert(bins(0)(1).highSplit.categories.length === 1)
assert(bins(0)(1).highSplit.categories.contains(1.0))
assert(bins(1)(1).category === Double.MinValue)
assert(bins(1)(1).lowSplit.categories.length === 1)
assert(bins(1)(1).lowSplit.categories.contains(0.0))
assert(bins(1)(1).highSplit.categories.length === 1)
assert(bins(1)(1).highSplit.categories.contains(1.0))

assert(bins(0)(2).category === Double.MinValue)
assert(bins(0)(2).lowSplit.categories.length === 1)
assert(bins(0)(2).lowSplit.categories.contains(1.0))
assert(bins(0)(2).highSplit.categories.length === 2)
assert(bins(0)(2).highSplit.categories.contains(1.0))
assert(bins(0)(2).highSplit.categories.contains(0.0))
assert(bins(1)(2).category === Double.MinValue)
assert(bins(1)(2).lowSplit.categories.length === 1)
assert(bins(1)(2).lowSplit.categories.contains(1.0))
assert(bins(1)(2).highSplit.categories.length === 2)
assert(bins(1)(2).highSplit.categories.contains(1.0))
assert(bins(1)(2).highSplit.categories.contains(0.0))

}

test("Multiclass classification with ordered categorical features: split and bin calculations") {
Expand Down