Skip to content

Commit 61c4509

Browse files
committed
Fixed bugs from merge: missing DT timer call, and numBins setting. Cleaned up DT Suite some.
1 parent 3ba7166 commit 61c4509

File tree

3 files changed

+24
-60
lines changed

3 files changed

+24
-60
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,7 @@ object DecisionTree extends Serializable with Logging {
823823

824824
nodeIndex += 1
825825
}
826+
timer.stop("chooseSplits")
826827

827828
bestSplits
828829
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ private[tree] object DecisionTreeMetadata {
112112
require(k < maxPossibleBins,
113113
s"maxBins (= $maxPossibleBins) should be greater than max categories " +
114114
s"in categorical features (>= $k)")
115+
numBins(f) = k
115116
}
116117
}
117118
} else {

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 22 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
6060
assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
6161
}
6262

63-
test("split and bin calculation for continuous features") {
63+
test("Binary classification with continuous features: split and bin calculation") {
6464
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
6565
assert(arr.length === 1000)
6666
val rdd = sc.parallelize(arr)
6767
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
6868
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
69+
assert(!metadata.isUnordered(featureIndex = 0))
6970
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
7071
assert(splits.length === 2)
7172
assert(bins.length === 2)
7273
assert(splits(0).length === 99)
7374
assert(bins(0).length === 100)
7475
}
7576

76-
test("split and bin calculation for binary features") {
77+
test("Binary classification with binary features: split and bin calculation") {
7778
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
7879
assert(arr.length === 1000)
7980
val rdd = sc.parallelize(arr)
@@ -100,32 +101,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
100101
assert(splits(0)(0).threshold === Double.MinValue)
101102
assert(splits(0)(0).featureType === Categorical)
102103
assert(splits(0)(0).categories.length === 1)
103-
//println(s"splits(0)(0).categories: ${splits(0)(0).categories}")
104104
assert(splits(0)(0).categories.contains(1.0))
105105

106-
/*
107-
assert(splits(0)(1).feature === 0)
108-
assert(splits(0)(1).threshold === Double.MinValue)
109-
assert(splits(0)(1).featureType === Categorical)
110-
assert(splits(0)(1).categories.length === 2)
111-
assert(splits(0)(1).categories.contains(0.0))
112-
assert(splits(0)(1).categories.contains(1.0))
113-
*/
114-
115106
assert(splits(1)(0).feature === 1)
116107
assert(splits(1)(0).threshold === Double.MinValue)
117108
assert(splits(1)(0).featureType === Categorical)
118109
assert(splits(1)(0).categories.length === 1)
119110
assert(splits(1)(0).categories.contains(0.0))
120111

121-
/*
122-
assert(splits(1)(1).feature === 1)
123-
assert(splits(1)(1).threshold === Double.MinValue)
124-
assert(splits(1)(1).featureType === Categorical)
125-
assert(splits(1)(1).categories.length === 2)
126-
assert(splits(1)(1).categories.contains(0.0))
127-
assert(splits(1)(1).categories.contains(1.0))
128-
*/
129112
// Check bins.
130113

131114
assert(bins(0)(0).lowSplit.categories.length === 0)
@@ -185,16 +168,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
185168
assert(splits(0)(1).categories.contains(0.0))
186169
assert(splits(0)(1).categories.contains(1.0))
187170

188-
/*
189-
assert(splits(0)(2).feature === 0)
190-
assert(splits(0)(2).threshold === Double.MinValue)
191-
assert(splits(0)(2).featureType === Categorical)
192-
assert(splits(0)(2).categories.length === 3)
193-
assert(splits(0)(2).categories.contains(0.0))
194-
assert(splits(0)(2).categories.contains(1.0))
195-
assert(splits(0)(2).categories.contains(2.0))
196-
*/
197-
198171
assert(splits(1)(0).feature === 1)
199172
assert(splits(1)(0).threshold === Double.MinValue)
200173
assert(splits(1)(0).featureType === Categorical)
@@ -208,16 +181,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
208181
assert(splits(1)(1).categories.contains(0.0))
209182
assert(splits(1)(1).categories.contains(1.0))
210183

211-
/*
212-
assert(splits(1)(2).feature === 1)
213-
assert(splits(1)(2).threshold === Double.MinValue)
214-
assert(splits(1)(2).featureType === Categorical)
215-
assert(splits(1)(2).categories.length === 3)
216-
assert(splits(1)(2).categories.contains(0.0))
217-
assert(splits(1)(2).categories.contains(1.0))
218-
assert(splits(1)(2).categories.contains(2.0))
219-
*/
220-
221184
// Check bins.
222185

223186
assert(bins(0)(0).lowSplit.categories.length === 0)
@@ -260,8 +223,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
260223
assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq)
261224
}
262225

263-
test("split and bin calculations for unordered categorical variables with multiclass " +
264-
"classification") {
226+
test("Multiclass classification with unordered categorical features:" +
227+
" split and bin calculations") {
265228
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
266229
assert(arr.length === 1000)
267230
val rdd = sc.parallelize(arr)
@@ -355,8 +318,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
355318

356319
}
357320

358-
test("split and bin calculations for ordered categorical variables with multiclass " +
359-
"classification") {
321+
test("Multiclass classification with ordered categorical features: split and bin calculations") {
360322
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
361323
assert(arr.length === 3000)
362324
val rdd = sc.parallelize(arr)
@@ -377,7 +339,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
377339
assert(splits(0).length === 9)
378340
assert(bins(0).length === 10)
379341

380-
// 2^10 - 1 > 100, so categorical variables will be ordered
342+
// 2^10 - 1 > 100, so categorical features will be ordered
381343

382344
assert(splits(0)(0).feature === 0)
383345
assert(splits(0)(0).threshold === Double.MinValue)
@@ -413,7 +375,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
413375
}
414376

415377

416-
test("classification stump with all ordered categorical variables") {
378+
test("Binary classification stump with all ordered categorical features") {
417379
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
418380
assert(arr.length === 1000)
419381
val rdd = sc.parallelize(arr)
@@ -450,7 +412,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
450412
assert(stats.impurity > 0.2)
451413
}
452414

453-
test("regression stump with all categorical variables") {
415+
test("Regression stump with 3-ary categorical features") {
454416
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
455417
assert(arr.length === 1000)
456418
val rdd = sc.parallelize(arr)
@@ -482,7 +444,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
482444
assert(stats.impurity > 0.2)
483445
}
484446

485-
test("regression stump with categorical variables of arity 2") {
447+
test("Regression stump with binary categorical features") {
486448
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
487449
assert(arr.length === 1000)
488450
val rdd = sc.parallelize(arr)
@@ -502,7 +464,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
502464
assert(model.depth === 1)
503465
}
504466

505-
test("stump with fixed label 0 for Gini") {
467+
test("Binary classification stump with fixed label 0 for Gini") {
506468
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
507469
assert(arr.length === 1000)
508470
val rdd = sc.parallelize(arr)
@@ -530,7 +492,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
530492
assert(bestSplits(0)._2.rightImpurity === 0)
531493
}
532494

533-
test("stump with fixed label 1 for Gini") {
495+
test("Binary classification stump with fixed label 1 for Gini") {
534496
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
535497
assert(arr.length === 1000)
536498
val rdd = sc.parallelize(arr)
@@ -559,7 +521,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
559521
assert(bestSplits(0)._2.predict === 1)
560522
}
561523

562-
test("stump with fixed label 0 for Entropy") {
524+
test("Binary classification stump with fixed label 0 for Entropy") {
563525
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
564526
assert(arr.length === 1000)
565527
val rdd = sc.parallelize(arr)
@@ -588,7 +550,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
588550
assert(bestSplits(0)._2.predict === 0)
589551
}
590552

591-
test("stump with fixed label 1 for Entropy") {
553+
test("Binary classification stump with fixed label 1 for Entropy") {
592554
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
593555
assert(arr.length === 1000)
594556
val rdd = sc.parallelize(arr)
@@ -617,7 +579,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
617579
assert(bestSplits(0)._2.predict === 1)
618580
}
619581

620-
test("second level node building with/without groups") {
582+
test("Second level node building with vs. without groups") {
621583
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
622584
assert(arr.length === 1000)
623585
val rdd = sc.parallelize(arr)
@@ -669,7 +631,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
669631
}
670632
}
671633

672-
test("stump with categorical variables for multiclass classification") {
634+
test("Multiclass classification stump with 3-ary categorical features") {
673635
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
674636
val rdd = sc.parallelize(arr)
675637
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
@@ -692,7 +654,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
692654
assert(bestSplit.featureType === Categorical)
693655
}
694656

695-
test("stump with 1 continuous variable for binary classification, to check off-by-1 error") {
657+
test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
696658
val arr = new Array[LabeledPoint](4)
697659
arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0))
698660
arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
@@ -708,7 +670,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
708670
assert(model.depth === 1)
709671
}
710672

711-
test("stump with 2 continuous variables for binary classification") {
673+
test("Binary classification stump with 2 continuous features") {
712674
val arr = new Array[LabeledPoint](4)
713675
arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
714676
arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
@@ -726,7 +688,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
726688
assert(model.topNode.split.get.feature === 1)
727689
}
728690

729-
test("stump with categorical variables for multiclass classification, with just enough bins") {
691+
test("Multiclass classification stump with categorical features, with just enough bins") {
730692
val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features
731693
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
732694
val rdd = sc.parallelize(arr)
@@ -757,7 +719,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
757719
assert(gain.rightImpurity === 0)
758720
}
759721

760-
test("stump with continuous variables for multiclass classification") {
722+
test("Multiclass classification stump with continuous features") {
761723
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
762724
val rdd = sc.parallelize(arr)
763725
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
@@ -783,7 +745,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
783745

784746
}
785747

786-
test("stump with continuous + categorical variables for multiclass classification") {
748+
test("Multiclass classification stump with continuous + categorical features") {
787749
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
788750
val rdd = sc.parallelize(arr)
789751
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
@@ -808,7 +770,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
808770
assert(bestSplit.threshold < 2020)
809771
}
810772

811-
test("stump with categorical variables for ordered multiclass classification") {
773+
test("Multiclass classification stump with 10-ary (ordered) categorical features") {
812774
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
813775
val rdd = sc.parallelize(arr)
814776
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,

0 commit comments

Comments
 (0)