Skip to content

Commit aa4c144

Browse files
committed
Add class_count to GaussianNB, BernoulliNB
1 parent 06fda75 commit aa4c144

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

src/naive_bayes/bernoulli.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@ use serde::{Deserialize, Serialize};
5151
struct BernoulliNBDistribution<T: RealNumber> {
5252
/// class labels known to the classifier
5353
class_labels: Vec<T>,
54+
/// number of training samples observed in each class
55+
class_count: Vec<usize>,
56+
/// probability of each class
5457
class_priors: Vec<T>,
58+
/// probability of features per class
5559
feature_prob: Vec<Vec<T>>,
5660
}
5761

@@ -157,10 +161,10 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
157161
let y = y.to_vec();
158162

159163
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
160-
let mut class_count = vec![T::zero(); class_labels.len()];
164+
let mut class_count = vec![0_usize; class_labels.len()];
161165

162166
for class_index in indices.iter() {
163-
class_count[*class_index] += T::one();
167+
class_count[*class_index] += 1;
164168
}
165169

166170
let class_priors = if let Some(class_priors) = priors {
@@ -173,7 +177,7 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
173177
} else {
174178
class_count
175179
.iter()
176-
.map(|&c| c / T::from(n_samples).unwrap())
180+
.map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap())
177181
.collect()
178182
};
179183

@@ -191,14 +195,18 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
191195
.map(|(class_index, feature_count)| {
192196
feature_count
193197
.iter()
194-
.map(|&count| (count + alpha) / (class_count[class_index] + alpha * T::two()))
198+
.map(|&count| {
199+
(count + alpha)
200+
/ (T::from(class_count[class_index]).unwrap() + alpha * T::two())
201+
})
195202
.collect()
196203
})
197204
.collect();
198205

199206
Ok(Self {
200207
class_labels,
201208
class_priors,
209+
class_count,
202210
feature_prob,
203211
})
204212
}
@@ -272,6 +280,12 @@ impl<T: RealNumber, M: Matrix<T>> BernoulliNB<T, M> {
272280
pub fn classes(&self) -> &Vec<T> {
273281
&self.inner.distribution.class_labels
274282
}
283+
284+
/// Number of training samples observed in each class.
285+
/// Returns a vector of size n_classes.
286+
pub fn class_count(&self) -> &Vec<usize> {
287+
&self.inner.distribution.class_count
288+
}
275289
}
276290

277291
#[cfg(test)]
@@ -342,6 +356,7 @@ mod tests {
342356
let y_hat = bnb.predict(&x).unwrap();
343357

344358
assert_eq!(bnb.classes(), &[0., 1., 2.]);
359+
assert_eq!(bnb.class_count(), &[7, 3, 5]);
345360

346361
assert!(bnb
347362
.inner

src/naive_bayes/gaussian.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ use serde::{Deserialize, Serialize};
3939
struct GaussianNBDistribution<T: RealNumber> {
4040
/// class labels known to the classifier
4141
class_labels: Vec<T>,
42+
/// number of training samples observed in each class
43+
class_count: Vec<usize>,
4244
/// probability of each class.
4345
class_priors: Vec<T>,
4446
/// variance of each feature per class
@@ -117,12 +119,12 @@ impl<T: RealNumber> GaussianNBDistribution<T> {
117119
let y = y.to_vec();
118120
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
119121

120-
let mut class_count = vec![T::zero(); class_labels.len()];
122+
let mut class_count = vec![0_usize; class_labels.len()];
121123

122124
let mut subdataset: Vec<Vec<Vec<T>>> = vec![vec![]; class_labels.len()];
123125

124126
for (row, class_index) in row_iter(x).zip(indices.iter()) {
125-
class_count[*class_index] += T::one();
127+
class_count[*class_index] += 1;
126128
subdataset[*class_index].push(row);
127129
}
128130

@@ -135,8 +137,8 @@ impl<T: RealNumber> GaussianNBDistribution<T> {
135137
class_priors
136138
} else {
137139
class_count
138-
.into_iter()
139-
.map(|c| c / T::from(n_samples).unwrap())
140+
.iter()
141+
.map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap())
140142
.collect()
141143
};
142144

@@ -160,6 +162,7 @@ impl<T: RealNumber> GaussianNBDistribution<T> {
160162

161163
Ok(Self {
162164
class_labels,
165+
class_count,
163166
class_priors,
164167
var,
165168
theta,
@@ -226,6 +229,12 @@ impl<T: RealNumber, M: Matrix<T>> GaussianNB<T, M> {
226229
&self.inner.distribution.class_labels
227230
}
228231

232+
/// Number of training samples observed in each class.
233+
/// Returns a vector of size n_classes.
234+
pub fn class_count(&self) -> &Vec<usize> {
235+
&self.inner.distribution.class_count
236+
}
237+
229238
/// Probability of each class
230239
/// Returns a vector of size n_classes.
231240
pub fn class_priors(&self) -> &Vec<T> {
@@ -268,6 +277,8 @@ mod tests {
268277

269278
assert_eq!(gnb.classes(), &[1., 2.]);
270279

280+
assert_eq!(gnb.class_count(), &[3, 3]);
281+
271282
assert_eq!(
272283
gnb.var(),
273284
&[

0 commit comments

Comments
 (0)