Skip to content

Commit

Permalink
[SPARK-3381] [MLlib] Eliminate bins for unordered features
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Feb 17, 2015
1 parent 3ce58cf commit d3ee042
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,15 @@ object DecisionTree extends Serializable with Logging {
* each (feature, bin).
* @param treePoint Data point being aggregated.
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param splits possible splits indexed (numFeatures)(numSplits)
* @param unorderedFeatures Set of indices of unordered features.
* @param instanceWeight Weight (importance) of instance in dataset.
*/
private def mixedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
bins: Array[Array[Bin]],
splits: Array[Array[Split]],
unorderedFeatures: Set[Int],
instanceWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
Expand Down Expand Up @@ -362,7 +364,7 @@ object DecisionTree extends Serializable with Logging {
val numSplits = agg.metadata.numSplits(featureIndex)
var splitIndex = 0
while (splitIndex < numSplits) {
if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
} else {
Expand Down Expand Up @@ -506,8 +508,8 @@ object DecisionTree extends Serializable with Logging {
if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
instanceWeight, featuresForNode)
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, splits,
metadata.unorderedFeatures, instanceWeight, featuresForNode)
}
}
}
Expand Down Expand Up @@ -1024,12 +1026,6 @@ object DecisionTree extends Serializable with Logging {
// Categorical feature
val featureArity = metadata.featureArity(featureIndex)
if (metadata.isUnordered(featureIndex)) {
// TODO: The second half of the bins are unused. Actually, we could just use
// splits and not build bins for unordered features. That should be part of
// a later PR since it will require changing other code (using splits instead
// of bins in a few places).
// Unordered features
// 2^(maxFeatureValue - 1) - 1 combinations
splits(featureIndex) = new Array[Split](numSplits)
bins(featureIndex) = new Array[Bin](numBins)
var splitIndex = 0
Expand All @@ -1038,30 +1034,18 @@ object DecisionTree extends Serializable with Logging {
extractMultiClassCategories(splitIndex + 1, featureArity)
splits(featureIndex)(splitIndex) =
new Split(featureIndex, Double.MinValue, Categorical, categories)
bins(featureIndex)(splitIndex) = {
if (splitIndex == 0) {
new Bin(
new DummyCategoricalSplit(featureIndex, Categorical),
splits(featureIndex)(0),
Categorical,
Double.MinValue)
} else {
new Bin(
splits(featureIndex)(splitIndex - 1),
splits(featureIndex)(splitIndex),
Categorical,
Double.MinValue)
}
}
splitIndex += 1
}
} else {
// Ordered features
// Bins correspond to feature values, so we do not need to compute splits or bins
// beforehand. Splits are constructed as needed during training.
splits(featureIndex) = new Array[Split](0)
bins(featureIndex) = new Array[Bin](0)
}
// For ordered features, bins correspond to feature values.
// For unordered categorical features, there is no need to construct the bins.
// since there is a one-to-one correspondence between the splits and the bins.
bins(featureIndex) = new Array[Bin](0)
}
featureIndex += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ private[tree] object TreePoint {
var featureIndex = 0
while (featureIndex < numFeatures) {
arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
isUnordered(featureIndex), bins)
bins)
featureIndex += 1
}
new TreePoint(labeledPoint.label, arr)
Expand All @@ -96,14 +96,12 @@ private[tree] object TreePoint {
* Find bin for one (labeledPoint, feature).
*
* @param featureArity 0 for continuous features; number of categories for categorical features.
* @param isUnorderedFeature (only applies if feature is categorical)
* @param bins Bins for features, of size (numFeatures, numBins).
*/
private def findBin(
featureIndex: Int,
labeledPoint: LabeledPoint,
featureArity: Int,
isUnorderedFeature: Boolean,
bins: Array[Array[Bin]]): Int = {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(splits.length === 2)
assert(bins.length === 2)
assert(splits(0).length === 3)
assert(bins(0).length === 6)
assert(bins(0).length === 0)

// Expecting 2^2 - 1 = 3 bins/splits
assert(splits(0)(0).feature === 0)
Expand Down Expand Up @@ -228,41 +228,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(splits(1)(2).categories.contains(0.0))
assert(splits(1)(2).categories.contains(1.0))

// Check bins.

assert(bins(0)(0).category === Double.MinValue)
assert(bins(0)(0).lowSplit.categories.length === 0)
assert(bins(0)(0).highSplit.categories.length === 1)
assert(bins(0)(0).highSplit.categories.contains(0.0))
assert(bins(1)(0).category === Double.MinValue)
assert(bins(1)(0).lowSplit.categories.length === 0)
assert(bins(1)(0).highSplit.categories.length === 1)
assert(bins(1)(0).highSplit.categories.contains(0.0))

assert(bins(0)(1).category === Double.MinValue)
assert(bins(0)(1).lowSplit.categories.length === 1)
assert(bins(0)(1).lowSplit.categories.contains(0.0))
assert(bins(0)(1).highSplit.categories.length === 1)
assert(bins(0)(1).highSplit.categories.contains(1.0))
assert(bins(1)(1).category === Double.MinValue)
assert(bins(1)(1).lowSplit.categories.length === 1)
assert(bins(1)(1).lowSplit.categories.contains(0.0))
assert(bins(1)(1).highSplit.categories.length === 1)
assert(bins(1)(1).highSplit.categories.contains(1.0))

assert(bins(0)(2).category === Double.MinValue)
assert(bins(0)(2).lowSplit.categories.length === 1)
assert(bins(0)(2).lowSplit.categories.contains(1.0))
assert(bins(0)(2).highSplit.categories.length === 2)
assert(bins(0)(2).highSplit.categories.contains(1.0))
assert(bins(0)(2).highSplit.categories.contains(0.0))
assert(bins(1)(2).category === Double.MinValue)
assert(bins(1)(2).lowSplit.categories.length === 1)
assert(bins(1)(2).lowSplit.categories.contains(1.0))
assert(bins(1)(2).highSplit.categories.length === 2)
assert(bins(1)(2).highSplit.categories.contains(1.0))
assert(bins(1)(2).highSplit.categories.contains(0.0))

}

test("Multiclass classification with ordered categorical features: split and bin calculations") {
Expand Down

0 comments on commit d3ee042

Please sign in to comment.