@@ -35,8 +35,6 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
35
35
import org .apache .spark .rdd .RDD
36
36
import org .apache .spark .sql .{DataFrame , SQLContext }
37
37
38
- import NaiveBayes .ModelType .{Bernoulli , Multinomial }
39
-
40
38
41
39
/**
42
40
* Model for Naive Bayes Classifiers.
@@ -45,18 +43,17 @@ import NaiveBayes.ModelType.{Bernoulli, Multinomial}
45
43
* @param pi log of class priors, whose dimension is C, number of labels
46
44
* @param theta log of class conditional probabilities, whose dimension is C-by-D,
47
45
* where D is number of features
48
- * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
49
- * Multinomial or Bernoulli
46
+ * @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli"
50
47
*/
51
48
class NaiveBayesModel private [mllib] (
52
49
val labels : Array [Double ],
53
50
val pi : Array [Double ],
54
51
val theta : Array [Array [Double ]],
55
- val modelType : NaiveBayes . ModelType )
52
+ val modelType : String )
56
53
extends ClassificationModel with Serializable with Saveable {
57
54
58
55
private [mllib] def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
59
- this (labels, pi, theta, Multinomial )
56
+ this (labels, pi, theta, " Multinomial" )
60
57
61
58
/** A Java-friendly constructor that takes three Iterable parameters. */
62
59
private [mllib] def this (
@@ -72,8 +69,8 @@ class NaiveBayesModel private[mllib] (
72
69
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
73
70
// application of this condition (in predict function).
74
71
private val (brzNegTheta, brzNegThetaSum) = modelType match {
75
- case Multinomial => (None , None )
76
- case Bernoulli =>
72
+ case " Multinomial" => (None , None )
73
+ case " Bernoulli" =>
77
74
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (- 1.0 )) :+= 1.0 ) // log(1.0 - exp(x))
78
75
(Option (negTheta), Option (brzSum(negTheta, Axis ._1)))
79
76
case _ =>
@@ -91,9 +88,9 @@ class NaiveBayesModel private[mllib] (
91
88
92
89
override def predict (testData : Vector ): Double = {
93
90
modelType match {
94
- case Multinomial =>
91
+ case " Multinomial" =>
95
92
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
96
- case Bernoulli =>
93
+ case " Bernoulli" =>
97
94
labels (brzArgmax (brzPi +
98
95
(brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
99
96
case _ =>
@@ -103,7 +100,7 @@ class NaiveBayesModel private[mllib] (
103
100
}
104
101
105
102
override def save (sc : SparkContext , path : String ): Unit = {
106
- val data = NaiveBayesModel .SaveLoadV2_0 .Data (labels, pi, theta, modelType.toString )
103
+ val data = NaiveBayesModel .SaveLoadV2_0 .Data (labels, pi, theta, modelType)
107
104
NaiveBayesModel .SaveLoadV2_0 .save(sc, path, data)
108
105
}
109
106
@@ -155,7 +152,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
155
152
val labels = data.getAs[Seq [Double ]](0 ).toArray
156
153
val pi = data.getAs[Seq [Double ]](1 ).toArray
157
154
val theta = data.getAs[Seq [Seq [Double ]]](2 ).map(_.toArray).toArray
158
- val modelType = NaiveBayes . ModelType .fromString( data.getString(3 ) )
155
+ val modelType = data.getString(3 )
159
156
new NaiveBayesModel (labels, pi, theta, modelType)
160
157
}
161
158
@@ -248,11 +245,11 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
248
245
249
246
class NaiveBayes private (
250
247
private var lambda : Double ,
251
- private var modelType : NaiveBayes . ModelType ) extends Serializable with Logging {
248
+ private var modelType : String ) extends Serializable with Logging {
252
249
253
- def this (lambda : Double ) = this (lambda, Multinomial )
250
+ def this (lambda : Double ) = this (lambda, " Multinomial" )
254
251
255
- def this () = this (1.0 , Multinomial )
252
+ def this () = this (1.0 , " Multinomial" )
256
253
257
254
/** Set the smoothing parameter. Default: 1.0. */
258
255
def setLambda (lambda : Double ): NaiveBayes = {
@@ -264,26 +261,21 @@ class NaiveBayes private (
264
261
def getLambda : Double = lambda
265
262
266
263
/**
267
- * Set the model type using a string (case-insensitive).
268
- * Supported options: "multinomial" and "bernoulli".
269
- * (default: multinomial)
270
- */
271
- def setModelType (modelType : String ): NaiveBayes = {
272
- setModelType(NaiveBayes .ModelType .fromString(modelType))
273
- }
274
-
275
- /**
276
- * Set the model type.
277
- * Supported options: [[NaiveBayes.ModelType.Bernoulli ]], [[NaiveBayes.ModelType.Multinomial ]]
264
+ * Set the model type using a string (case-sensitive).
265
+ * Supported options: "Multinomial" and "Bernoulli".
278
266
* (default: Multinomial)
279
267
*/
280
- def setModelType (modelType : NaiveBayes .ModelType ): NaiveBayes = {
281
- this .modelType = modelType
282
- this
268
+ def setModelType (modelType: String ): NaiveBayes = {
269
+ if (NaiveBayes .supportedModelTypes.contains(modelType)) {
270
+ this .modelType = modelType
271
+ this
272
+ } else {
273
+ throw new UnknownError (s " NaiveBayesModel does not support ModelType: $modelType" )
274
+ }
283
275
}
284
276
285
277
/** Get the model type. */
286
- def getModelType : NaiveBayes . ModelType = this .modelType
278
+ def getModelType : String = this .modelType
287
279
288
280
/**
289
281
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
@@ -336,8 +328,8 @@ class NaiveBayes private (
336
328
labels(i) = label
337
329
pi(i) = math.log(n + lambda) - piLogDenom
338
330
val thetaLogDenom = modelType match {
339
- case Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
340
- case Bernoulli => math.log(n + 2.0 * lambda)
331
+ case " Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
332
+ case " Bernoulli" => math.log(n + 2.0 * lambda)
341
333
case _ =>
342
334
// This should never happen.
343
335
throw new UnknownError (s " NaiveBayes was created with an unknown ModelType: $modelType" )
@@ -358,6 +350,10 @@ class NaiveBayes private (
358
350
* Top-level methods for calling naive Bayes.
359
351
*/
360
352
object NaiveBayes {
353
+
354
+ /* Set of modelTypes that NaiveBayes supports */
355
+ private [mllib] val supportedModelTypes = Set (" Multinomial" , " Bernoulli" )
356
+
361
357
/**
362
358
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
363
359
*
@@ -386,7 +382,7 @@ object NaiveBayes {
386
382
* @param lambda The smoothing parameter
387
383
*/
388
384
def train (input : RDD [LabeledPoint ], lambda : Double ): NaiveBayesModel = {
389
- new NaiveBayes (lambda, NaiveBayes . ModelType . Multinomial ).run(input)
385
+ new NaiveBayes (lambda, " Multinomial" ).run(input)
390
386
}
391
387
392
388
/**
@@ -408,42 +404,11 @@ object NaiveBayes {
408
404
* multinomial or bernoulli
409
405
*/
410
406
def train (input : RDD [LabeledPoint ], lambda : Double , modelType : String ): NaiveBayesModel = {
411
- new NaiveBayes (lambda, ModelType .fromString(modelType)).run(input)
412
- }
413
-
414
- /** Provides static methods for using ModelType. */
415
- sealed abstract class ModelType extends Serializable
416
-
417
- object ModelType extends Serializable {
418
-
419
- /**
420
- * Get the model type from a string.
421
- * @param modelType Supported: "multinomial" or "bernoulli" (case-insensitive)
422
- */
423
- def fromString (modelType : String ): ModelType = modelType.toLowerCase match {
424
- case " multinomial" => Multinomial
425
- case " bernoulli" => Bernoulli
426
- case _ =>
427
- throw new IllegalArgumentException (
428
- s " NaiveBayes.ModelType.fromString did not recognize string: $modelType" )
429
- }
430
-
431
- final val Multinomial : ModelType = {
432
- case object Multinomial extends ModelType with Serializable {
433
- override def toString : String = " multinomial"
434
- }
435
- Multinomial
436
- }
437
-
438
- final val Bernoulli : ModelType = {
439
- case object Bernoulli extends ModelType with Serializable {
440
- override def toString : String = " bernoulli"
441
- }
442
- Bernoulli
407
+ if (supportedModelTypes.contains(modelType)) {
408
+ new NaiveBayes (lambda, modelType).run(input)
409
+ } else {
410
+ throw new UnknownError (s " NaiveBayes was created with an unknown ModelType: $modelType" )
443
411
}
444
412
}
445
413
446
- /** Java-friendly accessor for supported ModelType options */
447
- final val modelTypes = ModelType
448
-
449
414
}
0 commit comments