Skip to content

Commit 2224b15

Browse files
committed
Merge pull request #2 from jkbradley/leahmcguire-master
Added model save/load version to support NaiveBayes ModelType
2 parents 852a727 + 9ad89ca commit 2224b15

File tree

3 files changed

+198
-81
lines changed

3 files changed

+198
-81
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 131 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
3535
import org.apache.spark.rdd.RDD
3636
import org.apache.spark.sql.{DataFrame, SQLContext}
3737

38+
import NaiveBayes.ModelType.{Bernoulli, Multinomial}
39+
3840

3941
/**
4042
* Model for Naive Bayes Classifiers.
@@ -54,7 +56,7 @@ class NaiveBayesModel private[mllib] (
5456
extends ClassificationModel with Serializable with Saveable {
5557

5658
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
57-
this(labels, pi, theta, NaiveBayes.Multinomial)
59+
this(labels, pi, theta, Multinomial)
5860

5961
/** A Java-friendly constructor that takes three Iterable parameters. */
6062
private[mllib] def this(
@@ -70,10 +72,13 @@ class NaiveBayesModel private[mllib] (
7072
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
7173
// application of this condition (in predict function).
7274
private val (brzNegTheta, brzNegThetaSum) = modelType match {
73-
case NaiveBayes.Multinomial => (None, None)
74-
case NaiveBayes.Bernoulli =>
75+
case Multinomial => (None, None)
76+
case Bernoulli =>
7577
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
7678
(Option(negTheta), Option(brzSum(negTheta, Axis._1)))
79+
case _ =>
80+
// This should never happen.
81+
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
7782
}
7883

7984
override def predict(testData: RDD[Vector]): RDD[Double] = {
@@ -86,29 +91,32 @@ class NaiveBayesModel private[mllib] (
8691

8792
override def predict(testData: Vector): Double = {
8893
modelType match {
89-
case NaiveBayes.Multinomial =>
94+
case Multinomial =>
9095
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
91-
case NaiveBayes.Bernoulli =>
96+
case Bernoulli =>
9297
labels (brzArgmax (brzPi +
9398
(brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
99+
case _ =>
100+
// This should never happen.
101+
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
94102
}
95103
}
96104

97105
override def save(sc: SparkContext, path: String): Unit = {
98-
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString)
99-
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
106+
val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType.toString)
107+
NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
100108
}
101109

102-
override protected def formatVersion: String = "1.0"
110+
override protected def formatVersion: String = "2.0"
103111
}
104112

105113
object NaiveBayesModel extends Loader[NaiveBayesModel] {
106114

107115
import org.apache.spark.mllib.util.Loader._
108116

109-
private object SaveLoadV1_0 {
117+
private[mllib] object SaveLoadV2_0 {
110118

111-
def thisFormatVersion: String = "1.0"
119+
def thisFormatVersion: String = "2.0"
112120

113121
/** Hard-code class name string in case it changes in the future */
114122
def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"
@@ -127,8 +135,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
127135
// Create JSON metadata.
128136
val metadata = compact(render(
129137
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
130-
("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length) ~
131-
("modelType" -> data.modelType)))
138+
("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
132139
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
133140

134141
// Create Parquet data.
@@ -151,36 +158,82 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
151158
val modelType = NaiveBayes.ModelType.fromString(data.getString(3))
152159
new NaiveBayesModel(labels, pi, theta, modelType)
153160
}
161+
154162
}
155163

156-
override def load(sc: SparkContext, path: String): NaiveBayesModel = {
157-
def getModelType(metadata: JValue): NaiveBayes.ModelType = {
158-
implicit val formats = DefaultFormats
159-
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String])
164+
private[mllib] object SaveLoadV1_0 {
165+
166+
def thisFormatVersion: String = "1.0"
167+
168+
/** Hard-code class name string in case it changes in the future */
169+
def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"
170+
171+
/** Model data for model import/export */
172+
case class Data(
173+
labels: Array[Double],
174+
pi: Array[Double],
175+
theta: Array[Array[Double]])
176+
177+
def save(sc: SparkContext, path: String, data: Data): Unit = {
178+
val sqlContext = new SQLContext(sc)
179+
import sqlContext.implicits._
180+
181+
// Create JSON metadata.
182+
val metadata = compact(render(
183+
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
184+
("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
185+
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
186+
187+
// Create Parquet data.
188+
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
189+
dataRDD.saveAsParquetFile(dataPath(path))
190+
}
191+
192+
def load(sc: SparkContext, path: String): NaiveBayesModel = {
193+
val sqlContext = new SQLContext(sc)
194+
// Load Parquet data.
195+
val dataRDD = sqlContext.parquetFile(dataPath(path))
196+
// Check schema explicitly since erasure makes it hard to use match-case for checking.
197+
checkSchema[Data](dataRDD.schema)
198+
val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
199+
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
200+
val data = dataArray(0)
201+
val labels = data.getAs[Seq[Double]](0).toArray
202+
val pi = data.getAs[Seq[Double]](1).toArray
203+
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
204+
new NaiveBayesModel(labels, pi, theta)
160205
}
206+
}
207+
208+
override def load(sc: SparkContext, path: String): NaiveBayesModel = {
161209
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
162210
val classNameV1_0 = SaveLoadV1_0.thisClassName
163-
(loadedClassName, version) match {
211+
val classNameV2_0 = SaveLoadV2_0.thisClassName
212+
val (model, numFeatures, numClasses) = (loadedClassName, version) match {
164213
case (className, "1.0") if className == classNameV1_0 =>
165214
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
166215
val model = SaveLoadV1_0.load(sc, path)
167-
assert(model.pi.size == numClasses,
168-
s"NaiveBayesModel.load expected $numClasses classes," +
169-
s" but class priors vector pi had ${model.pi.size} elements")
170-
assert(model.theta.size == numClasses,
171-
s"NaiveBayesModel.load expected $numClasses classes," +
172-
s" but class conditionals array theta had ${model.theta.size} elements")
173-
assert(model.theta.forall(_.size == numFeatures),
174-
s"NaiveBayesModel.load expected $numFeatures features," +
175-
s" but class conditionals array theta had elements of size:" +
176-
s" ${model.theta.map(_.size).mkString(",")}")
177-
assert(model.modelType == getModelType(metadata))
178-
model
216+
(model, numFeatures, numClasses)
217+
case (className, "2.0") if className == classNameV2_0 =>
218+
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
219+
val model = SaveLoadV2_0.load(sc, path)
220+
(model, numFeatures, numClasses)
179221
case _ => throw new Exception(
180222
s"NaiveBayesModel.load did not recognize model with (className, format version):" +
181223
s"($loadedClassName, $version). Supported:\n" +
182224
s" ($classNameV1_0, 1.0)")
183225
}
226+
assert(model.pi.size == numClasses,
227+
s"NaiveBayesModel.load expected $numClasses classes," +
228+
s" but class priors vector pi had ${model.pi.size} elements")
229+
assert(model.theta.size == numClasses,
230+
s"NaiveBayesModel.load expected $numClasses classes," +
231+
s" but class conditionals array theta had ${model.theta.size} elements")
232+
assert(model.theta.forall(_.size == numFeatures),
233+
s"NaiveBayesModel.load expected $numFeatures features," +
234+
s" but class conditionals array theta had elements of size:" +
235+
s" ${model.theta.map(_.size).mkString(",")}")
236+
model
184237
}
185238
}
186239

@@ -197,9 +250,9 @@ class NaiveBayes private (
197250
private var lambda: Double,
198251
private var modelType: NaiveBayes.ModelType) extends Serializable with Logging {
199252

200-
def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
253+
def this(lambda: Double) = this(lambda, Multinomial)
201254

202-
def this() = this(1.0, NaiveBayes.Multinomial)
255+
def this() = this(1.0, Multinomial)
203256

204257
/** Set the smoothing parameter. Default: 1.0. */
205258
def setLambda(lambda: Double): NaiveBayes = {
@@ -210,9 +263,22 @@ class NaiveBayes private (
210263
/** Get the smoothing parameter. */
211264
def getLambda: Double = lambda
212265

213-
/** Set the model type. Default: Multinomial. */
214-
def setModelType(model: NaiveBayes.ModelType): NaiveBayes = {
215-
this.modelType = model
266+
/**
267+
* Set the model type using a string (case-insensitive).
268+
* Supported options: "multinomial" and "bernoulli".
269+
* (default: multinomial)
270+
*/
271+
def setModelType(modelType: String): NaiveBayes = {
272+
setModelType(NaiveBayes.ModelType.fromString(modelType))
273+
}
274+
275+
/**
276+
* Set the model type.
277+
* Supported options: [[NaiveBayes.ModelType.Bernoulli]], [[NaiveBayes.ModelType.Multinomial]]
278+
* (default: Multinomial)
279+
*/
280+
def setModelType(modelType: NaiveBayes.ModelType): NaiveBayes = {
281+
this.modelType = modelType
216282
this
217283
}
218284

@@ -270,8 +336,11 @@ class NaiveBayes private (
270336
labels(i) = label
271337
pi(i) = math.log(n + lambda) - piLogDenom
272338
val thetaLogDenom = modelType match {
273-
case NaiveBayes.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
274-
case NaiveBayes.Bernoulli => math.log(n + 2.0 * lambda)
339+
case Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
340+
case Bernoulli => math.log(n + 2.0 * lambda)
341+
case _ =>
342+
// This should never happen.
343+
throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
275344
}
276345
var j = 0
277346
while (j < numFeatures) {
@@ -317,7 +386,7 @@ object NaiveBayes {
317386
* @param lambda The smoothing parameter
318387
*/
319388
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
320-
new NaiveBayes(lambda, NaiveBayes.Multinomial).run(input)
389+
new NaiveBayes(lambda, NaiveBayes.ModelType.Multinomial).run(input)
321390
}
322391

323392
/**
@@ -339,35 +408,42 @@ object NaiveBayes {
339408
* multinomial or bernoulli
340409
*/
341410
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
342-
new NaiveBayes(lambda, MODELTYPE.fromString(modelType)).run(input)
411+
new NaiveBayes(lambda, ModelType.fromString(modelType)).run(input)
343412
}
344413

345414
/** Provides static methods for using ModelType. */
346415
sealed abstract class ModelType extends Serializable
347416

348-
object MODELTYPE extends Serializable{
349-
final val MULTINOMIAL_STRING = "multinomial"
350-
final val BERNOULLI_STRING = "bernoulli"
417+
object ModelType extends Serializable {
351418

352-
def fromString(modelType: String): ModelType = modelType match {
353-
case MULTINOMIAL_STRING => Multinomial
354-
case BERNOULLI_STRING => Bernoulli
419+
/**
420+
* Get the model type from a string.
421+
* @param modelType Supported: "multinomial" or "bernoulli" (case-insensitive)
422+
*/
423+
def fromString(modelType: String): ModelType = modelType.toLowerCase match {
424+
case "multinomial" => Multinomial
425+
case "bernoulli" => Bernoulli
355426
case _ =>
356-
throw new IllegalArgumentException(s"Cannot recognize NaiveBayes ModelType: $modelType")
427+
throw new IllegalArgumentException(
428+
s"NaiveBayes.ModelType.fromString did not recognize string: $modelType")
357429
}
358-
}
359430

360-
final val ModelType = MODELTYPE
431+
final val Multinomial: ModelType = {
432+
case object Multinomial extends ModelType with Serializable {
433+
override def toString: String = "multinomial"
434+
}
435+
Multinomial
436+
}
361437

362-
/** Constant for specifying ModelType parameter: multinomial model */
363-
final val Multinomial: ModelType = new ModelType {
364-
override def toString: String = ModelType.MULTINOMIAL_STRING
438+
final val Bernoulli: ModelType = {
439+
case object Bernoulli extends ModelType with Serializable {
440+
override def toString: String = "bernoulli"
441+
}
442+
Bernoulli
443+
}
365444
}
366445

367-
/** Constant for specifying ModelType parameter: bernoulli model */
368-
final val Bernoulli: ModelType = new ModelType {
369-
override def toString: String = ModelType.BERNOULLI_STRING
370-
}
446+
/** Java-friendly accessor for supported ModelType options */
447+
final val modelTypes = ModelType
371448

372449
}
373-

mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,22 @@
1717

1818
package org.apache.spark.mllib.classification;
1919

20+
import java.io.Serializable;
21+
import java.util.Arrays;
22+
import java.util.List;
23+
24+
import org.junit.After;
25+
import org.junit.Assert;
26+
import org.junit.Before;
27+
import org.junit.Test;
28+
2029
import org.apache.spark.api.java.JavaRDD;
2130
import org.apache.spark.api.java.JavaSparkContext;
2231
import org.apache.spark.api.java.function.Function;
2332
import org.apache.spark.mllib.linalg.Vector;
2433
import org.apache.spark.mllib.linalg.Vectors;
2534
import org.apache.spark.mllib.regression.LabeledPoint;
26-
import org.junit.After;
27-
import org.junit.Assert;
28-
import org.junit.Before;
29-
import org.junit.Test;
3035

31-
import java.io.Serializable;
32-
import java.util.Arrays;
33-
import java.util.List;
3436

3537
public class JavaNaiveBayesSuite implements Serializable {
3638
private transient JavaSparkContext sc;
@@ -102,4 +104,11 @@ public Vector call(LabeledPoint v) throws Exception {
102104
// Should be able to get the first prediction.
103105
predictions.first();
104106
}
107+
108+
@Test
109+
public void testModelTypeSetters() {
110+
NaiveBayes nb = new NaiveBayes()
111+
.setModelType(NaiveBayes.modelTypes().Bernoulli())
112+
.setModelType(NaiveBayes.modelTypes().Multinomial());
113+
}
105114
}

0 commit comments

Comments
 (0)