Skip to content

Commit 4a3676d

Browse files
committed
Updated changes re-comments. Got rid of verbose populateMatrix method. Public api now has string instead of enumeration. Docs are updated."
1 parent ce73c63 commit 4a3676d

File tree

3 files changed

+20
-30
lines changed

3 files changed

+20
-30
lines changed

docs/mllib-naive-bayes.md

+10-7
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@ compute the conditional probability distribution of label given an observation
1313
and use it for prediction.
1414

1515
MLlib supports [multinomial naive
16-
Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes),
17-
which is typically used for [document
18-
classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
16+
Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes)
17+
and [Bernoulli naive Bayes] (http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html).
18+
Which are typically used for [document classification]
19+
(http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
1920
Within that context, each observation is a document and each
20-
feature represents a term whose value is the frequency of the term.
21-
Feature values must be nonnegative to represent term frequencies.
21+
feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or
22+
a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes).
23+
Feature values must be nonnegative.The model type is selected with on optional parameter
24+
"Multinomial" or "Bernoulli" with "Multinomial" as the default.
2225
[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by
2326
setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature
2427
vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of
@@ -32,7 +35,7 @@ sparsity. Since the training data is only used once, it is not necessary to cach
3235
[NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements
3336
multinomial naive Bayes. It takes an RDD of
3437
[LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional
35-
smoothing parameter `lambda` as input, and output a
38+
smoothing parameter `lambda` as input, an optional model type parameter (default is Multinomial), and outputs a
3639
[NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which
3740
can be used for evaluation and prediction.
3841

@@ -51,7 +54,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
5154
val training = splits(0)
5255
val test = splits(1)
5356

54-
val model = NaiveBayes.train(training, lambda = 1.0)
57+
val model = NaiveBayes.train(training, lambda = 1.0, model = "Multinomial")
5558

5659
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
5760
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

+7-21
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
package org.apache.spark.mllib.classification
1919

2020
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}
21-
import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels
21+
import breeze.numerics.{exp => brzExp, log => brzLog}
2222

2323
import org.apache.spark.{SparkException, Logging}
2424
import org.apache.spark.SparkContext._
2525
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
2626
import org.apache.spark.mllib.regression.LabeledPoint
27+
import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels
2728
import org.apache.spark.rdd.RDD
2829

2930

@@ -52,29 +53,14 @@ class NaiveBayesModel private[mllib] (
5253
val theta: Array[Array[Double]],
5354
val model: NaiveBayesModels) extends ClassificationModel with Serializable {
5455

55-
def populateMatrix(arrayIn: Array[Array[Double]],
56-
matrixIn: BDM[Double],
57-
transformation: (Double) => Double = (x) => x) = {
58-
var i = 0
59-
while (i < arrayIn.length) {
60-
var j = 0
61-
while (j < arrayIn(i).length) {
62-
matrixIn(i, j) = transformation(theta(i)(j))
63-
j += 1
64-
}
65-
i += 1
66-
}
67-
}
68-
6956
private val brzPi = new BDV[Double](pi)
70-
private val brzTheta = new BDM[Double](theta.length, theta(0).length)
71-
populateMatrix(theta, brzTheta)
57+
private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
7258

7359
private val brzNegTheta: Option[BDM[Double]] = model match {
7460
case NaiveBayesModels.Multinomial => None
7561
case NaiveBayesModels.Bernoulli =>
76-
val negTheta = new BDM[Double](theta.length, theta(0).length)
77-
populateMatrix(theta, negTheta, (x) => math.log(1.0 - math.exp(x)))
62+
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0)
63+
//((x) => math.log(1.0 - math.exp(x))
7864
Option(negTheta)
7965
}
8066

@@ -244,7 +230,7 @@ object NaiveBayes {
244230
* @param model The type of NB model to fit from the enumeration NaiveBayesModels, can be
245231
* Multinomial or Bernoulli
246232
*/
247-
def train(input: RDD[LabeledPoint], lambda: Double, model: NaiveBayesModels): NaiveBayesModel = {
248-
new NaiveBayes(lambda, model).run(input)
233+
def train(input: RDD[LabeledPoint], lambda: Double, model: String): NaiveBayesModel = {
234+
new NaiveBayes(lambda, NaiveBayesModels.withName(model)).run(input)
249235
}
250236
}

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
117117
val testRDD = sc.parallelize(testData, 2)
118118
testRDD.cache()
119119

120-
val model = NaiveBayes.train(testRDD, 1.0, NaiveBayesModels.Multinomial)
120+
val model = NaiveBayes.train(testRDD, 1.0, "Multinomial")
121121
validateModelFit(pi, theta, model)
122122

123123
val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 17, NaiveBayesModels.Multinomial)
@@ -140,11 +140,12 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
140140
Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2
141141
).map(_.map(math.log))
142142

143+
143144
val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 45, NaiveBayesModels.Bernoulli)
144145
val testRDD = sc.parallelize(testData, 2)
145146
testRDD.cache()
146147

147-
val model = NaiveBayes.train(testRDD, 1.0, NaiveBayesModels.Bernoulli) ///!!! this gives same result on both models check the math
148+
val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli") ///!!! this gives same result on both models check the math
148149
validateModelFit(pi, theta, model)
149150

150151
val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 20, NaiveBayesModels.Bernoulli)

0 commit comments

Comments
 (0)