@@ -25,7 +25,6 @@ import scala.util.Random
2525import org .scalatest .FunSuite
2626
2727import org .apache .spark .SparkException
28- import org .apache .spark .mllib .classification .NaiveBayesModels .NaiveBayesModels
2928import org .apache .spark .mllib .linalg .Vectors
3029import org .apache .spark .mllib .regression .LabeledPoint
3130import org .apache .spark .mllib .util .{LocalClusterSparkContext , MLlibTestSparkContext }
@@ -49,7 +48,7 @@ object NaiveBayesSuite {
4948 theta : Array [Array [Double ]], // CXD
5049 nPoints : Int ,
5150 seed : Int ,
52- dataModel : NaiveBayesModels = NaiveBayesModels .Multinomial ,
51+ dataModel : NaiveBayes . ModelType = NaiveBayes .Multinomial ,
5352 sample : Int = 10 ): Seq [LabeledPoint ] = {
5453 val D = theta(0 ).length
5554 val rnd = new Random (seed)
@@ -60,10 +59,10 @@ object NaiveBayesSuite {
6059 for (i <- 0 until nPoints) yield {
6160 val y = calcLabel(rnd.nextDouble(), _pi)
6261 val xi = dataModel match {
63- case NaiveBayesModels .Bernoulli => Array .tabulate[Double ] (D ) {j =>
62+ case NaiveBayes .Bernoulli => Array .tabulate[Double ] (D ) {j =>
6463 if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
6564 }
66- case NaiveBayesModels .Multinomial =>
65+ case NaiveBayes .Multinomial =>
6766 val mult = Multinomial (BDV (_theta(y)))
6867 val emptyMap = (0 until D ).map(x => (x, 0.0 )).toMap
6968 val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
@@ -78,7 +77,7 @@ object NaiveBayesSuite {
7877
7978 /** Binary labels, 3 features */
8079 private val binaryModel = new NaiveBayesModel (labels = Array (0.0 , 1.0 ), pi = Array (0.2 , 0.8 ),
81- theta = Array (Array (0.1 , 0.3 , 0.6 ), Array (0.2 , 0.4 , 0.4 )), NaiveBayesModels .Bernoulli )
80+ theta = Array (Array (0.1 , 0.3 , 0.6 ), Array (0.2 , 0.4 , 0.4 )), NaiveBayes .Bernoulli )
8281}
8382
8483class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
@@ -121,7 +120,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
121120 ).map(_.map(math.log))
122121
123122 val testData = NaiveBayesSuite .generateNaiveBayesInput(
124- pi, theta, nPoints, 42 , NaiveBayesModels .Multinomial )
123+ pi, theta, nPoints, 42 , NaiveBayes .Multinomial )
125124 val testRDD = sc.parallelize(testData, 2 )
126125 testRDD.cache()
127126
@@ -133,7 +132,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
133132 theta,
134133 nPoints,
135134 17 ,
136- NaiveBayesModels .Multinomial )
135+ NaiveBayes .Multinomial )
137136 val validationRDD = sc.parallelize(validationData, 2 )
138137
139138 // Test prediction on RDD.
@@ -158,7 +157,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
158157 theta,
159158 nPoints,
160159 45 ,
161- NaiveBayesModels .Bernoulli )
160+ NaiveBayes .Bernoulli )
162161 val testRDD = sc.parallelize(testData, 2 )
163162 testRDD.cache()
164163
@@ -170,7 +169,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
170169 theta,
171170 nPoints,
172171 20 ,
173- NaiveBayesModels .Bernoulli )
172+ NaiveBayes .Bernoulli )
174173 val validationRDD = sc.parallelize(validationData, 2 )
175174
176175 // Test prediction on RDD.
0 commit comments