Skip to content

Commit adc7315

Browse files
committed
support ordered categorical splits for multiclass classification
1 parent e3e8843 commit adc7315

File tree

2 files changed

+160
-31
lines changed

2 files changed

+160
-31
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 67 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ object DecisionTree extends Serializable with Logging {
516516
* Find bin for one feature.
517517
*/
518518
def findBin(featureIndex: Int, labeledPoint: WeightedLabeledPoint,
519-
isFeatureContinuous: Boolean): Int = {
519+
isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = {
520520
val binForFeatures = bins(featureIndex)
521521
val feature = labeledPoint.features(featureIndex)
522522

@@ -550,14 +550,14 @@ object DecisionTree extends Serializable with Logging {
550550
* splits. The actual left/right child allocation per split is performed in the
551551
* sequential phase of the bin aggregate operation.
552552
*/
553-
def sequentialBinSearchForCategoricalFeatureInMulticlassClassification(): Int = {
553+
def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = {
554554
labeledPoint.features(featureIndex).toInt
555555
}
556556

557557
/**
558558
* Sequential search helper method to find bin for categorical feature.
559559
*/
560-
def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = {
560+
def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = {
561561
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
562562
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
563563
var binIndex = 0
@@ -583,10 +583,10 @@ object DecisionTree extends Serializable with Logging {
583583
} else {
584584
// Perform sequential search to find bin for categorical features.
585585
val binIndex = {
586-
if (isMulticlassClassification) {
587-
sequentialBinSearchForCategoricalFeatureInMulticlassClassification()
586+
if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
587+
sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
588588
} else {
589-
sequentialBinSearchForCategoricalFeatureInBinaryClassification()
589+
sequentialBinSearchForOrderedCategoricalFeatureInClassification()
590590
}
591591
}
592592
if (binIndex == -1){
@@ -622,8 +622,19 @@ object DecisionTree extends Serializable with Logging {
622622
} else {
623623
var featureIndex = 0
624624
while (featureIndex < numFeatures) {
625-
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
626-
arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous)
625+
val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex)
626+
val isFeatureContinuous = featureInfo.isEmpty
627+
if (isFeatureContinuous) {
628+
arr(shift + featureIndex)
629+
= findBin(featureIndex, labeledPoint, isFeatureContinuous, false)
630+
} else {
631+
val featureCategories = featureInfo.get
632+
val isSpaceSufficientForAllCategoricalSplits
633+
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
634+
arr(shift + featureIndex)
635+
= findBin(featureIndex, labeledPoint, isFeatureContinuous,
636+
isSpaceSufficientForAllCategoricalSplits)
637+
}
627638
featureIndex += 1
628639
}
629640
}
@@ -731,12 +742,19 @@ object DecisionTree extends Serializable with Logging {
731742
// Iterate over all features.
732743
var featureIndex = 0
733744
while (featureIndex < numFeatures) {
734-
val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
735-
if (isContinuousFeature) {
745+
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
746+
if (isFeatureContinuous) {
736747
updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
737748
} else {
738-
updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift)
749+
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
750+
val isSpaceSufficientForAllCategoricalSplits
751+
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
752+
if (isSpaceSufficientForAllCategoricalSplits) {
753+
updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift)
754+
} else {
755+
updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
739756
}
757+
}
740758
featureIndex += 1
741759
}
742760
}
@@ -1093,7 +1111,14 @@ object DecisionTree extends Serializable with Logging {
10931111
if (isFeatureContinuous) {
10941112
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
10951113
} else {
1096-
findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
1114+
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1115+
val isSpaceSufficientForAllCategoricalSplits
1116+
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
1117+
if (isSpaceSufficientForAllCategoricalSplits) {
1118+
findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
1119+
} else {
1120+
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
1121+
}
10971122
}
10981123
} else {
10991124
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
@@ -1168,7 +1193,9 @@ object DecisionTree extends Serializable with Logging {
11681193
numBins - 1
11691194
} else { // Categorical feature
11701195
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1171-
if (isMulticlassClassification) {
1196+
val isSpaceSufficientForAllCategoricalSplits
1197+
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
1198+
if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
11721199
math.pow(2.0, featureCategories - 1).toInt - 1
11731200
} else { // Binary classification
11741201
featureCategories
@@ -1289,11 +1316,6 @@ object DecisionTree extends Serializable with Logging {
12891316
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
12901317
require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " +
12911318
"in categorical features")
1292-
if (isMulticlassClassification) {
1293-
require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1,
1294-
"numBins should be greater than 2^(maxNumCategories-1) -1 for multiclass classification" +
1295-
" with categorical variables")
1296-
}
12971319
}
12981320

12991321

@@ -1332,10 +1354,12 @@ object DecisionTree extends Serializable with Logging {
13321354
}
13331355
} else { // Categorical feature
13341356
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1357+
val isSpaceSufficientForAllCategoricalSplits
1358+
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
13351359

13361360
// Use different bin/split calculation strategy for categorical features in multiclass
1337-
// classification
1338-
if (isMulticlassClassification) {
1361+
// classification that satisfy the space constraint
1362+
if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
13391363
// 2^(maxFeatureValue- 1) - 1 combinations
13401364
var index = 0
13411365
while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {
@@ -1360,14 +1384,29 @@ object DecisionTree extends Serializable with Logging {
13601384
}
13611385
index += 1
13621386
}
1363-
} else { // regression or binary classification
1364-
1365-
// For categorical variables, each bin is a category. The bins are sorted and they
1366-
// are ordered by calculating the centroid of their corresponding labels.
1367-
val centroidForCategories =
1368-
sampledInput.map(lp => (lp.features(featureIndex),lp.label))
1369-
.groupBy(_._1)
1370-
.mapValues(x => x.map(_._2).sum / x.map(_._1).length)
1387+
} else {
1388+
1389+
val centroidForCategories = {
1390+
if (isMulticlassClassification) {
1391+
// For categorical variables in multiclass classification,
1392+
// each bin is a category. The bins are sorted and they
1393+
// are ordered by calculating the impurity of their corresponding labels.
1394+
sampledInput.map(lp => (lp.features(featureIndex), lp.label))
1395+
.groupBy(_._1)
1396+
.mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble))
1397+
.map(x => (x._1, x._2.values.toArray))
1398+
.map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum)))
1399+
} else { // regression or binary classification
1400+
// For categorical variables in regression and binary classification,
1401+
// each bin is a category. The bins are sorted and they
1402+
// are ordered by calculating the centroid of their corresponding labels.
1403+
sampledInput.map(lp => (lp.features(featureIndex), lp.label))
1404+
.groupBy(_._1)
1405+
.mapValues(x => x.map(_._2).sum / x.map(_._1).length)
1406+
}
1407+
}
1408+
1409+
logDebug("centriod for categories = " + centroidForCategories.mkString(","))
13711410

13721411
// Check for missing categorical variables and putting them last in the sorted list.
13731412
val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
239239
assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq)
240240
}
241241

242-
test("split and bin calculations for categorical variables with multiclass classification") {
242+
test("split and bin calculations for unordered categorical variables with multiclass " +
243+
"classification") {
243244
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
244245
assert(arr.length === 1000)
245246
val rdd = sc.parallelize(arr)
@@ -332,6 +333,62 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
332333

333334
}
334335

336+
test("split and bin calculations for ordered categorical variables with multiclass " +
337+
"classification") {
338+
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
339+
assert(arr.length === 3000)
340+
val rdd = sc.parallelize(arr)
341+
val strategy = new Strategy(
342+
Classification,
343+
Gini,
344+
maxDepth = 3,
345+
numClassesForClassification = 100,
346+
maxBins = 100,
347+
categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
348+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
349+
350+
// 2^10 - 1 > 100, so categorical variables will be ordered
351+
352+
assert(splits(0)(0).feature === 0)
353+
assert(splits(0)(0).threshold === Double.MinValue)
354+
assert(splits(0)(0).featureType === Categorical)
355+
assert(splits(0)(0).categories.length === 1)
356+
assert(splits(0)(0).categories.contains(1.0))
357+
358+
assert(splits(0)(1).feature === 0)
359+
assert(splits(0)(1).threshold === Double.MinValue)
360+
assert(splits(0)(1).featureType === Categorical)
361+
assert(splits(0)(1).categories.length === 2)
362+
assert(splits(0)(1).categories.contains(2.0))
363+
364+
assert(splits(0)(2).feature === 0)
365+
assert(splits(0)(2).threshold === Double.MinValue)
366+
assert(splits(0)(2).featureType === Categorical)
367+
assert(splits(0)(2).categories.length === 3)
368+
assert(splits(0)(2).categories.contains(2.0))
369+
assert(splits(0)(2).categories.contains(1.0))
370+
371+
assert(splits(0)(10) === null)
372+
assert(splits(1)(10) === null)
373+
374+
375+
// Check bins.
376+
377+
assert(bins(0)(0).category === 1.0)
378+
assert(bins(0)(0).lowSplit.categories.length === 0)
379+
assert(bins(0)(0).highSplit.categories.length === 1)
380+
assert(bins(0)(0).highSplit.categories.contains(1.0))
381+
assert(bins(0)(1).category === 2.0)
382+
assert(bins(0)(1).lowSplit.categories.length === 1)
383+
assert(bins(0)(1).highSplit.categories.length === 2)
384+
assert(bins(0)(1).highSplit.categories.contains(1.0))
385+
assert(bins(0)(1).highSplit.categories.contains(2.0))
386+
387+
assert(bins(0)(10) === null)
388+
389+
}
390+
391+
335392
test("classification stump with all categorical variables") {
336393
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
337394
assert(arr.length === 1000)
@@ -547,7 +604,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
547604
assert(bestSplit.featureType === Categorical)
548605
}
549606

550-
551607
test("stump with continuous variables for multiclass classification") {
552608
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
553609
val input = sc.parallelize(arr)
@@ -568,7 +624,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
568624

569625
}
570626

571-
572627
test("stump with continuous + categorical variables for multiclass classification") {
573628
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
574629
val input = sc.parallelize(arr)
@@ -588,6 +643,26 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
588643
assert(bestSplit.threshold < 2020)
589644
}
590645

646+
test("stump with categorical variables for ordered multiclass classification") {
647+
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
648+
val input = sc.parallelize(arr)
649+
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
650+
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
651+
assert(strategy.isMulticlassClassification)
652+
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
653+
val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
654+
Array[List[Filter]](), splits, bins, 10)
655+
656+
assert(bestSplits.length === 1)
657+
val bestSplit = bestSplits(0)._1
658+
assert(bestSplit.feature === 0)
659+
assert(bestSplit.categories.length === 1)
660+
println(bestSplit)
661+
assert(bestSplit.categories.contains(1.0))
662+
assert(bestSplit.featureType === Categorical)
663+
}
664+
665+
591666
}
592667

593668
object DecisionTreeSuite {
@@ -662,5 +737,20 @@ object DecisionTreeSuite {
662737
arr
663738
}
664739

740+
def generateCategoricalDataPointsForMulticlassForOrderedFeatures():
741+
Array[WeightedLabeledPoint] = {
742+
val arr = new Array[WeightedLabeledPoint](3000)
743+
for (i <- 0 until 3000) {
744+
if (i < 1000) {
745+
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0))
746+
} else if (i < 2000) {
747+
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0))
748+
} else {
749+
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, 2.0))
750+
}
751+
}
752+
arr
753+
}
754+
665755

666756
}

0 commit comments

Comments
 (0)