@@ -49,15 +49,15 @@ class NaiveBayesModel private[mllib] (
49
49
val modelType : String )
50
50
extends ClassificationModel with Serializable with Saveable {
51
51
52
- def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
52
+ private [mllib] def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
53
53
this (labels, pi, theta, NaiveBayes .Multinomial .toString)
54
54
55
55
private val brzPi = new BDV [Double ](pi)
56
56
private val brzTheta = new BDM (theta(0 ).length, theta.length, theta.flatten).t
57
57
58
- // Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0
59
- // this precomputes log(1.0 - exp(theta)) and its sum for linear algebra application
60
- // of this condition in predict function
58
+ // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
59
+ // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
60
+ // application of this condition ( in predict function).
61
61
private val (brzNegTheta, brzNegThetaSum) = NaiveBayes .ModelType .fromString(modelType) match {
62
62
case NaiveBayes .Multinomial => (None , None )
63
63
case NaiveBayes .Bernoulli =>
@@ -186,8 +186,6 @@ class NaiveBayes private (
186
186
private var lambda : Double ,
187
187
private var modelType : NaiveBayes .ModelType ) extends Serializable with Logging {
188
188
189
- def this (lambda : Double ) = this (lambda, NaiveBayes .Multinomial )
190
-
191
189
def this () = this (1.0 , NaiveBayes .Multinomial )
192
190
193
191
/** Set the smoothing parameter. Default: 1.0. */
@@ -202,6 +200,7 @@ class NaiveBayes private (
202
200
this
203
201
}
204
202
203
+ def getModelType (): NaiveBayes .ModelType = this .modelType
205
204
206
205
/**
207
206
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
@@ -301,10 +300,9 @@ object NaiveBayes {
301
300
* @param lambda The smoothing parameter
302
301
*/
303
302
def train (input : RDD [LabeledPoint ], lambda : Double ): NaiveBayesModel = {
304
- new NaiveBayes (lambda).run(input)
303
+ new NaiveBayes (lambda, NaiveBayes . Multinomial ).run(input)
305
304
}
306
305
307
-
308
306
/**
309
307
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
310
308
*
@@ -327,11 +325,7 @@ object NaiveBayes {
327
325
new NaiveBayes (lambda, MODELTYPE .fromString(modelType)).run(input)
328
326
}
329
327
330
-
331
- /**
332
- * Model types supported in Naive Bayes:
333
- * multinomial and Bernoulli currently supported
334
- */
328
+ /** Provides static methods for using ModelType. */
335
329
sealed abstract class ModelType
336
330
337
331
object MODELTYPE {
@@ -348,10 +342,12 @@ object NaiveBayes {
348
342
349
343
final val ModelType = MODELTYPE
350
344
345
+ /** Constant for specifying ModelType parameter: multinomial model */
351
346
final val Multinomial : ModelType = new ModelType {
352
347
override def toString : String = ModelType .MULTINOMIAL_STRING
353
348
}
354
349
350
+ /** Constant for specifying ModelType parameter: bernoulli model */
355
351
final val Bernoulli : ModelType = new ModelType {
356
352
override def toString : String = ModelType .BERNOULLI_STRING
357
353
}
0 commit comments