@@ -47,12 +47,44 @@ use serde::{Deserialize, Serialize};
47
47
48
48
/// Naive Bayes classifier for Bearnoulli features
49
49
#[ cfg_attr( feature = "serde" , derive( Serialize , Deserialize ) ) ]
50
- #[ derive( Debug , PartialEq ) ]
50
+ #[ derive( Debug ) ]
51
51
struct BernoulliNBDistribution < T : RealNumber > {
52
52
/// class labels known to the classifier
53
53
class_labels : Vec < T > ,
54
+ /// number of training samples observed in each class
55
+ class_count : Vec < usize > ,
56
+ /// probability of each class
54
57
class_priors : Vec < T > ,
55
- feature_prob : Vec < Vec < T > > ,
58
+ /// Number of samples encountered for each (class, feature)
59
+ feature_count : Vec < Vec < usize > > ,
60
+ /// probability of features per class
61
+ feature_log_prob : Vec < Vec < T > > ,
62
+ /// Number of features of each sample
63
+ n_features : usize ,
64
+ }
65
+
66
+ impl < T : RealNumber > PartialEq for BernoulliNBDistribution < T > {
67
+ fn eq ( & self , other : & Self ) -> bool {
68
+ if self . class_labels == other. class_labels
69
+ && self . class_count == other. class_count
70
+ && self . class_priors == other. class_priors
71
+ && self . feature_count == other. feature_count
72
+ && self . n_features == other. n_features
73
+ {
74
+ for ( a, b) in self
75
+ . feature_log_prob
76
+ . iter ( )
77
+ . zip ( other. feature_log_prob . iter ( ) )
78
+ {
79
+ if !a. approximate_eq ( b, T :: epsilon ( ) ) {
80
+ return false ;
81
+ }
82
+ }
83
+ true
84
+ } else {
85
+ false
86
+ }
87
+ }
56
88
}
57
89
58
90
impl < T : RealNumber , M : Matrix < T > > NBDistribution < T , M > for BernoulliNBDistribution < T > {
@@ -65,9 +97,9 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistributi
65
97
for feature in 0 ..j. len ( ) {
66
98
let value = j. get ( feature) ;
67
99
if value == T :: one ( ) {
68
- likelihood += self . feature_prob [ class_index] [ feature] . ln ( ) ;
100
+ likelihood += self . feature_log_prob [ class_index] [ feature] ;
69
101
} else {
70
- likelihood += ( T :: one ( ) - self . feature_prob [ class_index] [ feature] ) . ln ( ) ;
102
+ likelihood += ( T :: one ( ) - self . feature_log_prob [ class_index] [ feature] . exp ( ) ) . ln ( ) ;
71
103
}
72
104
}
73
105
likelihood
@@ -157,10 +189,10 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
157
189
let y = y. to_vec ( ) ;
158
190
159
191
let ( class_labels, indices) = <Vec < T > as RealNumberVector < T > >:: unique_with_indices ( & y) ;
160
- let mut class_count = vec ! [ T :: zero ( ) ; class_labels. len( ) ] ;
192
+ let mut class_count = vec ! [ 0_usize ; class_labels. len( ) ] ;
161
193
162
194
for class_index in indices. iter ( ) {
163
- class_count[ * class_index] += T :: one ( ) ;
195
+ class_count[ * class_index] += 1 ;
164
196
}
165
197
166
198
let class_priors = if let Some ( class_priors) = priors {
@@ -173,33 +205,46 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
173
205
} else {
174
206
class_count
175
207
. iter ( )
176
- . map ( |& c| c / T :: from ( n_samples) . unwrap ( ) )
208
+ . map ( |& c| T :: from ( c ) . unwrap ( ) / T :: from ( n_samples) . unwrap ( ) )
177
209
. collect ( )
178
210
} ;
179
211
180
- let mut feature_in_class_counter = vec ! [ vec![ T :: zero ( ) ; n_features] ; class_labels. len( ) ] ;
212
+ let mut feature_in_class_counter = vec ! [ vec![ 0_usize ; n_features] ; class_labels. len( ) ] ;
181
213
182
214
for ( row, class_index) in row_iter ( x) . zip ( indices) {
183
215
for ( idx, row_i) in row. iter ( ) . enumerate ( ) . take ( n_features) {
184
- feature_in_class_counter[ class_index] [ idx] += * row_i;
216
+ feature_in_class_counter[ class_index] [ idx] +=
217
+ row_i. to_usize ( ) . ok_or_else ( || {
218
+ Failed :: fit ( & format ! (
219
+ "Elements of the matrix should be 1.0 or 0.0 |found|=[{}]" ,
220
+ row_i
221
+ ) )
222
+ } ) ?;
185
223
}
186
224
}
187
225
188
- let feature_prob = feature_in_class_counter
226
+ let feature_log_prob = feature_in_class_counter
189
227
. iter ( )
190
228
. enumerate ( )
191
229
. map ( |( class_index, feature_count) | {
192
230
feature_count
193
231
. iter ( )
194
- . map ( |& count| ( count + alpha) / ( class_count[ class_index] + alpha * T :: two ( ) ) )
232
+ . map ( |& count| {
233
+ ( ( T :: from ( count) . unwrap ( ) + alpha)
234
+ / ( T :: from ( class_count[ class_index] ) . unwrap ( ) + alpha * T :: two ( ) ) )
235
+ . ln ( )
236
+ } )
195
237
. collect ( )
196
238
} )
197
239
. collect ( ) ;
198
240
199
241
Ok ( Self {
200
242
class_labels,
201
243
class_priors,
202
- feature_prob,
244
+ class_count,
245
+ feature_count : feature_in_class_counter,
246
+ feature_log_prob,
247
+ n_features,
203
248
} )
204
249
}
205
250
}
@@ -266,6 +311,34 @@ impl<T: RealNumber, M: Matrix<T>> BernoulliNB<T, M> {
266
311
self . inner . predict ( x)
267
312
}
268
313
}
314
+
315
+ /// Class labels known to the classifier.
316
+ /// Returns a vector of size n_classes.
317
+ pub fn classes ( & self ) -> & Vec < T > {
318
+ & self . inner . distribution . class_labels
319
+ }
320
+
321
+ /// Number of training samples observed in each class.
322
+ /// Returns a vector of size n_classes.
323
+ pub fn class_count ( & self ) -> & Vec < usize > {
324
+ & self . inner . distribution . class_count
325
+ }
326
+
327
+ /// Number of features of each sample
328
+ pub fn n_features ( & self ) -> usize {
329
+ self . inner . distribution . n_features
330
+ }
331
+
332
+ /// Number of samples encountered for each (class, feature)
333
+ /// Returns a 2d vector of shape (n_classes, n_features)
334
+ pub fn feature_count ( & self ) -> & Vec < Vec < usize > > {
335
+ & self . inner . distribution . feature_count
336
+ }
337
+
338
+ /// Empirical log probability of features given a class
339
+ pub fn feature_log_prob ( & self ) -> & Vec < Vec < T > > {
340
+ & self . inner . distribution . feature_log_prob
341
+ }
269
342
}
270
343
271
344
#[ cfg( test) ]
@@ -296,10 +369,24 @@ mod tests {
296
369
297
370
assert_eq ! ( bnb. inner. distribution. class_priors, & [ 0.75 , 0.25 ] ) ;
298
371
assert_eq ! (
299
- bnb. inner . distribution . feature_prob ,
372
+ bnb. feature_log_prob ( ) ,
300
373
& [
301
- & [ 0.4 , 0.8 , 0.2 , 0.4 , 0.4 , 0.2 ] ,
302
- & [ 1. / 3.0 , 2. / 3.0 , 2. / 3.0 , 1. / 3.0 , 1. / 3.0 , 2. / 3.0 ]
374
+ & [
375
+ -0.916290731874155 ,
376
+ -0.2231435513142097 ,
377
+ -1.6094379124341003 ,
378
+ -0.916290731874155 ,
379
+ -0.916290731874155 ,
380
+ -1.6094379124341003
381
+ ] ,
382
+ & [
383
+ -1.0986122886681098 ,
384
+ -0.40546510810816444 ,
385
+ -0.40546510810816444 ,
386
+ -1.0986122886681098 ,
387
+ -1.0986122886681098 ,
388
+ -0.40546510810816444
389
+ ]
303
390
]
304
391
) ;
305
392
@@ -335,13 +422,36 @@ mod tests {
335
422
336
423
let y_hat = bnb. predict ( & x) . unwrap ( ) ;
337
424
425
+ assert_eq ! ( bnb. classes( ) , & [ 0. , 1. , 2. ] ) ;
426
+ assert_eq ! ( bnb. class_count( ) , & [ 7 , 3 , 5 ] ) ;
427
+ assert_eq ! ( bnb. n_features( ) , 10 ) ;
428
+ assert_eq ! (
429
+ bnb. feature_count( ) ,
430
+ & [
431
+ & [ 5 , 6 , 6 , 7 , 6 , 4 , 6 , 7 , 7 , 7 ] ,
432
+ & [ 3 , 3 , 3 , 1 , 3 , 2 , 3 , 2 , 2 , 3 ] ,
433
+ & [ 4 , 4 , 3 , 4 , 5 , 2 , 4 , 5 , 3 , 4 ]
434
+ ]
435
+ ) ;
436
+
338
437
assert ! ( bnb
339
438
. inner
340
439
. distribution
341
440
. class_priors
342
441
. approximate_eq( & vec!( 0.46 , 0.2 , 0.33 ) , 1e-2 ) ) ;
343
- assert ! ( bnb. inner. distribution. feature_prob[ 1 ] . approximate_eq(
344
- & vec!( 0.8 , 0.8 , 0.8 , 0.4 , 0.8 , 0.6 , 0.8 , 0.6 , 0.6 , 0.8 ) ,
442
+ assert ! ( bnb. feature_log_prob( ) [ 1 ] . approximate_eq(
443
+ & vec![
444
+ -0.22314355 ,
445
+ -0.22314355 ,
446
+ -0.22314355 ,
447
+ -0.91629073 ,
448
+ -0.22314355 ,
449
+ -0.51082562 ,
450
+ -0.22314355 ,
451
+ -0.51082562 ,
452
+ -0.51082562 ,
453
+ -0.22314355
454
+ ] ,
345
455
1e-1
346
456
) ) ;
347
457
assert ! ( y_hat. approximate_eq(
0 commit comments