@@ -35,6 +35,8 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
35
35
import org .apache .spark .rdd .RDD
36
36
import org .apache .spark .sql .{DataFrame , SQLContext }
37
37
38
+ import NaiveBayes .ModelType .{Bernoulli , Multinomial }
39
+
38
40
39
41
/**
40
42
* Model for Naive Bayes Classifiers.
@@ -54,7 +56,7 @@ class NaiveBayesModel private[mllib] (
54
56
extends ClassificationModel with Serializable with Saveable {
55
57
56
58
private [mllib] def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
57
- this (labels, pi, theta, NaiveBayes . Multinomial )
59
+ this (labels, pi, theta, Multinomial )
58
60
59
61
/** A Java-friendly constructor that takes three Iterable parameters. */
60
62
private [mllib] def this (
@@ -70,10 +72,13 @@ class NaiveBayesModel private[mllib] (
70
72
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
71
73
// application of this condition (in predict function).
72
74
private val (brzNegTheta, brzNegThetaSum) = modelType match {
73
- case NaiveBayes . Multinomial => (None , None )
74
- case NaiveBayes . Bernoulli =>
75
+ case Multinomial => (None , None )
76
+ case Bernoulli =>
75
77
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (- 1.0 )) :+= 1.0 ) // log(1.0 - exp(x))
76
78
(Option (negTheta), Option (brzSum(negTheta, Axis ._1)))
79
+ case _ =>
80
+ // This should never happen.
81
+ throw new UnknownError (s " NaiveBayesModel was created with an unknown ModelType: $modelType" )
77
82
}
78
83
79
84
override def predict (testData : RDD [Vector ]): RDD [Double ] = {
@@ -86,29 +91,32 @@ class NaiveBayesModel private[mllib] (
86
91
87
92
override def predict (testData : Vector ): Double = {
88
93
modelType match {
89
- case NaiveBayes . Multinomial =>
94
+ case Multinomial =>
90
95
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
91
- case NaiveBayes . Bernoulli =>
96
+ case Bernoulli =>
92
97
labels (brzArgmax (brzPi +
93
98
(brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
99
+ case _ =>
100
+ // This should never happen.
101
+ throw new UnknownError (s " NaiveBayesModel was created with an unknown ModelType: $modelType" )
94
102
}
95
103
}
96
104
97
105
override def save (sc : SparkContext , path : String ): Unit = {
98
- val data = NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType.toString)
99
- NaiveBayesModel .SaveLoadV1_0 .save(sc, path, data)
106
+ val data = NaiveBayesModel .SaveLoadV2_0 .Data (labels, pi, theta, modelType.toString)
107
+ NaiveBayesModel .SaveLoadV2_0 .save(sc, path, data)
100
108
}
101
109
102
- override protected def formatVersion : String = " 1 .0"
110
+ override protected def formatVersion : String = " 2 .0"
103
111
}
104
112
105
113
object NaiveBayesModel extends Loader [NaiveBayesModel ] {
106
114
107
115
import org .apache .spark .mllib .util .Loader ._
108
116
109
- private object SaveLoadV1_0 {
117
+ private [mllib] object SaveLoadV2_0 {
110
118
111
- def thisFormatVersion : String = " 1 .0"
119
+ def thisFormatVersion : String = " 2 .0"
112
120
113
121
/** Hard-code class name string in case it changes in the future */
114
122
def thisClassName : String = " org.apache.spark.mllib.classification.NaiveBayesModel"
@@ -127,8 +135,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
127
135
// Create JSON metadata.
128
136
val metadata = compact(render(
129
137
(" class" -> thisClassName) ~ (" version" -> thisFormatVersion) ~
130
- (" numFeatures" -> data.theta(0 ).length) ~ (" numClasses" -> data.pi.length) ~
131
- (" modelType" -> data.modelType)))
138
+ (" numFeatures" -> data.theta(0 ).length) ~ (" numClasses" -> data.pi.length)))
132
139
sc.parallelize(Seq (metadata), 1 ).saveAsTextFile(metadataPath(path))
133
140
134
141
// Create Parquet data.
@@ -151,36 +158,82 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
151
158
val modelType = NaiveBayes .ModelType .fromString(data.getString(3 ))
152
159
new NaiveBayesModel (labels, pi, theta, modelType)
153
160
}
161
+
154
162
}
155
163
156
- override def load (sc : SparkContext , path : String ): NaiveBayesModel = {
157
- def getModelType (metadata : JValue ): NaiveBayes .ModelType = {
158
- implicit val formats = DefaultFormats
159
- NaiveBayes .ModelType .fromString((metadata \ " modelType" ).extract[String ])
164
+ private [mllib] object SaveLoadV1_0 {
165
+
166
+ def thisFormatVersion : String = " 1.0"
167
+
168
+ /** Hard-code class name string in case it changes in the future */
169
+ def thisClassName : String = " org.apache.spark.mllib.classification.NaiveBayesModel"
170
+
171
+ /** Model data for model import/export */
172
+ case class Data (
173
+ labels : Array [Double ],
174
+ pi : Array [Double ],
175
+ theta : Array [Array [Double ]])
176
+
177
+ def save (sc : SparkContext , path : String , data : Data ): Unit = {
178
+ val sqlContext = new SQLContext (sc)
179
+ import sqlContext .implicits ._
180
+
181
+ // Create JSON metadata.
182
+ val metadata = compact(render(
183
+ (" class" -> thisClassName) ~ (" version" -> thisFormatVersion) ~
184
+ (" numFeatures" -> data.theta(0 ).length) ~ (" numClasses" -> data.pi.length)))
185
+ sc.parallelize(Seq (metadata), 1 ).saveAsTextFile(metadataPath(path))
186
+
187
+ // Create Parquet data.
188
+ val dataRDD : DataFrame = sc.parallelize(Seq (data), 1 ).toDF()
189
+ dataRDD.saveAsParquetFile(dataPath(path))
190
+ }
191
+
192
+ def load (sc : SparkContext , path : String ): NaiveBayesModel = {
193
+ val sqlContext = new SQLContext (sc)
194
+ // Load Parquet data.
195
+ val dataRDD = sqlContext.parquetFile(dataPath(path))
196
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
197
+ checkSchema[Data ](dataRDD.schema)
198
+ val dataArray = dataRDD.select(" labels" , " pi" , " theta" ).take(1 )
199
+ assert(dataArray.size == 1 , s " Unable to load NaiveBayesModel data from: ${dataPath(path)}" )
200
+ val data = dataArray(0 )
201
+ val labels = data.getAs[Seq [Double ]](0 ).toArray
202
+ val pi = data.getAs[Seq [Double ]](1 ).toArray
203
+ val theta = data.getAs[Seq [Seq [Double ]]](2 ).map(_.toArray).toArray
204
+ new NaiveBayesModel (labels, pi, theta)
160
205
}
206
+ }
207
+
208
+ override def load (sc : SparkContext , path : String ): NaiveBayesModel = {
161
209
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
162
210
val classNameV1_0 = SaveLoadV1_0 .thisClassName
163
- (loadedClassName, version) match {
211
+ val classNameV2_0 = SaveLoadV2_0 .thisClassName
212
+ val (model, numFeatures, numClasses) = (loadedClassName, version) match {
164
213
case (className, " 1.0" ) if className == classNameV1_0 =>
165
214
val (numFeatures, numClasses) = ClassificationModel .getNumFeaturesClasses(metadata)
166
215
val model = SaveLoadV1_0 .load(sc, path)
167
- assert(model.pi.size == numClasses,
168
- s " NaiveBayesModel.load expected $numClasses classes, " +
169
- s " but class priors vector pi had ${model.pi.size} elements " )
170
- assert(model.theta.size == numClasses,
171
- s " NaiveBayesModel.load expected $numClasses classes, " +
172
- s " but class conditionals array theta had ${model.theta.size} elements " )
173
- assert(model.theta.forall(_.size == numFeatures),
174
- s " NaiveBayesModel.load expected $numFeatures features, " +
175
- s " but class conditionals array theta had elements of size: " +
176
- s " ${model.theta.map(_.size).mkString(" ," )}" )
177
- assert(model.modelType == getModelType(metadata))
178
- model
216
+ (model, numFeatures, numClasses)
217
+ case (className, " 2.0" ) if className == classNameV2_0 =>
218
+ val (numFeatures, numClasses) = ClassificationModel .getNumFeaturesClasses(metadata)
219
+ val model = SaveLoadV2_0 .load(sc, path)
220
+ (model, numFeatures, numClasses)
179
221
case _ => throw new Exception (
180
222
s " NaiveBayesModel.load did not recognize model with (className, format version): " +
181
223
s " ( $loadedClassName, $version). Supported: \n " +
182
224
s " ( $classNameV1_0, 1.0) " )
183
225
}
226
+ assert(model.pi.size == numClasses,
227
+ s " NaiveBayesModel.load expected $numClasses classes, " +
228
+ s " but class priors vector pi had ${model.pi.size} elements " )
229
+ assert(model.theta.size == numClasses,
230
+ s " NaiveBayesModel.load expected $numClasses classes, " +
231
+ s " but class conditionals array theta had ${model.theta.size} elements " )
232
+ assert(model.theta.forall(_.size == numFeatures),
233
+ s " NaiveBayesModel.load expected $numFeatures features, " +
234
+ s " but class conditionals array theta had elements of size: " +
235
+ s " ${model.theta.map(_.size).mkString(" ," )}" )
236
+ model
184
237
}
185
238
}
186
239
@@ -197,9 +250,9 @@ class NaiveBayes private (
197
250
private var lambda : Double ,
198
251
private var modelType : NaiveBayes .ModelType ) extends Serializable with Logging {
199
252
200
- def this (lambda : Double ) = this (lambda, NaiveBayes . Multinomial )
253
+ def this (lambda : Double ) = this (lambda, Multinomial )
201
254
202
- def this () = this (1.0 , NaiveBayes . Multinomial )
255
+ def this () = this (1.0 , Multinomial )
203
256
204
257
/** Set the smoothing parameter. Default: 1.0. */
205
258
def setLambda (lambda : Double ): NaiveBayes = {
@@ -210,9 +263,22 @@ class NaiveBayes private (
210
263
/** Get the smoothing parameter. */
211
264
def getLambda : Double = lambda
212
265
213
- /** Set the model type. Default: Multinomial. */
214
- def setModelType (model : NaiveBayes .ModelType ): NaiveBayes = {
215
- this .modelType = model
266
+ /**
267
+ * Set the model type using a string (case-insensitive).
268
+ * Supported options: "multinomial" and "bernoulli".
269
+ * (default: multinomial)
270
+ */
271
+ def setModelType (modelType : String ): NaiveBayes = {
272
+ setModelType(NaiveBayes .ModelType .fromString(modelType))
273
+ }
274
+
275
+ /**
276
+ * Set the model type.
277
+ * Supported options: [[NaiveBayes.ModelType.Bernoulli ]], [[NaiveBayes.ModelType.Multinomial ]]
278
+ * (default: Multinomial)
279
+ */
280
+ def setModelType (modelType : NaiveBayes .ModelType ): NaiveBayes = {
281
+ this .modelType = modelType
216
282
this
217
283
}
218
284
@@ -270,8 +336,11 @@ class NaiveBayes private (
270
336
labels(i) = label
271
337
pi(i) = math.log(n + lambda) - piLogDenom
272
338
val thetaLogDenom = modelType match {
273
- case NaiveBayes .Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
274
- case NaiveBayes .Bernoulli => math.log(n + 2.0 * lambda)
339
+ case Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
340
+ case Bernoulli => math.log(n + 2.0 * lambda)
341
+ case _ =>
342
+ // This should never happen.
343
+ throw new UnknownError (s " NaiveBayes was created with an unknown ModelType: $modelType" )
275
344
}
276
345
var j = 0
277
346
while (j < numFeatures) {
@@ -317,7 +386,7 @@ object NaiveBayes {
317
386
* @param lambda The smoothing parameter
318
387
*/
319
388
def train (input : RDD [LabeledPoint ], lambda : Double ): NaiveBayesModel = {
320
- new NaiveBayes (lambda, NaiveBayes .Multinomial ).run(input)
389
+ new NaiveBayes (lambda, NaiveBayes .ModelType . Multinomial ).run(input)
321
390
}
322
391
323
392
/**
@@ -339,35 +408,42 @@ object NaiveBayes {
339
408
* multinomial or bernoulli
340
409
*/
341
410
def train (input : RDD [LabeledPoint ], lambda : Double , modelType : String ): NaiveBayesModel = {
342
- new NaiveBayes (lambda, MODELTYPE .fromString(modelType)).run(input)
411
+ new NaiveBayes (lambda, ModelType .fromString(modelType)).run(input)
343
412
}
344
413
345
414
/** Provides static methods for using ModelType. */
346
415
sealed abstract class ModelType extends Serializable
347
416
348
- object MODELTYPE extends Serializable {
349
- final val MULTINOMIAL_STRING = " multinomial"
350
- final val BERNOULLI_STRING = " bernoulli"
417
+ object ModelType extends Serializable {
351
418
352
- def fromString (modelType : String ): ModelType = modelType match {
353
- case MULTINOMIAL_STRING => Multinomial
354
- case BERNOULLI_STRING => Bernoulli
419
+ /**
420
+ * Get the model type from a string.
421
+ * @param modelType Supported: "multinomial" or "bernoulli" (case-insensitive)
422
+ */
423
+ def fromString (modelType : String ): ModelType = modelType.toLowerCase match {
424
+ case " multinomial" => Multinomial
425
+ case " bernoulli" => Bernoulli
355
426
case _ =>
356
- throw new IllegalArgumentException (s " Cannot recognize NaiveBayes ModelType: $modelType" )
427
+ throw new IllegalArgumentException (
428
+ s " NaiveBayes.ModelType.fromString did not recognize string: $modelType" )
357
429
}
358
- }
359
430
360
- final val ModelType = MODELTYPE
431
+ final val Multinomial : ModelType = {
432
+ case object Multinomial extends ModelType with Serializable {
433
+ override def toString : String = " multinomial"
434
+ }
435
+ Multinomial
436
+ }
361
437
362
- /** Constant for specifying ModelType parameter: multinomial model */
363
- final val Multinomial : ModelType = new ModelType {
364
- override def toString : String = ModelType .MULTINOMIAL_STRING
438
+ final val Bernoulli : ModelType = {
439
+ case object Bernoulli extends ModelType with Serializable {
440
+ override def toString : String = " bernoulli"
441
+ }
442
+ Bernoulli
443
+ }
365
444
}
366
445
367
- /** Constant for specifying ModelType parameter: bernoulli model */
368
- final val Bernoulli : ModelType = new ModelType {
369
- override def toString : String = ModelType .BERNOULLI_STRING
370
- }
446
+ /** Java-friendly accessor for supported ModelType options */
447
+ final val modelTypes = ModelType
371
448
372
449
}
373
-
0 commit comments