Skip to content

Commit f3c8994

Browse files
committed
changed checks on model type to requires
1 parent acb69af commit f3c8994

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,10 @@ class NaiveBayes private (
266266
* (default: Multinomial)
267267
*/
268268
def setModelType(modelType:String): NaiveBayes = {
269-
if (NaiveBayes.supportedModelTypes.contains(modelType)) {
270-
this.modelType = modelType
271-
this
272-
} else {
273-
throw new UnknownError(s"NaiveBayesModel does not support ModelType: $modelType")
274-
}
269+
require(NaiveBayes.supportedModelTypes.contains(modelType),
270+
s"NaiveBayes was created with an unknown ModelType: $modelType")
271+
this.modelType = modelType
272+
this
275273
}
276274

277275
/** Get the model type. */
@@ -404,11 +402,9 @@ object NaiveBayes {
404402
* multinomial or bernoulli
405403
*/
406404
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
407-
if (supportedModelTypes.contains(modelType)) {
408-
new NaiveBayes(lambda, modelType).run(input)
409-
} else {
410-
throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
411-
}
405+
require(supportedModelTypes.contains(modelType),
406+
s"NaiveBayes was created with an unknown ModelType: $modelType")
407+
new NaiveBayes(lambda, modelType).run(input)
412408
}
413409

414410
}

0 commit comments

Comments
 (0)