Skip to content

Commit 06fda75

Browse files
committed
Add feature_log_prob getter to MultinomialNB
1 parent b85bea8 commit 06fda75

File tree

1 file changed

+45
-11
lines changed

1 file changed

+45
-11
lines changed

src/naive_bayes/multinomial.rs

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ struct MultinomialNBDistribution<T: RealNumber> {
5252
/// class labels known to the classifier
5353
class_labels: Vec<T>,
5454
class_priors: Vec<T>,
55-
feature_prob: Vec<Vec<T>>,
55+
/// Empirical log probability of features given a class
56+
feature_log_prob: Vec<Vec<T>>,
5657
}
5758

5859
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribution<T> {
@@ -64,7 +65,7 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribu
6465
let mut likelihood = T::zero();
6566
for feature in 0..j.len() {
6667
let value = j.get(feature);
67-
likelihood += value * self.feature_prob[class_index][feature].ln();
68+
likelihood += value * self.feature_log_prob[class_index][feature];
6869
}
6970
likelihood
7071
}
@@ -172,21 +173,23 @@ impl<T: RealNumber> MultinomialNBDistribution<T> {
172173
}
173174
}
174175

175-
let feature_prob = feature_in_class_counter
176+
let feature_log_prob = feature_in_class_counter
176177
.iter()
177178
.map(|feature_count| {
178179
let n_c = feature_count.sum();
179180
feature_count
180181
.iter()
181-
.map(|&count| (count + alpha) / (n_c + alpha * T::from(n_features).unwrap()))
182+
.map(|&count| {
183+
((count + alpha) / (n_c + alpha * T::from(n_features).unwrap())).ln()
184+
})
182185
.collect()
183186
})
184187
.collect();
185188

186189
Ok(Self {
187190
class_labels,
188191
class_priors,
189-
feature_prob,
192+
feature_log_prob,
190193
})
191194
}
192195
}
@@ -246,6 +249,12 @@ impl<T: RealNumber, M: Matrix<T>> MultinomialNB<T, M> {
246249
pub fn classes(&self) -> &Vec<T> {
247250
&self.inner.distribution.class_labels
248251
}
252+
253+
/// Empirical log probability of features given a class, P(x_i|y).
254+
/// Returns a 2d vector of shape (n_classes, n_features)
255+
pub fn feature_log_prob(&self) -> &Vec<Vec<T>> {
256+
&self.inner.distribution.feature_log_prob
257+
}
249258
}
250259

251260
#[cfg(test)]
@@ -278,10 +287,24 @@ mod tests {
278287

279288
assert_eq!(mnb.inner.distribution.class_priors, &[0.75, 0.25]);
280289
assert_eq!(
281-
mnb.inner.distribution.feature_prob,
290+
mnb.feature_log_prob(),
282291
&[
283-
&[1. / 7., 3. / 7., 1. / 14., 1. / 7., 1. / 7., 1. / 14.],
284-
&[1. / 9., 2. / 9.0, 2. / 9.0, 1. / 9.0, 1. / 9.0, 2. / 9.0]
292+
&[
293+
(1_f64 / 7_f64).ln(),
294+
(3_f64 / 7_f64).ln(),
295+
(1_f64 / 14_f64).ln(),
296+
(1_f64 / 7_f64).ln(),
297+
(1_f64 / 7_f64).ln(),
298+
(1_f64 / 14_f64).ln()
299+
],
300+
&[
301+
(1_f64 / 9_f64).ln(),
302+
(2_f64 / 9_f64).ln(),
303+
(2_f64 / 9_f64).ln(),
304+
(1_f64 / 9_f64).ln(),
305+
(1_f64 / 9_f64).ln(),
306+
(2_f64 / 9_f64).ln()
307+
]
285308
]
286309
);
287310

@@ -322,9 +345,20 @@ mod tests {
322345
.distribution
323346
.class_priors
324347
.approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2));
325-
assert!(nb.inner.distribution.feature_prob[1].approximate_eq(
326-
&vec!(0.07, 0.12, 0.07, 0.15, 0.07, 0.09, 0.08, 0.10, 0.08, 0.11),
327-
1e-1
348+
assert!(nb.feature_log_prob()[1].approximate_eq(
349+
&vec![
350+
-2.00148,
351+
-2.35815494,
352+
-2.00148,
353+
-2.69462718,
354+
-2.22462355,
355+
-2.91777073,
356+
-2.10684052,
357+
-2.51230562,
358+
-2.69462718,
359+
-2.00148
360+
],
361+
1e-5
328362
));
329363
assert!(y_hat.approximate_eq(
330364
&vec!(2.0, 2.0, 0.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.0, 2.0),

0 commit comments

Comments
 (0)