Skip to content

Commit 52c37c3

Browse files
committed
address comments
1 parent 65f892d commit 52c37c3

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
9898
case dv: DenseVector =>
9999
dv.values
100100
}
101-
if (!values.forall(x => x >= 0.0)) {
101+
if (!values.forall(_ >= 0.0)) {
102102
throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")
103103
}
104104
}

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,14 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext {
114114
intercept[SparkException] {
115115
NaiveBayes.train(sc.makeRDD(sparse, 2))
116116
}
117+
val nan = Seq(
118+
LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
119+
LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(Double.NaN))),
120+
LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
121+
LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty)))
122+
intercept[SparkException] {
123+
NaiveBayes.train(sc.makeRDD(nan, 2))
124+
}
117125
}
118126
}
119127

0 commit comments

Comments
 (0)