@@ -31,7 +31,6 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
31
31
import org .apache .spark .mllib .tree .model .{DecisionTreeModel , Node }
32
32
import org .apache .spark .mllib .util .LocalSparkContext
33
33
34
-
35
34
class DecisionTreeSuite extends FunSuite with LocalSparkContext {
36
35
37
36
def validateClassifier (
@@ -353,8 +352,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
353
352
assert(splits(0 ).length === 99 )
354
353
assert(bins.length === 2 )
355
354
assert(bins(0 ).length === 100 )
356
- assert(splits(0 ).length === 99 )
357
- assert(bins(0 ).length === 100 )
358
355
359
356
val treeInput = TreePoint .convertToTreeRDD(rdd, bins, metadata)
360
357
val bestSplits = DecisionTree .findBestSplits(treeInput, new Array (8 ), metadata, 0 ,
@@ -381,8 +378,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
381
378
assert(splits(0 ).length === 99 )
382
379
assert(bins.length === 2 )
383
380
assert(bins(0 ).length === 100 )
384
- assert(splits(0 ).length === 99 )
385
- assert(bins(0 ).length === 100 )
386
381
387
382
val treeInput = TreePoint .convertToTreeRDD(rdd, bins, metadata)
388
383
val bestSplits = DecisionTree .findBestSplits(treeInput, new Array (2 ), metadata, 0 ,
@@ -410,8 +405,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
410
405
assert(splits(0 ).length === 99 )
411
406
assert(bins.length === 2 )
412
407
assert(bins(0 ).length === 100 )
413
- assert(splits(0 ).length === 99 )
414
- assert(bins(0 ).length === 100 )
415
408
416
409
val treeInput = TreePoint .convertToTreeRDD(rdd, bins, metadata)
417
410
val bestSplits = DecisionTree .findBestSplits(treeInput, new Array (2 ), metadata, 0 ,
@@ -439,8 +432,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
439
432
assert(splits(0 ).length === 99 )
440
433
assert(bins.length === 2 )
441
434
assert(bins(0 ).length === 100 )
442
- assert(splits(0 ).length === 99 )
443
- assert(bins(0 ).length === 100 )
444
435
445
436
val treeInput = TreePoint .convertToTreeRDD(rdd, bins, metadata)
446
437
val bestSplits = DecisionTree .findBestSplits(treeInput, new Array (2 ), metadata, 0 ,
@@ -464,8 +455,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
464
455
assert(splits(0 ).length === 99 )
465
456
assert(bins.length === 2 )
466
457
assert(bins(0 ).length === 100 )
467
- assert(splits(0 ).length === 99 )
468
- assert(bins(0 ).length === 100 )
469
458
470
459
// Train a 1-node model
471
460
val strategyOneNode = new Strategy (Classification , Entropy , 1 , 2 , 100 )
@@ -600,7 +589,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
600
589
val arr = DecisionTreeSuite .generateContinuousDataPointsForMulticlass()
601
590
val rdd = sc.parallelize(arr)
602
591
val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 4 ,
603
- numClassesForClassification = 3 )
592
+ numClassesForClassification = 3 , maxBins = 100 )
604
593
assert(strategy.isMulticlassClassification)
605
594
val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
606
595
@@ -626,7 +615,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
626
615
val arr = DecisionTreeSuite .generateContinuousDataPointsForMulticlass()
627
616
val rdd = sc.parallelize(arr)
628
617
val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 4 ,
629
- numClassesForClassification = 3 , categoricalFeaturesInfo = Map (0 -> 3 ))
618
+ numClassesForClassification = 3 , maxBins = 100 , categoricalFeaturesInfo = Map (0 -> 3 ))
630
619
assert(strategy.isMulticlassClassification)
631
620
val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
632
621
assert(metadata.isUnordered(featureIndex = 0 ))
@@ -652,7 +641,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
652
641
val arr = DecisionTreeSuite .generateCategoricalDataPointsForMulticlassForOrderedFeatures()
653
642
val rdd = sc.parallelize(arr)
654
643
val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 4 ,
655
- numClassesForClassification = 3 , categoricalFeaturesInfo = Map (0 -> 10 , 1 -> 10 ))
644
+ numClassesForClassification = 3 , maxBins = 100 ,
645
+ categoricalFeaturesInfo = Map (0 -> 10 , 1 -> 10 ))
656
646
assert(strategy.isMulticlassClassification)
657
647
val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
658
648
assert(! metadata.isUnordered(featureIndex = 0 ))
0 commit comments