@@ -51,7 +51,11 @@ use serde::{Deserialize, Serialize};
51
51
struct BernoulliNBDistribution < T : RealNumber > {
52
52
/// class labels known to the classifier
53
53
class_labels : Vec < T > ,
54
+ /// number of training samples observed in each class
55
+ class_count : Vec < usize > ,
56
+ /// probability of each class
54
57
class_priors : Vec < T > ,
58
+ /// probability of features per class
55
59
feature_prob : Vec < Vec < T > > ,
56
60
}
57
61
@@ -157,10 +161,10 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
157
161
let y = y. to_vec ( ) ;
158
162
159
163
let ( class_labels, indices) = <Vec < T > as RealNumberVector < T > >:: unique_with_indices ( & y) ;
160
- let mut class_count = vec ! [ T :: zero ( ) ; class_labels. len( ) ] ;
164
+ let mut class_count = vec ! [ 0_usize ; class_labels. len( ) ] ;
161
165
162
166
for class_index in indices. iter ( ) {
163
- class_count[ * class_index] += T :: one ( ) ;
167
+ class_count[ * class_index] += 1 ;
164
168
}
165
169
166
170
let class_priors = if let Some ( class_priors) = priors {
@@ -173,7 +177,7 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
173
177
} else {
174
178
class_count
175
179
. iter ( )
176
- . map ( |& c| c / T :: from ( n_samples) . unwrap ( ) )
180
+ . map ( |& c| T :: from ( c ) . unwrap ( ) / T :: from ( n_samples) . unwrap ( ) )
177
181
. collect ( )
178
182
} ;
179
183
@@ -191,14 +195,18 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
191
195
. map ( |( class_index, feature_count) | {
192
196
feature_count
193
197
. iter ( )
194
- . map ( |& count| ( count + alpha) / ( class_count[ class_index] + alpha * T :: two ( ) ) )
198
+ . map ( |& count| {
199
+ ( count + alpha)
200
+ / ( T :: from ( class_count[ class_index] ) . unwrap ( ) + alpha * T :: two ( ) )
201
+ } )
195
202
. collect ( )
196
203
} )
197
204
. collect ( ) ;
198
205
199
206
Ok ( Self {
200
207
class_labels,
201
208
class_priors,
209
+ class_count,
202
210
feature_prob,
203
211
} )
204
212
}
@@ -272,6 +280,12 @@ impl<T: RealNumber, M: Matrix<T>> BernoulliNB<T, M> {
272
280
pub fn classes ( & self ) -> & Vec < T > {
273
281
& self . inner . distribution . class_labels
274
282
}
283
+
284
+ /// Number of training samples observed in each class.
285
+ /// Returns a vector of size n_classes.
286
+ pub fn class_count ( & self ) -> & Vec < usize > {
287
+ & self . inner . distribution . class_count
288
+ }
275
289
}
276
290
277
291
#[ cfg( test) ]
@@ -342,6 +356,7 @@ mod tests {
342
356
let y_hat = bnb. predict ( & x) . unwrap ( ) ;
343
357
344
358
assert_eq ! ( bnb. classes( ) , & [ 0. , 1. , 2. ] ) ;
359
+ assert_eq ! ( bnb. class_count( ) , & [ 7 , 3 , 5 ] ) ;
345
360
346
361
assert ! ( bnb
347
362
. inner
0 commit comments