Skip to content

Commit 068b6fe

Browse files
committed
[SPARK-3130][MLLIB] detect negative values in naive Bayes
because NB treats feature values as term frequencies. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #2038 from mengxr/nb-neg and squashes the following commits: 52c37c3 [Xiangrui Meng] address comments 65f892d [Xiangrui Meng] detect negative values in nb
1 parent 0e3ab94 commit 068b6fe

File tree

3 files changed

+53
-6
lines changed

3 files changed

+53
-6
lines changed

docs/mllib-naive-bayes.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bay
1717
which is typically used for [document
1818
classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
1919
Within that context, each observation is a document and each
20-
feature represents a term whose value is the frequency of the term.
20+
feature represents a term whose value is the frequency of the term.
21+
Feature values must be nonnegative to represent term frequencies.
2122
[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by
2223
setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature
2324
vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ package org.apache.spark.mllib.classification
1919

2020
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
2121

22-
import org.apache.spark.Logging
22+
import org.apache.spark.{SparkException, Logging}
2323
import org.apache.spark.SparkContext._
24-
import org.apache.spark.mllib.linalg.Vector
24+
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
2525
import org.apache.spark.mllib.regression.LabeledPoint
2626
import org.apache.spark.rdd.RDD
2727

@@ -73,7 +73,7 @@ class NaiveBayesModel private[mllib] (
7373
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
7474
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
7575
* document classification. By making every vector a 0-1 vector, it can also be used as
76-
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
76+
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative.
7777
*/
7878
class NaiveBayes private (private var lambda: Double) extends Serializable with Logging {
7979

@@ -91,12 +91,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
9191
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
9292
*/
9393
def run(data: RDD[LabeledPoint]) = {
94+
val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
95+
val values = v match {
96+
case sv: SparseVector =>
97+
sv.values
98+
case dv: DenseVector =>
99+
dv.values
100+
}
101+
if (!values.forall(_ >= 0.0)) {
102+
throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")
103+
}
104+
}
105+
94106
// Aggregates term frequencies per label.
95107
// TODO: Calling combineByKey and collect creates two stages, we can implement something
96108
// TODO: similar to reduceByKeyLocally to save one stage.
97109
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
98-
createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector),
99-
mergeValue = (c: (Long, BDV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze),
110+
createCombiner = (v: Vector) => {
111+
requireNonnegativeValues(v)
112+
(1L, v.toBreeze.toDenseVector)
113+
},
114+
mergeValue = (c: (Long, BDV[Double]), v: Vector) => {
115+
requireNonnegativeValues(v)
116+
(c._1 + 1L, c._2 += v.toBreeze)
117+
},
100118
mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
101119
(c1._1 + c2._1, c1._2 += c2._2)
102120
).collect()

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.util.Random
2121

2222
import org.scalatest.FunSuite
2323

24+
import org.apache.spark.SparkException
2425
import org.apache.spark.mllib.linalg.Vectors
2526
import org.apache.spark.mllib.regression.LabeledPoint
2627
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
@@ -95,6 +96,33 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext {
9596
// Test prediction on Array.
9697
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
9798
}
99+
100+
test("detect negative values") {
101+
val dense = Seq(
102+
LabeledPoint(1.0, Vectors.dense(1.0)),
103+
LabeledPoint(0.0, Vectors.dense(-1.0)),
104+
LabeledPoint(1.0, Vectors.dense(1.0)),
105+
LabeledPoint(1.0, Vectors.dense(0.0)))
106+
intercept[SparkException] {
107+
NaiveBayes.train(sc.makeRDD(dense, 2))
108+
}
109+
val sparse = Seq(
110+
LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
111+
LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(-1.0))),
112+
LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
113+
LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty)))
114+
intercept[SparkException] {
115+
NaiveBayes.train(sc.makeRDD(sparse, 2))
116+
}
117+
val nan = Seq(
118+
LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
119+
LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(Double.NaN))),
120+
LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
121+
LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty)))
122+
intercept[SparkException] {
123+
NaiveBayes.train(sc.makeRDD(nan, 2))
124+
}
125+
}
98126
}
99127

100128
class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {

0 commit comments

Comments
 (0)