@@ -60,20 +60,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
60
60
assert(mse <= requiredMSE, s " validateRegressor calculated MSE $mse but required $requiredMSE. " )
61
61
}
62
62
63
- test(" split and bin calculation for continuous features " ) {
63
+ test(" Binary classification with continuous features: split and bin calculation" ) {
64
64
val arr = DecisionTreeSuite .generateOrderedLabeledPointsWithLabel1()
65
65
assert(arr.length === 1000 )
66
66
val rdd = sc.parallelize(arr)
67
67
val strategy = new Strategy (Classification , Gini , 3 , 2 , 100 )
68
68
val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
69
+ assert(! metadata.isUnordered(featureIndex = 0 ))
69
70
val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
70
71
assert(splits.length === 2 )
71
72
assert(bins.length === 2 )
72
73
assert(splits(0 ).length === 99 )
73
74
assert(bins(0 ).length === 100 )
74
75
}
75
76
76
- test(" split and bin calculation for binary features " ) {
77
+ test(" Binary classification with binary features: split and bin calculation" ) {
77
78
val arr = DecisionTreeSuite .generateCategoricalDataPoints()
78
79
assert(arr.length === 1000 )
79
80
val rdd = sc.parallelize(arr)
@@ -100,32 +101,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
100
101
assert(splits(0 )(0 ).threshold === Double .MinValue )
101
102
assert(splits(0 )(0 ).featureType === Categorical )
102
103
assert(splits(0 )(0 ).categories.length === 1 )
103
- // println(s"splits(0)(0).categories: ${splits(0)(0).categories}")
104
104
assert(splits(0 )(0 ).categories.contains(1.0 ))
105
105
106
- /*
107
- assert(splits(0)(1).feature === 0)
108
- assert(splits(0)(1).threshold === Double.MinValue)
109
- assert(splits(0)(1).featureType === Categorical)
110
- assert(splits(0)(1).categories.length === 2)
111
- assert(splits(0)(1).categories.contains(0.0))
112
- assert(splits(0)(1).categories.contains(1.0))
113
- */
114
-
115
106
assert(splits(1 )(0 ).feature === 1 )
116
107
assert(splits(1 )(0 ).threshold === Double .MinValue )
117
108
assert(splits(1 )(0 ).featureType === Categorical )
118
109
assert(splits(1 )(0 ).categories.length === 1 )
119
110
assert(splits(1 )(0 ).categories.contains(0.0 ))
120
111
121
- /*
122
- assert(splits(1)(1).feature === 1)
123
- assert(splits(1)(1).threshold === Double.MinValue)
124
- assert(splits(1)(1).featureType === Categorical)
125
- assert(splits(1)(1).categories.length === 2)
126
- assert(splits(1)(1).categories.contains(0.0))
127
- assert(splits(1)(1).categories.contains(1.0))
128
- */
129
112
// Check bins.
130
113
131
114
assert(bins(0 )(0 ).lowSplit.categories.length === 0 )
@@ -185,16 +168,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
185
168
assert(splits(0 )(1 ).categories.contains(0.0 ))
186
169
assert(splits(0 )(1 ).categories.contains(1.0 ))
187
170
188
- /*
189
- assert(splits(0)(2).feature === 0)
190
- assert(splits(0)(2).threshold === Double.MinValue)
191
- assert(splits(0)(2).featureType === Categorical)
192
- assert(splits(0)(2).categories.length === 3)
193
- assert(splits(0)(2).categories.contains(0.0))
194
- assert(splits(0)(2).categories.contains(1.0))
195
- assert(splits(0)(2).categories.contains(2.0))
196
- */
197
-
198
171
assert(splits(1 )(0 ).feature === 1 )
199
172
assert(splits(1 )(0 ).threshold === Double .MinValue )
200
173
assert(splits(1 )(0 ).featureType === Categorical )
@@ -208,16 +181,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
208
181
assert(splits(1 )(1 ).categories.contains(0.0 ))
209
182
assert(splits(1 )(1 ).categories.contains(1.0 ))
210
183
211
- /*
212
- assert(splits(1)(2).feature === 1)
213
- assert(splits(1)(2).threshold === Double.MinValue)
214
- assert(splits(1)(2).featureType === Categorical)
215
- assert(splits(1)(2).categories.length === 3)
216
- assert(splits(1)(2).categories.contains(0.0))
217
- assert(splits(1)(2).categories.contains(1.0))
218
- assert(splits(1)(2).categories.contains(2.0))
219
- */
220
-
221
184
// Check bins.
222
185
223
186
assert(bins(0 )(0 ).lowSplit.categories.length === 0 )
@@ -260,8 +223,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
260
223
assert(List (3.0 , 2.0 , 0.0 ).toSeq === l.toSeq)
261
224
}
262
225
263
- test(" split and bin calculations for unordered categorical variables with multiclass " +
264
- " classification " ) {
226
+ test(" Multiclass classification with unordered categorical features: " +
227
+ " split and bin calculations " ) {
265
228
val arr = DecisionTreeSuite .generateCategoricalDataPoints()
266
229
assert(arr.length === 1000 )
267
230
val rdd = sc.parallelize(arr)
@@ -355,8 +318,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
355
318
356
319
}
357
320
358
- test(" split and bin calculations for ordered categorical variables with multiclass " +
359
- " classification" ) {
321
+ test(" Multiclass classification with ordered categorical features: split and bin calculations" ) {
360
322
val arr = DecisionTreeSuite .generateCategoricalDataPointsForMulticlassForOrderedFeatures()
361
323
assert(arr.length === 3000 )
362
324
val rdd = sc.parallelize(arr)
@@ -377,7 +339,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
377
339
assert(splits(0 ).length === 9 )
378
340
assert(bins(0 ).length === 10 )
379
341
380
- // 2^10 - 1 > 100, so categorical variables will be ordered
342
+ // 2^10 - 1 > 100, so categorical features will be ordered
381
343
382
344
assert(splits(0 )(0 ).feature === 0 )
383
345
assert(splits(0 )(0 ).threshold === Double .MinValue )
@@ -413,7 +375,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
413
375
}
414
376
415
377
416
- test(" classification stump with all ordered categorical variables " ) {
378
+ test(" Binary classification stump with all ordered categorical features " ) {
417
379
val arr = DecisionTreeSuite .generateCategoricalDataPoints()
418
380
assert(arr.length === 1000 )
419
381
val rdd = sc.parallelize(arr)
@@ -450,7 +412,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
450
412
assert(stats.impurity > 0.2 )
451
413
}
452
414
453
- test(" regression stump with all categorical variables " ) {
415
+ test(" Regression stump with 3-ary categorical features " ) {
454
416
val arr = DecisionTreeSuite .generateCategoricalDataPoints()
455
417
assert(arr.length === 1000 )
456
418
val rdd = sc.parallelize(arr)
@@ -482,7 +444,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
482
444
assert(stats.impurity > 0.2 )
483
445
}
484
446
485
- test(" regression stump with categorical variables of arity 2 " ) {
447
+ test(" Regression stump with binary categorical features " ) {
486
448
val arr = DecisionTreeSuite .generateCategoricalDataPoints()
487
449
assert(arr.length === 1000 )
488
450
val rdd = sc.parallelize(arr)
@@ -502,7 +464,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
502
464
assert(model.depth === 1 )
503
465
}
504
466
505
- test(" stump with fixed label 0 for Gini" ) {
467
+ test(" Binary classification stump with fixed label 0 for Gini" ) {
506
468
val arr = DecisionTreeSuite .generateOrderedLabeledPointsWithLabel0()
507
469
assert(arr.length === 1000 )
508
470
val rdd = sc.parallelize(arr)
@@ -530,7 +492,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
530
492
assert(bestSplits(0 )._2.rightImpurity === 0 )
531
493
}
532
494
533
- test(" stump with fixed label 1 for Gini" ) {
495
+ test(" Binary classification stump with fixed label 1 for Gini" ) {
534
496
val arr = DecisionTreeSuite .generateOrderedLabeledPointsWithLabel1()
535
497
assert(arr.length === 1000 )
536
498
val rdd = sc.parallelize(arr)
@@ -559,7 +521,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
559
521
assert(bestSplits(0 )._2.predict === 1 )
560
522
}
561
523
562
- test(" stump with fixed label 0 for Entropy" ) {
524
+ test(" Binary classification stump with fixed label 0 for Entropy" ) {
563
525
val arr = DecisionTreeSuite .generateOrderedLabeledPointsWithLabel0()
564
526
assert(arr.length === 1000 )
565
527
val rdd = sc.parallelize(arr)
@@ -588,7 +550,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
588
550
assert(bestSplits(0 )._2.predict === 0 )
589
551
}
590
552
591
- test(" stump with fixed label 1 for Entropy" ) {
553
+ test(" Binary classification stump with fixed label 1 for Entropy" ) {
592
554
val arr = DecisionTreeSuite .generateOrderedLabeledPointsWithLabel1()
593
555
assert(arr.length === 1000 )
594
556
val rdd = sc.parallelize(arr)
@@ -617,7 +579,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
617
579
assert(bestSplits(0 )._2.predict === 1 )
618
580
}
619
581
620
- test(" second level node building with/ without groups" ) {
582
+ test(" Second level node building with vs. without groups" ) {
621
583
val arr = DecisionTreeSuite .generateOrderedLabeledPoints()
622
584
assert(arr.length === 1000 )
623
585
val rdd = sc.parallelize(arr)
@@ -669,7 +631,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
669
631
}
670
632
}
671
633
672
- test(" stump with categorical variables for multiclass classification " ) {
634
+ test(" Multiclass classification stump with 3-ary categorical features " ) {
673
635
val arr = DecisionTreeSuite .generateCategoricalDataPointsForMulticlass()
674
636
val rdd = sc.parallelize(arr)
675
637
val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 4 ,
@@ -692,7 +654,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
692
654
assert(bestSplit.featureType === Categorical )
693
655
}
694
656
695
- test(" stump with 1 continuous variable for binary classification , to check off-by-1 error" ) {
657
+ test(" Binary classification stump with 1 continuous feature , to check off-by-1 error" ) {
696
658
val arr = new Array [LabeledPoint ](4 )
697
659
arr(0 ) = new LabeledPoint (0.0 , Vectors .dense(0.0 ))
698
660
arr(1 ) = new LabeledPoint (1.0 , Vectors .dense(1.0 ))
@@ -708,7 +670,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
708
670
assert(model.depth === 1 )
709
671
}
710
672
711
- test(" stump with 2 continuous variables for binary classification " ) {
673
+ test(" Binary classification stump with 2 continuous features " ) {
712
674
val arr = new Array [LabeledPoint ](4 )
713
675
arr(0 ) = new LabeledPoint (0.0 , Vectors .sparse(2 , Seq ((0 , 0.0 ))))
714
676
arr(1 ) = new LabeledPoint (1.0 , Vectors .sparse(2 , Seq ((1 , 1.0 ))))
@@ -726,7 +688,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
726
688
assert(model.topNode.split.get.feature === 1 )
727
689
}
728
690
729
- test(" stump with categorical variables for multiclass classification , with just enough bins" ) {
691
+ test(" Multiclass classification stump with categorical features , with just enough bins" ) {
730
692
val maxBins = math.pow(2 , 3 - 1 ).toInt // just enough bins to allow unordered features
731
693
val arr = DecisionTreeSuite .generateCategoricalDataPointsForMulticlass()
732
694
val rdd = sc.parallelize(arr)
@@ -757,7 +719,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
757
719
assert(gain.rightImpurity === 0 )
758
720
}
759
721
760
- test(" stump with continuous variables for multiclass classification " ) {
722
+ test(" Multiclass classification stump with continuous features " ) {
761
723
val arr = DecisionTreeSuite .generateContinuousDataPointsForMulticlass()
762
724
val rdd = sc.parallelize(arr)
763
725
val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 4 ,
@@ -783,7 +745,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
783
745
784
746
}
785
747
786
- test(" stump with continuous + categorical variables for multiclass classification " ) {
748
+ test(" Multiclass classification stump with continuous + categorical features " ) {
787
749
val arr = DecisionTreeSuite .generateContinuousDataPointsForMulticlass()
788
750
val rdd = sc.parallelize(arr)
789
751
val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 4 ,
@@ -808,7 +770,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
808
770
assert(bestSplit.threshold < 2020 )
809
771
}
810
772
811
- test(" stump with categorical variables for ordered multiclass classification " ) {
773
+ test(" Multiclass classification stump with 10-ary ( ordered) categorical features " ) {
812
774
val arr = DecisionTreeSuite .generateCategoricalDataPointsForMulticlassForOrderedFeatures()
813
775
val rdd = sc.parallelize(arr)
814
776
val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 4 ,
0 commit comments