Skip to content

Commit 1b42f8a

Browse files
authored
feat: Add getters for naive bayes structs (#74)
* feat: Add getters for GaussianNB * Add classes getter to BernoulliNB Add classes getter to CategoricalNB Add classes getter to MultinomialNB * Add feature_log_prob getter to MultinomialNB * Add class_count to NB structs * Add n_features getter for NB * Add feature_count to MultinomialNB and BernoulliNB * Add n_categories to CategoricalNB * Implement feature_log_prob and category_count getter for CategoricalNB * Implement feature_log_prob for BernoulliNB
1 parent c0be45b commit 1b42f8a

File tree

4 files changed

+420
-77
lines changed

4 files changed

+420
-77
lines changed

src/naive_bayes/bernoulli.rs

Lines changed: 127 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,44 @@ use serde::{Deserialize, Serialize};
4747

4848
/// Naive Bayes classifier for Bearnoulli features
4949
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
50-
#[derive(Debug, PartialEq)]
50+
#[derive(Debug)]
5151
struct BernoulliNBDistribution<T: RealNumber> {
5252
/// class labels known to the classifier
5353
class_labels: Vec<T>,
54+
/// number of training samples observed in each class
55+
class_count: Vec<usize>,
56+
/// probability of each class
5457
class_priors: Vec<T>,
55-
feature_prob: Vec<Vec<T>>,
58+
/// Number of samples encountered for each (class, feature)
59+
feature_count: Vec<Vec<usize>>,
60+
/// probability of features per class
61+
feature_log_prob: Vec<Vec<T>>,
62+
/// Number of features of each sample
63+
n_features: usize,
64+
}
65+
66+
impl<T: RealNumber> PartialEq for BernoulliNBDistribution<T> {
67+
fn eq(&self, other: &Self) -> bool {
68+
if self.class_labels == other.class_labels
69+
&& self.class_count == other.class_count
70+
&& self.class_priors == other.class_priors
71+
&& self.feature_count == other.feature_count
72+
&& self.n_features == other.n_features
73+
{
74+
for (a, b) in self
75+
.feature_log_prob
76+
.iter()
77+
.zip(other.feature_log_prob.iter())
78+
{
79+
if !a.approximate_eq(b, T::epsilon()) {
80+
return false;
81+
}
82+
}
83+
true
84+
} else {
85+
false
86+
}
87+
}
5688
}
5789

5890
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistribution<T> {
@@ -65,9 +97,9 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistributi
6597
for feature in 0..j.len() {
6698
let value = j.get(feature);
6799
if value == T::one() {
68-
likelihood += self.feature_prob[class_index][feature].ln();
100+
likelihood += self.feature_log_prob[class_index][feature];
69101
} else {
70-
likelihood += (T::one() - self.feature_prob[class_index][feature]).ln();
102+
likelihood += (T::one() - self.feature_log_prob[class_index][feature].exp()).ln();
71103
}
72104
}
73105
likelihood
@@ -157,10 +189,10 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
157189
let y = y.to_vec();
158190

159191
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
160-
let mut class_count = vec![T::zero(); class_labels.len()];
192+
let mut class_count = vec![0_usize; class_labels.len()];
161193

162194
for class_index in indices.iter() {
163-
class_count[*class_index] += T::one();
195+
class_count[*class_index] += 1;
164196
}
165197

166198
let class_priors = if let Some(class_priors) = priors {
@@ -173,33 +205,46 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
173205
} else {
174206
class_count
175207
.iter()
176-
.map(|&c| c / T::from(n_samples).unwrap())
208+
.map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap())
177209
.collect()
178210
};
179211

180-
let mut feature_in_class_counter = vec![vec![T::zero(); n_features]; class_labels.len()];
212+
let mut feature_in_class_counter = vec![vec![0_usize; n_features]; class_labels.len()];
181213

182214
for (row, class_index) in row_iter(x).zip(indices) {
183215
for (idx, row_i) in row.iter().enumerate().take(n_features) {
184-
feature_in_class_counter[class_index][idx] += *row_i;
216+
feature_in_class_counter[class_index][idx] +=
217+
row_i.to_usize().ok_or_else(|| {
218+
Failed::fit(&format!(
219+
"Elements of the matrix should be 1.0 or 0.0 |found|=[{}]",
220+
row_i
221+
))
222+
})?;
185223
}
186224
}
187225

188-
let feature_prob = feature_in_class_counter
226+
let feature_log_prob = feature_in_class_counter
189227
.iter()
190228
.enumerate()
191229
.map(|(class_index, feature_count)| {
192230
feature_count
193231
.iter()
194-
.map(|&count| (count + alpha) / (class_count[class_index] + alpha * T::two()))
232+
.map(|&count| {
233+
((T::from(count).unwrap() + alpha)
234+
/ (T::from(class_count[class_index]).unwrap() + alpha * T::two()))
235+
.ln()
236+
})
195237
.collect()
196238
})
197239
.collect();
198240

199241
Ok(Self {
200242
class_labels,
201243
class_priors,
202-
feature_prob,
244+
class_count,
245+
feature_count: feature_in_class_counter,
246+
feature_log_prob,
247+
n_features,
203248
})
204249
}
205250
}
@@ -266,6 +311,34 @@ impl<T: RealNumber, M: Matrix<T>> BernoulliNB<T, M> {
266311
self.inner.predict(x)
267312
}
268313
}
314+
315+
/// Class labels known to the classifier.
316+
/// Returns a vector of size n_classes.
317+
pub fn classes(&self) -> &Vec<T> {
318+
&self.inner.distribution.class_labels
319+
}
320+
321+
/// Number of training samples observed in each class.
322+
/// Returns a vector of size n_classes.
323+
pub fn class_count(&self) -> &Vec<usize> {
324+
&self.inner.distribution.class_count
325+
}
326+
327+
/// Number of features of each sample
328+
pub fn n_features(&self) -> usize {
329+
self.inner.distribution.n_features
330+
}
331+
332+
/// Number of samples encountered for each (class, feature)
333+
/// Returns a 2d vector of shape (n_classes, n_features)
334+
pub fn feature_count(&self) -> &Vec<Vec<usize>> {
335+
&self.inner.distribution.feature_count
336+
}
337+
338+
/// Empirical log probability of features given a class
339+
pub fn feature_log_prob(&self) -> &Vec<Vec<T>> {
340+
&self.inner.distribution.feature_log_prob
341+
}
269342
}
270343

271344
#[cfg(test)]
@@ -296,10 +369,24 @@ mod tests {
296369

297370
assert_eq!(bnb.inner.distribution.class_priors, &[0.75, 0.25]);
298371
assert_eq!(
299-
bnb.inner.distribution.feature_prob,
372+
bnb.feature_log_prob(),
300373
&[
301-
&[0.4, 0.8, 0.2, 0.4, 0.4, 0.2],
302-
&[1. / 3.0, 2. / 3.0, 2. / 3.0, 1. / 3.0, 1. / 3.0, 2. / 3.0]
374+
&[
375+
-0.916290731874155,
376+
-0.2231435513142097,
377+
-1.6094379124341003,
378+
-0.916290731874155,
379+
-0.916290731874155,
380+
-1.6094379124341003
381+
],
382+
&[
383+
-1.0986122886681098,
384+
-0.40546510810816444,
385+
-0.40546510810816444,
386+
-1.0986122886681098,
387+
-1.0986122886681098,
388+
-0.40546510810816444
389+
]
303390
]
304391
);
305392

@@ -335,13 +422,36 @@ mod tests {
335422

336423
let y_hat = bnb.predict(&x).unwrap();
337424

425+
assert_eq!(bnb.classes(), &[0., 1., 2.]);
426+
assert_eq!(bnb.class_count(), &[7, 3, 5]);
427+
assert_eq!(bnb.n_features(), 10);
428+
assert_eq!(
429+
bnb.feature_count(),
430+
&[
431+
&[5, 6, 6, 7, 6, 4, 6, 7, 7, 7],
432+
&[3, 3, 3, 1, 3, 2, 3, 2, 2, 3],
433+
&[4, 4, 3, 4, 5, 2, 4, 5, 3, 4]
434+
]
435+
);
436+
338437
assert!(bnb
339438
.inner
340439
.distribution
341440
.class_priors
342441
.approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2));
343-
assert!(bnb.inner.distribution.feature_prob[1].approximate_eq(
344-
&vec!(0.8, 0.8, 0.8, 0.4, 0.8, 0.6, 0.8, 0.6, 0.6, 0.8),
442+
assert!(bnb.feature_log_prob()[1].approximate_eq(
443+
&vec![
444+
-0.22314355,
445+
-0.22314355,
446+
-0.22314355,
447+
-0.91629073,
448+
-0.22314355,
449+
-0.51082562,
450+
-0.22314355,
451+
-0.51082562,
452+
-0.51082562,
453+
-0.22314355
454+
],
345455
1e-1
346456
));
347457
assert!(y_hat.approximate_eq(

0 commit comments

Comments
 (0)