Skip to content

Commit acb69af

Browse files
committed
removed enum type and replaces all modelType parameters with strings
1 parent 2224b15 commit acb69af

File tree

3 files changed

+47
-83
lines changed

3 files changed

+47
-83
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 33 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
3535
import org.apache.spark.rdd.RDD
3636
import org.apache.spark.sql.{DataFrame, SQLContext}
3737

38-
import NaiveBayes.ModelType.{Bernoulli, Multinomial}
39-
4038

4139
/**
4240
* Model for Naive Bayes Classifiers.
@@ -45,18 +43,17 @@ import NaiveBayes.ModelType.{Bernoulli, Multinomial}
4543
* @param pi log of class priors, whose dimension is C, number of labels
4644
* @param theta log of class conditional probabilities, whose dimension is C-by-D,
4745
* where D is number of features
48-
* @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
49-
* Multinomial or Bernoulli
46+
* @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli"
5047
*/
5148
class NaiveBayesModel private[mllib] (
5249
val labels: Array[Double],
5350
val pi: Array[Double],
5451
val theta: Array[Array[Double]],
55-
val modelType: NaiveBayes.ModelType)
52+
val modelType: String)
5653
extends ClassificationModel with Serializable with Saveable {
5754

5855
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
59-
this(labels, pi, theta, Multinomial)
56+
this(labels, pi, theta, "Multinomial")
6057

6158
/** A Java-friendly constructor that takes three Iterable parameters. */
6259
private[mllib] def this(
@@ -72,8 +69,8 @@ class NaiveBayesModel private[mllib] (
7269
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
7370
// application of this condition (in predict function).
7471
private val (brzNegTheta, brzNegThetaSum) = modelType match {
75-
case Multinomial => (None, None)
76-
case Bernoulli =>
72+
case "Multinomial" => (None, None)
73+
case "Bernoulli" =>
7774
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
7875
(Option(negTheta), Option(brzSum(negTheta, Axis._1)))
7976
case _ =>
@@ -91,9 +88,9 @@ class NaiveBayesModel private[mllib] (
9188

9289
override def predict(testData: Vector): Double = {
9390
modelType match {
94-
case Multinomial =>
91+
case "Multinomial" =>
9592
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
96-
case Bernoulli =>
93+
case "Bernoulli" =>
9794
labels (brzArgmax (brzPi +
9895
(brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
9996
case _ =>
@@ -103,7 +100,7 @@ class NaiveBayesModel private[mllib] (
103100
}
104101

105102
override def save(sc: SparkContext, path: String): Unit = {
106-
val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType.toString)
103+
val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType)
107104
NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
108105
}
109106

@@ -155,7 +152,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
155152
val labels = data.getAs[Seq[Double]](0).toArray
156153
val pi = data.getAs[Seq[Double]](1).toArray
157154
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
158-
val modelType = NaiveBayes.ModelType.fromString(data.getString(3))
155+
val modelType = data.getString(3)
159156
new NaiveBayesModel(labels, pi, theta, modelType)
160157
}
161158

@@ -248,11 +245,11 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
248245

249246
class NaiveBayes private (
250247
private var lambda: Double,
251-
private var modelType: NaiveBayes.ModelType) extends Serializable with Logging {
248+
private var modelType: String) extends Serializable with Logging {
252249

253-
def this(lambda: Double) = this(lambda, Multinomial)
250+
def this(lambda: Double) = this(lambda, "Multinomial")
254251

255-
def this() = this(1.0, Multinomial)
252+
def this() = this(1.0, "Multinomial")
256253

257254
/** Set the smoothing parameter. Default: 1.0. */
258255
def setLambda(lambda: Double): NaiveBayes = {
@@ -264,26 +261,21 @@ class NaiveBayes private (
264261
def getLambda: Double = lambda
265262

266263
/**
267-
* Set the model type using a string (case-insensitive).
268-
* Supported options: "multinomial" and "bernoulli".
269-
* (default: multinomial)
270-
*/
271-
def setModelType(modelType: String): NaiveBayes = {
272-
setModelType(NaiveBayes.ModelType.fromString(modelType))
273-
}
274-
275-
/**
276-
* Set the model type.
277-
* Supported options: [[NaiveBayes.ModelType.Bernoulli]], [[NaiveBayes.ModelType.Multinomial]]
264+
* Set the model type using a string (case-sensitive).
265+
* Supported options: "Multinomial" and "Bernoulli".
278266
* (default: Multinomial)
279267
*/
280-
def setModelType(modelType: NaiveBayes.ModelType): NaiveBayes = {
281-
this.modelType = modelType
282-
this
268+
def setModelType(modelType:String): NaiveBayes = {
269+
if (NaiveBayes.supportedModelTypes.contains(modelType)) {
270+
this.modelType = modelType
271+
this
272+
} else {
273+
throw new UnknownError(s"NaiveBayesModel does not support ModelType: $modelType")
274+
}
283275
}
284276

285277
/** Get the model type. */
286-
def getModelType: NaiveBayes.ModelType = this.modelType
278+
def getModelType: String = this.modelType
287279

288280
/**
289281
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
@@ -336,8 +328,8 @@ class NaiveBayes private (
336328
labels(i) = label
337329
pi(i) = math.log(n + lambda) - piLogDenom
338330
val thetaLogDenom = modelType match {
339-
case Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
340-
case Bernoulli => math.log(n + 2.0 * lambda)
331+
case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
332+
case "Bernoulli" => math.log(n + 2.0 * lambda)
341333
case _ =>
342334
// This should never happen.
343335
throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
@@ -358,6 +350,10 @@ class NaiveBayes private (
358350
* Top-level methods for calling naive Bayes.
359351
*/
360352
object NaiveBayes {
353+
354+
/* Set of modelTypes that NaiveBayes supports */
355+
private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli")
356+
361357
/**
362358
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
363359
*
@@ -386,7 +382,7 @@ object NaiveBayes {
386382
* @param lambda The smoothing parameter
387383
*/
388384
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
389-
new NaiveBayes(lambda, NaiveBayes.ModelType.Multinomial).run(input)
385+
new NaiveBayes(lambda, "Multinomial").run(input)
390386
}
391387

392388
/**
@@ -408,42 +404,11 @@ object NaiveBayes {
408404
* multinomial or bernoulli
409405
*/
410406
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
411-
new NaiveBayes(lambda, ModelType.fromString(modelType)).run(input)
412-
}
413-
414-
/** Provides static methods for using ModelType. */
415-
sealed abstract class ModelType extends Serializable
416-
417-
object ModelType extends Serializable {
418-
419-
/**
420-
* Get the model type from a string.
421-
* @param modelType Supported: "multinomial" or "bernoulli" (case-insensitive)
422-
*/
423-
def fromString(modelType: String): ModelType = modelType.toLowerCase match {
424-
case "multinomial" => Multinomial
425-
case "bernoulli" => Bernoulli
426-
case _ =>
427-
throw new IllegalArgumentException(
428-
s"NaiveBayes.ModelType.fromString did not recognize string: $modelType")
429-
}
430-
431-
final val Multinomial: ModelType = {
432-
case object Multinomial extends ModelType with Serializable {
433-
override def toString: String = "multinomial"
434-
}
435-
Multinomial
436-
}
437-
438-
final val Bernoulli: ModelType = {
439-
case object Bernoulli extends ModelType with Serializable {
440-
override def toString: String = "bernoulli"
441-
}
442-
Bernoulli
407+
if (supportedModelTypes.contains(modelType)) {
408+
new NaiveBayes(lambda, modelType).run(input)
409+
} else {
410+
throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
443411
}
444412
}
445413

446-
/** Java-friendly accessor for supported ModelType options */
447-
final val modelTypes = ModelType
448-
449414
}

mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ public Vector call(LabeledPoint v) throws Exception {
108108
@Test
109109
public void testModelTypeSetters() {
110110
NaiveBayes nb = new NaiveBayes()
111-
.setModelType(NaiveBayes.modelTypes().Bernoulli())
112-
.setModelType(NaiveBayes.modelTypes().Multinomial());
111+
.setModelType("Bernoulli")
112+
.setModelType("Multinomial");
113113
}
114114
}

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import breeze.stats.distributions.{Multinomial => BrzMultinomial}
2525
import org.scalatest.FunSuite
2626

2727
import org.apache.spark.SparkException
28-
import org.apache.spark.mllib.classification.NaiveBayes.ModelType.{Bernoulli, Multinomial}
2928
import org.apache.spark.mllib.linalg.Vectors
3029
import org.apache.spark.mllib.regression.LabeledPoint
3130
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -49,7 +48,7 @@ object NaiveBayesSuite {
4948
theta: Array[Array[Double]], // CXD
5049
nPoints: Int,
5150
seed: Int,
52-
modelType: NaiveBayes.ModelType = Multinomial,
51+
modelType: String = "Multinomial",
5352
sample: Int = 10): Seq[LabeledPoint] = {
5453
val D = theta(0).length
5554
val rnd = new Random(seed)
@@ -59,10 +58,10 @@ object NaiveBayesSuite {
5958
for (i <- 0 until nPoints) yield {
6059
val y = calcLabel(rnd.nextDouble(), _pi)
6160
val xi = modelType match {
62-
case Bernoulli => Array.tabulate[Double] (D) { j =>
61+
case "Bernoulli" => Array.tabulate[Double] (D) { j =>
6362
if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
6463
}
65-
case Multinomial =>
64+
case "Multinomial" =>
6665
val mult = BrzMultinomial(BDV(_theta(y)))
6766
val emptyMap = (0 until D).map(x => (x, 0.0)).toMap
6867
val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
@@ -81,12 +80,12 @@ object NaiveBayesSuite {
8180
/** Bernoulli NaiveBayes with binary labels, 3 features */
8281
private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
8382
pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
84-
Bernoulli)
83+
"Bernoulli")
8584

8685
/** Multinomial NaiveBayes with binary labels, 3 features */
8786
private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
8887
pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
89-
Multinomial)
88+
"Multinomial")
9089
}
9190

9291
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
@@ -136,15 +135,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
136135
).map(_.map(math.log))
137136

138137
val testData = NaiveBayesSuite.generateNaiveBayesInput(
139-
pi, theta, nPoints, 42, Multinomial)
138+
pi, theta, nPoints, 42, "Multinomial")
140139
val testRDD = sc.parallelize(testData, 2)
141140
testRDD.cache()
142141

143-
val model = NaiveBayes.train(testRDD, 1.0, "multinomial")
142+
val model = NaiveBayes.train(testRDD, 1.0, "Multinomial")
144143
validateModelFit(pi, theta, model)
145144

146145
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
147-
pi, theta, nPoints, 17, Multinomial)
146+
pi, theta, nPoints, 17, "Multinomial")
148147
val validationRDD = sc.parallelize(validationData, 2)
149148

150149
// Test prediction on RDD.
@@ -164,15 +163,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
164163
).map(_.map(math.log))
165164

166165
val testData = NaiveBayesSuite.generateNaiveBayesInput(
167-
pi, theta, nPoints, 45, Bernoulli)
166+
pi, theta, nPoints, 45, "Bernoulli")
168167
val testRDD = sc.parallelize(testData, 2)
169168
testRDD.cache()
170169

171-
val model = NaiveBayes.train(testRDD, 1.0, "bernoulli")
170+
val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli")
172171
validateModelFit(pi, theta, model)
173172

174173
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
175-
pi, theta, nPoints, 20, Bernoulli)
174+
pi, theta, nPoints, 20, "Bernoulli")
176175
val validationRDD = sc.parallelize(validationData, 2)
177176

178177
// Test prediction on RDD.
@@ -243,7 +242,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
243242
assert(model.labels === sameModel.labels)
244243
assert(model.pi === sameModel.pi)
245244
assert(model.theta === sameModel.theta)
246-
assert(model.modelType === NaiveBayes.ModelType.Multinomial)
245+
assert(model.modelType === "Multinomial")
247246
} finally {
248247
Utils.deleteRecursively(tempDir)
249248
}

0 commit comments

Comments
 (0)