1717
1818package org .apache .spark .mllib .classification
1919
20- import breeze .linalg .{DenseMatrix => BDM , DenseVector => BDV , argmax => brzArgmax , sum => brzSum }
20+ import breeze .linalg .{DenseMatrix => BDM , DenseVector => BDV , argmax => brzArgmax , sum => brzSum , Axis }
21+ import org .apache .spark .mllib .classification .NaiveBayesModels .NaiveBayesModels
2122
2223import org .apache .spark .{SparkException , Logging }
2324import org .apache .spark .SparkContext ._
2425import org .apache .spark .mllib .linalg .{DenseVector , SparseVector , Vector }
2526import org .apache .spark .mllib .regression .LabeledPoint
2627import org .apache .spark .rdd .RDD
2728
29+
30+ /**
31+ *
32+ */
33+ object NaiveBayesModels extends Enumeration {
34+ type NaiveBayesModels = Value
35+ val Multinomial, Bernoulli = Value
36+ }
37+
2838/**
2939 * Model for Naive Bayes Classifiers.
3040 *
3141 * @param labels list of labels
3242 * @param pi log of class priors, whose dimension is C, number of labels
3343 * @param theta log of class conditional probabilities, whose dimension is C-by-D,
3444 * where D is number of features
45+ * @param model The type of NB model to fit from the enumeration NaiveBayesModels, can be
46+ * Multinomial or Bernoulli
3547 */
48+
3649class NaiveBayesModel private [mllib] (
3750 val labels : Array [Double ],
3851 val pi : Array [Double ],
39- val theta : Array [Array [Double ]]) extends ClassificationModel with Serializable {
40-
41- private val brzPi = new BDV [Double ](pi)
42- private val brzTheta = new BDM [Double ](theta.length, theta(0 ).length)
52+ val theta : Array [Array [Double ]],
53+ val model : NaiveBayesModels ) extends ClassificationModel with Serializable {
4354
44- {
45- // Need to put an extra pair of braces to prevent Scala treating `i` as a member.
55+ def populateMatrix (arrayIn : Array [Array [Double ]],
56+ matrixIn : BDM [Double ],
57+ transformation : (Double ) => Double = (x) => x) = {
4658 var i = 0
47- while (i < theta .length) {
59+ while (i < arrayIn .length) {
4860 var j = 0
49- while (j < theta (i).length) {
50- brzTheta (i, j) = theta(i)(j)
61+ while (j < arrayIn (i).length) {
62+ matrixIn (i, j) = transformation( theta(i)(j) )
5163 j += 1
5264 }
5365 i += 1
5466 }
5567 }
5668
69+ private val brzPi = new BDV [Double ](pi)
70+ private val brzTheta = new BDM [Double ](theta.length, theta(0 ).length)
71+ populateMatrix(theta, brzTheta)
72+
73+ private val brzNegTheta : Option [BDM [Double ]] = model match {
74+ case NaiveBayesModels .Multinomial => None
75+ case NaiveBayesModels .Bernoulli =>
76+ val negTheta = new BDM [Double ](theta.length, theta(0 ).length)
77+ populateMatrix(theta, negTheta, (x) => math.log(1.0 - math.exp(x)))
78+ Option (negTheta)
79+ }
80+
5781 override def predict (testData : RDD [Vector ]): RDD [Double ] = {
5882 val bcModel = testData.context.broadcast(this )
5983 testData.mapPartitions { iter =>
@@ -63,7 +87,14 @@ class NaiveBayesModel private[mllib] (
6387 }
6488
6589 override def predict (testData : Vector ): Double = {
66- labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
90+ model match {
91+ case NaiveBayesModels .Multinomial =>
92+ labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
93+ case NaiveBayesModels .Bernoulli =>
94+ labels (brzArgmax (brzPi +
95+ (brzTheta - brzNegTheta.get) * testData.toBreeze +
96+ brzSum(brzNegTheta.get, Axis ._1)))
97+ }
6798 }
6899}
69100
@@ -75,16 +106,26 @@ class NaiveBayesModel private[mllib] (
75106 * document classification. By making every vector a 0-1 vector, it can also be used as
76107 * Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]). The input feature values must be nonnegative.
77108 */
78- class NaiveBayes private (private var lambda : Double ) extends Serializable with Logging {
109+ class NaiveBayes private (private var lambda : Double ,
110+ var model : NaiveBayesModels ) extends Serializable with Logging {
79111
80- def this () = this (1.0 )
112+ def this (lambda : Double ) = this (lambda, NaiveBayesModels .Multinomial )
113+
114+ def this () = this (1.0 , NaiveBayesModels .Multinomial )
81115
82116 /** Set the smoothing parameter. Default: 1.0. */
83117 def setLambda (lambda : Double ): NaiveBayes = {
84118 this .lambda = lambda
85119 this
86120 }
87121
122+ /** Set the model type. Default: Multinomial. */
123+ def setModelType (model : NaiveBayesModels ): NaiveBayes = {
124+ this .model = model
125+ this
126+ }
127+
128+
88129 /**
89130 * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
90131 *
@@ -118,21 +159,27 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
118159 mergeCombiners = (c1 : (Long , BDV [Double ]), c2 : (Long , BDV [Double ])) =>
119160 (c1._1 + c2._1, c1._2 += c2._2)
120161 ).collect()
162+
121163 val numLabels = aggregated.length
122164 var numDocuments = 0L
123165 aggregated.foreach { case (_, (n, _)) =>
124166 numDocuments += n
125167 }
126168 val numFeatures = aggregated.head match { case (_, (_, v)) => v.size }
169+
127170 val labels = new Array [Double ](numLabels)
128171 val pi = new Array [Double ](numLabels)
129172 val theta = Array .fill(numLabels)(new Array [Double ](numFeatures))
173+
130174 val piLogDenom = math.log(numDocuments + numLabels * lambda)
131175 var i = 0
132176 aggregated.foreach { case (label, (n, sumTermFreqs)) =>
133177 labels(i) = label
134- val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
135178 pi(i) = math.log(n + lambda) - piLogDenom
179+ val thetaLogDenom = model match {
180+ case NaiveBayesModels .Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
181+ case NaiveBayesModels .Bernoulli => math.log(n + 2.0 * lambda)
182+ }
136183 var j = 0
137184 while (j < numFeatures) {
138185 theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
@@ -141,7 +188,7 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
141188 i += 1
142189 }
143190
144- new NaiveBayesModel (labels, pi, theta)
191+ new NaiveBayesModel (labels, pi, theta, model )
145192 }
146193}
147194
@@ -154,8 +201,7 @@ object NaiveBayes {
154201 *
155202 * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle all kinds of
156203 * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
157- * document classification. By making every vector a 0-1 vector, it can also be used as
158- * Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]).
204+ * document classification.
159205 *
160206 * This version of the method uses a default smoothing parameter of 1.0.
161207 *
@@ -171,8 +217,7 @@ object NaiveBayes {
171217 *
172218 * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle all kinds of
173219 * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
174- * document classification. By making every vector a 0-1 vector, it can also be used as
175- * Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]).
220+ * document classification.
176221 *
177222 * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
178223 * vector or a count vector.
@@ -181,4 +226,25 @@ object NaiveBayes {
181226 def train (input : RDD [LabeledPoint ], lambda : Double ): NaiveBayesModel = {
182227 new NaiveBayes (lambda).run(input)
183228 }
229+
230+
231+ /**
232+ * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
233+ *
234+ * This is by default the Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle
235+ * all kinds of discrete data. For example, by converting documents into TF-IDF vectors,
236+ * it can be used for document classification. By making every vector a 0-1 vector and
237+ * setting the model type to NaiveBayesModels.Bernoulli, it fits and predicts as
238+ * Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]).
239+ *
240+ * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
241+ * vector or a count vector.
242+ * @param lambda The smoothing parameter
243+ *
244+ * @param model The type of NB model to fit from the enumeration NaiveBayesModels, can be
245+ * Multinomial or Bernoulli
246+ */
247+ def train (input : RDD [LabeledPoint ], lambda : Double , model : NaiveBayesModels ): NaiveBayesModel = {
248+ new NaiveBayes (lambda, model).run(input)
249+ }
184250}
0 commit comments