@@ -132,7 +132,7 @@ private[tree] object LearningMetadata {
132
132
val unorderedFeatures = new mutable.HashSet [Int ]()
133
133
// numBins[featureIndex] = number of bins for feature
134
134
val numBins = Array .fill[Int ](numFeatures)(maxPossibleBins)
135
- if (numClasses > 2 ) {
135
+ if (numClasses >= 2 ) {
136
136
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
137
137
val numUnorderedBins = DecisionTree .numUnorderedBins(k)
138
138
if (numUnorderedBins < maxPossibleBins) {
@@ -204,7 +204,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
204
204
205
205
timer.findSplitsBinsTime += timer.elapsed()
206
206
207
- /*
208
207
println(s " splits: " )
209
208
for (f <- Range (0 , splits.size)) {
210
209
for (s <- Range (0 , splits(f).size)) {
@@ -217,7 +216,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
217
216
println(s " bins( $f)( $s): ${bins(f)(s)}" )
218
217
}
219
218
}
220
- */
221
219
222
220
timer.reset()
223
221
val treeInput = TreePoint .convertToTreeRDD(retaggedInput, bins, metadata).cache()
@@ -271,7 +269,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
271
269
var break = false
272
270
while (level <= maxDepth && ! break) {
273
271
274
- // println(s"LEVEL $level")
272
+ println(s " LEVEL $level" )
275
273
logDebug(" #####################################" )
276
274
logDebug(" level = " + level)
277
275
logDebug(" #####################################" )
@@ -286,11 +284,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
286
284
287
285
val levelNodeIndexOffset = DecisionTree .maxNodesInLevel(level) - 1
288
286
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
289
- /*
290
287
println(s " splitsStatsForLevel: index= $index" )
291
288
println(s " \t split: ${nodeSplitStats._1}" )
292
289
println(s " \t gain stats: ${nodeSplitStats._2}" )
293
- */
294
290
val nodeIndex = levelNodeIndexOffset + index
295
291
val isLeftChild = level != 0 && nodeIndex % 2 == 1
296
292
val parentNodeIndex = if (isLeftChild) { // -1 for root node
@@ -326,7 +322,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
326
322
}
327
323
require(DecisionTree .maxNodesInLevel(level) == splitsStatsForLevel.length)
328
324
// Check whether all the nodes at the current level at leaves.
329
- // println(s"LOOP over levels: level=$level, splitStats...gains: ${splitsStatsForLevel.map(_._2.gain).mkString(",")}")
325
+ println(s " LOOP over levels: level= $level, splitStats...gains: ${splitsStatsForLevel.map(_._2.gain).mkString(" ," )}" )
330
326
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0 )
331
327
logDebug(" all leaf = " + allLeaf)
332
328
if (allLeaf) {
@@ -798,7 +794,7 @@ object DecisionTree extends Serializable with Logging {
798
794
* @param treePoint Data point being aggregated.
799
795
* @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
800
796
*/
801
- def multiclassWithCategoricalBinSeqOp (
797
+ def someUnorderedBinSeqOp (
802
798
agg : Array [Array [Array [ImpurityAggregator ]]],
803
799
treePoint : TreePoint ,
804
800
nodeIndex : Int ): Unit = {
@@ -885,10 +881,10 @@ object DecisionTree extends Serializable with Logging {
885
881
treePoint : TreePoint ): Array [Array [Array [ImpurityAggregator ]]] = {
886
882
val nodeIndex = treePointToNodeIndex(treePoint)
887
883
if (nodeIndex >= 0 ) { // Otherwise, example does not reach this level.
888
- if (isMulticlassWithCategoricalFeatures) {
889
- multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex)
890
- } else {
884
+ if (metadata.unorderedFeatures.isEmpty) {
891
885
orderedBinSeqOp(agg, treePoint, nodeIndex)
886
+ } else {
887
+ someUnorderedBinSeqOp(agg, treePoint, nodeIndex)
892
888
}
893
889
}
894
890
agg
@@ -926,7 +922,6 @@ object DecisionTree extends Serializable with Logging {
926
922
input.aggregate(initAgg)(binSeqOp, binCombOp)
927
923
}
928
924
929
- /*
930
925
println(" binAggregates:" )
931
926
for (n <- Range (0 , binAggregates.size)) {
932
927
for (f <- Range (0 , binAggregates(n).size)) {
@@ -935,7 +930,6 @@ object DecisionTree extends Serializable with Logging {
935
930
}
936
931
}
937
932
}
938
- */
939
933
940
934
timer.binAggregatesTime += timer.elapsed()
941
935
@@ -1006,11 +1000,10 @@ object DecisionTree extends Serializable with Logging {
1006
1000
1007
1001
val leftImpurity = leftNodeAgg.calculate() // Note: 0 if count = 0
1008
1002
val rightImpurity = rightNodeAgg.calculate()
1009
- /*
1003
+
1010
1004
println(s " calculateGainForSplit " )
1011
1005
println(s " \t leftImpurity = $leftImpurity, leftNodeAgg: $leftNodeAgg" )
1012
1006
println(s " \t rightImpurity = $rightImpurity, rightNodeAgg: $rightNodeAgg" )
1013
- */
1014
1007
1015
1008
val leftWeight = leftCount / totalCount.toDouble
1016
1009
val rightWeight = rightCount / totalCount.toDouble
@@ -1265,20 +1258,16 @@ object DecisionTree extends Serializable with Logging {
1265
1258
case _ => throw new IllegalArgumentException (s " Bad impurity parameter: ${metadata.impurity}" )
1266
1259
}
1267
1260
1268
- val binMultiplier = if (metadata.isMulticlassWithCategoricalFeatures) {
1269
- 2
1270
- } else {
1271
- 1
1272
- }
1273
1261
val agg = Array .fill[Array [ImpurityAggregator ]](numNodes, metadata.numFeatures)(
1274
1262
new Array [ImpurityAggregator ](0 ))
1275
1263
var nodeIndex = 0
1276
1264
while (nodeIndex < numNodes) {
1277
1265
var featureIndex = 0
1278
1266
while (featureIndex < metadata.numFeatures) {
1279
- var binIndex = 0
1267
+ val binMultiplier = if (metadata.isUnordered(featureIndex)) 2 else 1
1280
1268
val effNumBins = metadata.numBins(featureIndex) * binMultiplier
1281
1269
agg(nodeIndex)(featureIndex) = new Array [ImpurityAggregator ](effNumBins)
1270
+ var binIndex = 0
1282
1271
while (binIndex < effNumBins) {
1283
1272
agg(nodeIndex)(featureIndex)(binIndex) = impurityAggregator.newAggregator
1284
1273
binIndex += 1
0 commit comments