Skip to content

Commit 64548dc

Browse files
committed
Add feature_log_prob getter for BernoulliNB
1 parent b02a1b0 commit 64548dc

File tree

1 file changed

+62
-12
lines changed

1 file changed

+62
-12
lines changed

src/naive_bayes/bernoulli.rs

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
//! ## References:
3434
//!
3535
//! * ["Introduction to Information Retrieval", Manning C. D., Raghavan P., Schutze H., 2009, Chapter 13 ](https://nlp.stanford.edu/IR-book/information-retrieval-book.html)
36+
use std::ops::Not;
37+
3638
use crate::api::{Predictor, SupervisedEstimator};
3739
use crate::error::Failed;
3840
use crate::linalg::row_iter;
@@ -47,12 +49,26 @@ use serde::{Deserialize, Serialize};
4749

4850
/// Naive Bayes classifier for Bearnoulli features
4951
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
50-
#[derive(Debug, PartialEq)]
52+
#[derive(Debug)]
5153
struct BernoulliNBDistribution<T: RealNumber> {
5254
/// class labels known to the classifier
5355
class_labels: Vec<T>,
5456
class_priors: Vec<T>,
55-
feature_prob: Vec<Vec<T>>,
57+
feature_log_prob: Vec<Vec<T>>,
58+
}
59+
60+
impl<T: RealNumber> PartialEq for BernoulliNBDistribution<T> {
61+
fn eq(&self, other: &Self) -> bool {
62+
if self.class_labels == other.class_labels && self.class_priors == other.class_priors {
63+
self.feature_log_prob
64+
.iter()
65+
.zip(other.feature_log_prob.iter())
66+
.any(|(left, right)| !left.approximate_eq(right, T::epsilon()))
67+
.not()
68+
} else {
69+
false
70+
}
71+
}
5672
}
5773

5874
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistribution<T> {
@@ -65,9 +81,9 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistributi
6581
for feature in 0..j.len() {
6682
let value = j.get(feature);
6783
if value == T::one() {
68-
likelihood += self.feature_prob[class_index][feature].ln();
84+
likelihood += self.feature_log_prob[class_index][feature];
6985
} else {
70-
likelihood += (T::one() - self.feature_prob[class_index][feature]).ln();
86+
likelihood += (T::one() - self.feature_log_prob[class_index][feature].exp()).ln();
7187
}
7288
}
7389
likelihood
@@ -185,21 +201,23 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
185201
}
186202
}
187203

188-
let feature_prob = feature_in_class_counter
204+
let feature_log_prob = feature_in_class_counter
189205
.iter()
190206
.enumerate()
191207
.map(|(class_index, feature_count)| {
192208
feature_count
193209
.iter()
194-
.map(|&count| (count + alpha) / (class_count[class_index] + alpha * T::two()))
210+
.map(|&count| {
211+
((count + alpha) / (class_count[class_index] + alpha * T::two())).ln()
212+
})
195213
.collect()
196214
})
197215
.collect();
198216

199217
Ok(Self {
200218
class_labels,
201219
class_priors,
202-
feature_prob,
220+
feature_log_prob,
203221
})
204222
}
205223
}
@@ -272,6 +290,12 @@ impl<T: RealNumber, M: Matrix<T>> BernoulliNB<T, M> {
272290
pub fn classes(&self) -> &Vec<T> {
273291
&self.inner.distribution.class_labels
274292
}
293+
294+
/// Empirical log probability of features given a class, P(x_i|y).
295+
/// Returns a 2d vector of shape (n_classes, n_features)
296+
pub fn feature_log_prob(&self) -> &Vec<Vec<T>> {
297+
&self.inner.distribution.feature_log_prob
298+
}
275299
}
276300

277301
#[cfg(test)]
@@ -302,10 +326,24 @@ mod tests {
302326

303327
assert_eq!(bnb.inner.distribution.class_priors, &[0.75, 0.25]);
304328
assert_eq!(
305-
bnb.inner.distribution.feature_prob,
329+
bnb.feature_log_prob(),
306330
&[
307-
&[0.4, 0.8, 0.2, 0.4, 0.4, 0.2],
308-
&[1. / 3.0, 2. / 3.0, 2. / 3.0, 1. / 3.0, 1. / 3.0, 2. / 3.0]
331+
&[
332+
-0.916290731874155,
333+
-0.2231435513142097,
334+
-1.6094379124341003,
335+
-0.916290731874155,
336+
-0.916290731874155,
337+
-1.6094379124341003
338+
],
339+
&[
340+
-1.0986122886681098,
341+
-0.40546510810816444,
342+
-0.40546510810816444,
343+
-1.0986122886681098,
344+
-1.0986122886681098,
345+
-0.40546510810816444
346+
]
309347
]
310348
);
311349

@@ -348,10 +386,22 @@ mod tests {
348386
.distribution
349387
.class_priors
350388
.approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2));
351-
assert!(bnb.inner.distribution.feature_prob[1].approximate_eq(
352-
&vec!(0.8, 0.8, 0.8, 0.4, 0.8, 0.6, 0.8, 0.6, 0.6, 0.8),
389+
assert!(bnb.feature_log_prob()[1].approximate_eq(
390+
&vec![
391+
-0.22314355,
392+
-0.22314355,
393+
-0.22314355,
394+
-0.91629073,
395+
-0.22314355,
396+
-0.51082562,
397+
-0.22314355,
398+
-0.51082562,
399+
-0.51082562,
400+
-0.22314355
401+
],
353402
1e-1
354403
));
404+
println!("{:?}", y_hat);
355405
assert!(y_hat.approximate_eq(
356406
&vec!(2.0, 2.0, 0.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
357407
1e-5

0 commit comments

Comments
 (0)