@@ -52,7 +52,8 @@ struct MultinomialNBDistribution<T: RealNumber> {
52
52
/// class labels known to the classifier
53
53
class_labels : Vec < T > ,
54
54
class_priors : Vec < T > ,
55
- feature_prob : Vec < Vec < T > > ,
55
+ /// Empirical log probability of features given a class
56
+ feature_log_prob : Vec < Vec < T > > ,
56
57
}
57
58
58
59
impl < T : RealNumber , M : Matrix < T > > NBDistribution < T , M > for MultinomialNBDistribution < T > {
@@ -64,7 +65,7 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribu
64
65
let mut likelihood = T :: zero ( ) ;
65
66
for feature in 0 ..j. len ( ) {
66
67
let value = j. get ( feature) ;
67
- likelihood += value * self . feature_prob [ class_index] [ feature] . ln ( ) ;
68
+ likelihood += value * self . feature_log_prob [ class_index] [ feature] ;
68
69
}
69
70
likelihood
70
71
}
@@ -172,21 +173,23 @@ impl<T: RealNumber> MultinomialNBDistribution<T> {
172
173
}
173
174
}
174
175
175
- let feature_prob = feature_in_class_counter
176
+ let feature_log_prob = feature_in_class_counter
176
177
. iter ( )
177
178
. map ( |feature_count| {
178
179
let n_c = feature_count. sum ( ) ;
179
180
feature_count
180
181
. iter ( )
181
- . map ( |& count| ( count + alpha) / ( n_c + alpha * T :: from ( n_features) . unwrap ( ) ) )
182
+ . map ( |& count| {
183
+ ( ( count + alpha) / ( n_c + alpha * T :: from ( n_features) . unwrap ( ) ) ) . ln ( )
184
+ } )
182
185
. collect ( )
183
186
} )
184
187
. collect ( ) ;
185
188
186
189
Ok ( Self {
187
190
class_labels,
188
191
class_priors,
189
- feature_prob ,
192
+ feature_log_prob ,
190
193
} )
191
194
}
192
195
}
@@ -246,6 +249,12 @@ impl<T: RealNumber, M: Matrix<T>> MultinomialNB<T, M> {
246
249
pub fn classes ( & self ) -> & Vec < T > {
247
250
& self . inner . distribution . class_labels
248
251
}
252
+
253
+ /// Empirical log probability of features given a class, P(x_i|y).
254
+ /// Returns a 2d vector of shape (n_classes, n_features)
255
+ pub fn feature_log_prob ( & self ) -> & Vec < Vec < T > > {
256
+ & self . inner . distribution . feature_log_prob
257
+ }
249
258
}
250
259
251
260
#[ cfg( test) ]
@@ -278,10 +287,24 @@ mod tests {
278
287
279
288
assert_eq ! ( mnb. inner. distribution. class_priors, & [ 0.75 , 0.25 ] ) ;
280
289
assert_eq ! (
281
- mnb. inner . distribution . feature_prob ,
290
+ mnb. feature_log_prob ( ) ,
282
291
& [
283
- & [ 1. / 7. , 3. / 7. , 1. / 14. , 1. / 7. , 1. / 7. , 1. / 14. ] ,
284
- & [ 1. / 9. , 2. / 9.0 , 2. / 9.0 , 1. / 9.0 , 1. / 9.0 , 2. / 9.0 ]
292
+ & [
293
+ ( 1_f64 / 7_f64 ) . ln( ) ,
294
+ ( 3_f64 / 7_f64 ) . ln( ) ,
295
+ ( 1_f64 / 14_f64 ) . ln( ) ,
296
+ ( 1_f64 / 7_f64 ) . ln( ) ,
297
+ ( 1_f64 / 7_f64 ) . ln( ) ,
298
+ ( 1_f64 / 14_f64 ) . ln( )
299
+ ] ,
300
+ & [
301
+ ( 1_f64 / 9_f64 ) . ln( ) ,
302
+ ( 2_f64 / 9_f64 ) . ln( ) ,
303
+ ( 2_f64 / 9_f64 ) . ln( ) ,
304
+ ( 1_f64 / 9_f64 ) . ln( ) ,
305
+ ( 1_f64 / 9_f64 ) . ln( ) ,
306
+ ( 2_f64 / 9_f64 ) . ln( )
307
+ ]
285
308
]
286
309
) ;
287
310
@@ -322,9 +345,20 @@ mod tests {
322
345
. distribution
323
346
. class_priors
324
347
. approximate_eq( & vec!( 0.46 , 0.2 , 0.33 ) , 1e-2 ) ) ;
325
- assert ! ( nb. inner. distribution. feature_prob[ 1 ] . approximate_eq(
326
- & vec!( 0.07 , 0.12 , 0.07 , 0.15 , 0.07 , 0.09 , 0.08 , 0.10 , 0.08 , 0.11 ) ,
327
- 1e-1
348
+ assert ! ( nb. feature_log_prob( ) [ 1 ] . approximate_eq(
349
+ & vec![
350
+ -2.00148 ,
351
+ -2.35815494 ,
352
+ -2.00148 ,
353
+ -2.69462718 ,
354
+ -2.22462355 ,
355
+ -2.91777073 ,
356
+ -2.10684052 ,
357
+ -2.51230562 ,
358
+ -2.69462718 ,
359
+ -2.00148
360
+ ] ,
361
+ 1e-5
328
362
) ) ;
329
363
assert ! ( y_hat. approximate_eq(
330
364
& vec!( 2.0 , 2.0 , 0.0 , 0.0 , 0.0 , 2.0 , 2.0 , 1.0 , 0.0 , 1.0 , 0.0 , 2.0 , 0.0 , 0.0 , 2.0 ) ,
0 commit comments