Skip to content

Commit dd4d3aa

Browse files
committed
Mid-process in bug fix: bug for binary classification with categorical features
* Bug: Categorical features were all treated as ordered for binary classification. This is possible but would require the bin ordering to be determined on-the-fly after the aggregation. Currently, the ordering is determined a priori and fixed for all splits. * (Temp) Fix: Treat low-arity categorical features as unordered for binary classification. * Related change: I removed most tests for isMulticlass in the code. I instead test metadata for whether there are unordered features. * Status: The bug may be fixed, but more testing needs to be done. Aggregates: The same binMultiplier (for ordered vs. unordered) was applied to all features. It is now applied on a per-feature basis.
1 parent 438a660 commit dd4d3aa

File tree

1 file changed

+10
-21
lines changed

1 file changed

+10
-21
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ private[tree] object LearningMetadata {
132132
val unorderedFeatures = new mutable.HashSet[Int]()
133133
// numBins[featureIndex] = number of bins for feature
134134
val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
135-
if (numClasses > 2) {
135+
if (numClasses >= 2) {
136136
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
137137
val numUnorderedBins = DecisionTree.numUnorderedBins(k)
138138
if (numUnorderedBins < maxPossibleBins) {
@@ -204,7 +204,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
204204

205205
timer.findSplitsBinsTime += timer.elapsed()
206206

207-
/*
208207
println(s"splits:")
209208
for (f <- Range(0, splits.size)) {
210209
for (s <- Range(0, splits(f).size)) {
@@ -217,7 +216,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
217216
println(s" bins($f)($s): ${bins(f)(s)}")
218217
}
219218
}
220-
*/
221219

222220
timer.reset()
223221
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata).cache()
@@ -271,7 +269,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
271269
var break = false
272270
while (level <= maxDepth && !break) {
273271

274-
//println(s"LEVEL $level")
272+
println(s"LEVEL $level")
275273
logDebug("#####################################")
276274
logDebug("level = " + level)
277275
logDebug("#####################################")
@@ -286,11 +284,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
286284

287285
val levelNodeIndexOffset = DecisionTree.maxNodesInLevel(level) - 1
288286
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
289-
/*
290287
println(s"splitsStatsForLevel: index=$index")
291288
println(s"\t split: ${nodeSplitStats._1}")
292289
println(s"\t gain stats: ${nodeSplitStats._2}")
293-
*/
294290
val nodeIndex = levelNodeIndexOffset + index
295291
val isLeftChild = level != 0 && nodeIndex % 2 == 1
296292
val parentNodeIndex = if (isLeftChild) { // -1 for root node
@@ -326,7 +322,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
326322
}
327323
require(DecisionTree.maxNodesInLevel(level) == splitsStatsForLevel.length)
328324
// Check whether all the nodes at the current level at leaves.
329-
//println(s"LOOP over levels: level=$level, splitStats...gains: ${splitsStatsForLevel.map(_._2.gain).mkString(",")}")
325+
println(s"LOOP over levels: level=$level, splitStats...gains: ${splitsStatsForLevel.map(_._2.gain).mkString(",")}")
330326
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
331327
logDebug("all leaf = " + allLeaf)
332328
if (allLeaf) {
@@ -798,7 +794,7 @@ object DecisionTree extends Serializable with Logging {
798794
* @param treePoint Data point being aggregated.
799795
* @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
800796
*/
801-
def multiclassWithCategoricalBinSeqOp(
797+
def someUnorderedBinSeqOp(
802798
agg: Array[Array[Array[ImpurityAggregator]]],
803799
treePoint: TreePoint,
804800
nodeIndex: Int): Unit = {
@@ -885,10 +881,10 @@ object DecisionTree extends Serializable with Logging {
885881
treePoint: TreePoint): Array[Array[Array[ImpurityAggregator]]] = {
886882
val nodeIndex = treePointToNodeIndex(treePoint)
887883
if (nodeIndex >= 0) { // Otherwise, example does not reach this level.
888-
if (isMulticlassWithCategoricalFeatures) {
889-
multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex)
890-
} else {
884+
if (metadata.unorderedFeatures.isEmpty) {
891885
orderedBinSeqOp(agg, treePoint, nodeIndex)
886+
} else {
887+
someUnorderedBinSeqOp(agg, treePoint, nodeIndex)
892888
}
893889
}
894890
agg
@@ -926,7 +922,6 @@ object DecisionTree extends Serializable with Logging {
926922
input.aggregate(initAgg)(binSeqOp, binCombOp)
927923
}
928924

929-
/*
930925
println("binAggregates:")
931926
for (n <- Range(0, binAggregates.size)) {
932927
for (f <- Range(0, binAggregates(n).size)) {
@@ -935,7 +930,6 @@ object DecisionTree extends Serializable with Logging {
935930
}
936931
}
937932
}
938-
*/
939933

940934
timer.binAggregatesTime += timer.elapsed()
941935

@@ -1006,11 +1000,10 @@ object DecisionTree extends Serializable with Logging {
10061000

10071001
val leftImpurity = leftNodeAgg.calculate() // Note: 0 if count = 0
10081002
val rightImpurity = rightNodeAgg.calculate()
1009-
/*
1003+
10101004
println(s"calculateGainForSplit")
10111005
println(s"\t leftImpurity = $leftImpurity, leftNodeAgg: $leftNodeAgg")
10121006
println(s"\t rightImpurity = $rightImpurity, rightNodeAgg: $rightNodeAgg")
1013-
*/
10141007

10151008
val leftWeight = leftCount / totalCount.toDouble
10161009
val rightWeight = rightCount / totalCount.toDouble
@@ -1265,20 +1258,16 @@ object DecisionTree extends Serializable with Logging {
12651258
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
12661259
}
12671260

1268-
val binMultiplier = if (metadata.isMulticlassWithCategoricalFeatures) {
1269-
2
1270-
} else {
1271-
1
1272-
}
12731261
val agg = Array.fill[Array[ImpurityAggregator]](numNodes, metadata.numFeatures)(
12741262
new Array[ImpurityAggregator](0))
12751263
var nodeIndex = 0
12761264
while (nodeIndex < numNodes) {
12771265
var featureIndex = 0
12781266
while (featureIndex < metadata.numFeatures) {
1279-
var binIndex = 0
1267+
val binMultiplier = if (metadata.isUnordered(featureIndex)) 2 else 1
12801268
val effNumBins = metadata.numBins(featureIndex) * binMultiplier
12811269
agg(nodeIndex)(featureIndex) = new Array[ImpurityAggregator](effNumBins)
1270+
var binIndex = 0
12821271
while (binIndex < effNumBins) {
12831272
agg(nodeIndex)(featureIndex)(binIndex) = impurityAggregator.newAggregator
12841273
binIndex += 1

0 commit comments

Comments
 (0)