@@ -19,9 +19,9 @@ package org.apache.spark.mllib.classification
19
19
20
20
import breeze .linalg .{DenseMatrix => BDM , DenseVector => BDV , argmax => brzArgmax , sum => brzSum }
21
21
22
- import org .apache .spark .Logging
22
+ import org .apache .spark .{ SparkException , Logging }
23
23
import org .apache .spark .SparkContext ._
24
- import org .apache .spark .mllib .linalg .Vector
24
+ import org .apache .spark .mllib .linalg .{ DenseVector , SparseVector , Vector }
25
25
import org .apache .spark .mllib .regression .LabeledPoint
26
26
import org .apache .spark .rdd .RDD
27
27
@@ -73,7 +73,7 @@ class NaiveBayesModel private[mllib] (
73
73
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p ]]) which can handle all kinds of
74
74
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
75
75
* document classification. By making every vector a 0-1 vector, it can also be used as
76
- * Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]).
76
+ * Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]). The input feature values must be nonnegative.
77
77
*/
78
78
class NaiveBayes private (private var lambda : Double ) extends Serializable with Logging {
79
79
@@ -91,12 +91,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
91
91
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
92
92
*/
93
93
def run (data : RDD [LabeledPoint ]) = {
94
+ val requireNonnegativeValues : Vector => Unit = (v : Vector ) => {
95
+ val values = v match {
96
+ case sv : SparseVector =>
97
+ sv.values
98
+ case dv : DenseVector =>
99
+ dv.values
100
+ }
101
+ if (! values.forall(_ >= 0.0 )) {
102
+ throw new SparkException (s " Naive Bayes requires nonnegative feature values but found $v. " )
103
+ }
104
+ }
105
+
94
106
// Aggregates term frequencies per label.
95
107
// TODO: Calling combineByKey and collect creates two stages, we can implement something
96
108
// TODO: similar to reduceByKeyLocally to save one stage.
97
109
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long , BDV [Double ])](
98
- createCombiner = (v : Vector ) => (1L , v.toBreeze.toDenseVector),
99
- mergeValue = (c : (Long , BDV [Double ]), v : Vector ) => (c._1 + 1L , c._2 += v.toBreeze),
110
+ createCombiner = (v : Vector ) => {
111
+ requireNonnegativeValues(v)
112
+ (1L , v.toBreeze.toDenseVector)
113
+ },
114
+ mergeValue = (c : (Long , BDV [Double ]), v : Vector ) => {
115
+ requireNonnegativeValues(v)
116
+ (c._1 + 1L , c._2 += v.toBreeze)
117
+ },
100
118
mergeCombiners = (c1 : (Long , BDV [Double ]), c2 : (Long , BDV [Double ])) =>
101
119
(c1._1 + c2._1, c1._2 += c2._2)
102
120
).collect()
0 commit comments