Skip to content

Commit 01baad7

Browse files
committed
made fixes from code review
1 parent fb0a5c7 commit 01baad7

File tree

3 files changed

+14
-26
lines changed

3 files changed

+14
-26
lines changed

docs/mllib-naive-bayes.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ and use it for prediction.
1515
MLlib supports [multinomial naive
1616
Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes)
1717
and [Bernoulli naive Bayes] (http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html).
18-
Which are typically used for [document classification]
18+
These models are typically used for [document classification]
1919
(http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
2020
Within that context, each observation is a document and each
2121
feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or
2222
a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes).
23-
Feature values must be nonnegative.The model type is selected with on optional parameter
23+
Feature values must be nonnegative. The model type is selected with an optional parameter
2424
"Multinomial" or "Bernoulli" with "Multinomial" as the default.
2525
[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by
2626
setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ class NaiveBayesModel private[mllib] (
4949
val modelType: String)
5050
extends ClassificationModel with Serializable with Saveable {
5151

52-
def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
52+
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
5353
this(labels, pi, theta, NaiveBayes.Multinomial.toString)
5454

5555
private val brzPi = new BDV[Double](pi)
5656
private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
5757

58-
// Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0
59-
// this precomputes log(1.0 - exp(theta)) and its sum for linear algebra application
60-
// of this condition in predict function
58+
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
59+
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
60+
// application of this condition (in predict function).
6161
private val (brzNegTheta, brzNegThetaSum) = NaiveBayes.ModelType.fromString(modelType) match {
6262
case NaiveBayes.Multinomial => (None, None)
6363
case NaiveBayes.Bernoulli =>
@@ -186,8 +186,6 @@ class NaiveBayes private (
186186
private var lambda: Double,
187187
private var modelType: NaiveBayes.ModelType) extends Serializable with Logging {
188188

189-
def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
190-
191189
def this() = this(1.0, NaiveBayes.Multinomial)
192190

193191
/** Set the smoothing parameter. Default: 1.0. */
@@ -202,6 +200,7 @@ class NaiveBayes private (
202200
this
203201
}
204202

203+
def getModelType(): NaiveBayes.ModelType = this.modelType
205204

206205
/**
207206
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
@@ -301,10 +300,9 @@ object NaiveBayes {
301300
* @param lambda The smoothing parameter
302301
*/
303302
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
304-
new NaiveBayes(lambda).run(input)
303+
new NaiveBayes(lambda, NaiveBayes.Multinomial).run(input)
305304
}
306305

307-
308306
/**
309307
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
310308
*
@@ -327,11 +325,7 @@ object NaiveBayes {
327325
new NaiveBayes(lambda, MODELTYPE.fromString(modelType)).run(input)
328326
}
329327

330-
331-
/**
332-
* Model types supported in Naive Bayes:
333-
* multinomial and Bernoulli currently supported
334-
*/
328+
/** Provides static methods for using ModelType. */
335329
sealed abstract class ModelType
336330

337331
object MODELTYPE {
@@ -348,10 +342,12 @@ object NaiveBayes {
348342

349343
final val ModelType = MODELTYPE
350344

345+
/** Constant for specifying ModelType parameter: multinomial model */
351346
final val Multinomial: ModelType = new ModelType {
352347
override def toString: String = ModelType.MULTINOMIAL_STRING
353348
}
354349

350+
/** Constant for specifying ModelType parameter: bernoulli model */
355351
final val Bernoulli: ModelType = new ModelType {
356352
override def toString: String = ModelType.BERNOULLI_STRING
357353
}

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ object NaiveBayesSuite {
5858
for (i <- 0 until nPoints) yield {
5959
val y = calcLabel(rnd.nextDouble(), _pi)
6060
val xi = dataModel match {
61-
case NaiveBayes.Bernoulli => Array.tabulate[Double] (D) {j =>
61+
case NaiveBayes.Bernoulli => Array.tabulate[Double] (D) { j =>
6262
if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
6363
}
6464
case NaiveBayes.Multinomial =>
@@ -118,23 +118,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
118118
).map(_.map(math.log))
119119

120120
val testData = NaiveBayesSuite.generateNaiveBayesInput(
121-
pi,
122-
theta,
123-
nPoints,
124-
42,
125-
NaiveBayes.Multinomial)
121+
pi, theta, nPoints, 42, NaiveBayes.Multinomial)
126122
val testRDD = sc.parallelize(testData, 2)
127123
testRDD.cache()
128124

129125
val model = NaiveBayes.train(testRDD, 1.0, "multinomial")
130126
validateModelFit(pi, theta, model)
131127

132128
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
133-
pi,
134-
theta,
135-
nPoints,
136-
17,
137-
NaiveBayes.Multinomial)
129+
pi, theta, nPoints, 17, NaiveBayes.Multinomial)
138130
val validationRDD = sc.parallelize(validationData, 2)
139131

140132
// Test prediction on RDD.

0 commit comments

Comments
 (0)