33
33
//! ## References:
34
34
//!
35
35
//! * ["Introduction to Information Retrieval", Manning C. D., Raghavan P., Schutze H., 2009, Chapter 13 ](https://nlp.stanford.edu/IR-book/information-retrieval-book.html)
36
+ use std:: ops:: Not ;
37
+
36
38
use crate :: api:: { Predictor , SupervisedEstimator } ;
37
39
use crate :: error:: Failed ;
38
40
use crate :: linalg:: row_iter;
@@ -47,12 +49,26 @@ use serde::{Deserialize, Serialize};
47
49
48
50
/// Naive Bayes classifier for Bearnoulli features
49
51
#[ cfg_attr( feature = "serde" , derive( Serialize , Deserialize ) ) ]
50
- #[ derive( Debug , PartialEq ) ]
52
+ #[ derive( Debug ) ]
51
53
struct BernoulliNBDistribution < T : RealNumber > {
52
54
/// class labels known to the classifier
53
55
class_labels : Vec < T > ,
54
56
class_priors : Vec < T > ,
55
- feature_prob : Vec < Vec < T > > ,
57
+ feature_log_prob : Vec < Vec < T > > ,
58
+ }
59
+
60
+ impl < T : RealNumber > PartialEq for BernoulliNBDistribution < T > {
61
+ fn eq ( & self , other : & Self ) -> bool {
62
+ if self . class_labels == other. class_labels && self . class_priors == other. class_priors {
63
+ self . feature_log_prob
64
+ . iter ( )
65
+ . zip ( other. feature_log_prob . iter ( ) )
66
+ . any ( |( left, right) | !left. approximate_eq ( right, T :: epsilon ( ) ) )
67
+ . not ( )
68
+ } else {
69
+ false
70
+ }
71
+ }
56
72
}
57
73
58
74
impl < T : RealNumber , M : Matrix < T > > NBDistribution < T , M > for BernoulliNBDistribution < T > {
@@ -65,9 +81,9 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistributi
65
81
for feature in 0 ..j. len ( ) {
66
82
let value = j. get ( feature) ;
67
83
if value == T :: one ( ) {
68
- likelihood += self . feature_prob [ class_index] [ feature] . ln ( ) ;
84
+ likelihood += self . feature_log_prob [ class_index] [ feature] ;
69
85
} else {
70
- likelihood += ( T :: one ( ) - self . feature_prob [ class_index] [ feature] ) . ln ( ) ;
86
+ likelihood += ( T :: one ( ) - self . feature_log_prob [ class_index] [ feature] . exp ( ) ) . ln ( ) ;
71
87
}
72
88
}
73
89
likelihood
@@ -185,21 +201,23 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
185
201
}
186
202
}
187
203
188
- let feature_prob = feature_in_class_counter
204
+ let feature_log_prob = feature_in_class_counter
189
205
. iter ( )
190
206
. enumerate ( )
191
207
. map ( |( class_index, feature_count) | {
192
208
feature_count
193
209
. iter ( )
194
- . map ( |& count| ( count + alpha) / ( class_count[ class_index] + alpha * T :: two ( ) ) )
210
+ . map ( |& count| {
211
+ ( ( count + alpha) / ( class_count[ class_index] + alpha * T :: two ( ) ) ) . ln ( )
212
+ } )
195
213
. collect ( )
196
214
} )
197
215
. collect ( ) ;
198
216
199
217
Ok ( Self {
200
218
class_labels,
201
219
class_priors,
202
- feature_prob ,
220
+ feature_log_prob ,
203
221
} )
204
222
}
205
223
}
@@ -272,6 +290,12 @@ impl<T: RealNumber, M: Matrix<T>> BernoulliNB<T, M> {
272
290
pub fn classes ( & self ) -> & Vec < T > {
273
291
& self . inner . distribution . class_labels
274
292
}
293
+
294
+ /// Empirical log probability of features given a class, P(x_i|y).
295
+ /// Returns a 2d vector of shape (n_classes, n_features)
296
+ pub fn feature_log_prob ( & self ) -> & Vec < Vec < T > > {
297
+ & self . inner . distribution . feature_log_prob
298
+ }
275
299
}
276
300
277
301
#[ cfg( test) ]
@@ -302,10 +326,24 @@ mod tests {
302
326
303
327
assert_eq ! ( bnb. inner. distribution. class_priors, & [ 0.75 , 0.25 ] ) ;
304
328
assert_eq ! (
305
- bnb. inner . distribution . feature_prob ,
329
+ bnb. feature_log_prob ( ) ,
306
330
& [
307
- & [ 0.4 , 0.8 , 0.2 , 0.4 , 0.4 , 0.2 ] ,
308
- & [ 1. / 3.0 , 2. / 3.0 , 2. / 3.0 , 1. / 3.0 , 1. / 3.0 , 2. / 3.0 ]
331
+ & [
332
+ -0.916290731874155 ,
333
+ -0.2231435513142097 ,
334
+ -1.6094379124341003 ,
335
+ -0.916290731874155 ,
336
+ -0.916290731874155 ,
337
+ -1.6094379124341003
338
+ ] ,
339
+ & [
340
+ -1.0986122886681098 ,
341
+ -0.40546510810816444 ,
342
+ -0.40546510810816444 ,
343
+ -1.0986122886681098 ,
344
+ -1.0986122886681098 ,
345
+ -0.40546510810816444
346
+ ]
309
347
]
310
348
) ;
311
349
@@ -348,10 +386,22 @@ mod tests {
348
386
. distribution
349
387
. class_priors
350
388
. approximate_eq( & vec!( 0.46 , 0.2 , 0.33 ) , 1e-2 ) ) ;
351
- assert ! ( bnb. inner. distribution. feature_prob[ 1 ] . approximate_eq(
352
- & vec!( 0.8 , 0.8 , 0.8 , 0.4 , 0.8 , 0.6 , 0.8 , 0.6 , 0.6 , 0.8 ) ,
389
+ assert ! ( bnb. feature_log_prob( ) [ 1 ] . approximate_eq(
390
+ & vec![
391
+ -0.22314355 ,
392
+ -0.22314355 ,
393
+ -0.22314355 ,
394
+ -0.91629073 ,
395
+ -0.22314355 ,
396
+ -0.51082562 ,
397
+ -0.22314355 ,
398
+ -0.51082562 ,
399
+ -0.51082562 ,
400
+ -0.22314355
401
+ ] ,
353
402
1e-1
354
403
) ) ;
404
+ println ! ( "{:?}" , y_hat) ;
355
405
assert ! ( y_hat. approximate_eq(
356
406
& vec!( 2.0 , 2.0 , 0.0 , 0.0 , 0.0 , 2.0 , 1.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ) ,
357
407
1e-5
0 commit comments