-
Notifications
You must be signed in to change notification settings - Fork 86
feat: BernoulliNB #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/naive_bayes/bernoulli.rs
Outdated
} | ||
likelihood | ||
} else { | ||
T::zero() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added this in all the other implementations, but maybe this should be -infinity or NaN, right? @VolodymyrOrlov
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should panic here instead, since else
branch is unreachable, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will remove just this if and else branches since as you say this is unreachable (at least in the public api)
afa3622
to
ede3cfb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks good and numbers match Scikit's! Great clean code. I love it!
There are some minor changes I'd love to make but after your are done feel free to merge it in development
src/preprocessing/mod.rs
Outdated
use crate::linalg::Matrix; | ||
use crate::math::num::RealNumber; | ||
|
||
pub(crate) fn binarize<T: RealNumber, M: Matrix<T>>(x: &M, threshold: T) -> M { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please move this method to https://github.com/smartcorelib/smartcore/blob/development/src/linalg/stats.rs and doc string describing how it works? It would be awesome to have a short example in doc string as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added that in linals/stats.rs in a new Trait, let me know if that is what you were thinking about
src/naive_bayes/bernoulli.rs
Outdated
} | ||
likelihood | ||
} else { | ||
T::zero() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should panic here instead, since else
branch is unreachable, right?
]); | ||
let y = vec![0., 0., 0., 1.]; | ||
let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add this test please? It doesn't change anything but it's good to have a direct comparison with Scikit.
#[test]
fn bernoulli_nb_scikit_parity() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[2., 4., 0., 0., 2., 1., 2., 4., 2., 0.],
&[3., 4., 0., 2., 1., 0., 1., 4., 0., 3.],
&[1., 4., 2., 4., 1., 0., 1., 2., 3., 2.],
&[0., 3., 3., 4., 1., 0., 3., 1., 1., 1.],
&[0., 2., 1., 4., 3., 4., 1., 2., 3., 1.],
&[3., 2., 4., 1., 3., 0., 2., 4., 0., 2.],
&[3., 1., 3., 0., 2., 0., 4., 4., 3., 4.],
&[2., 2., 2., 0., 1., 1., 2., 1., 0., 1.],
&[3., 3., 2., 2., 0., 2., 3., 2., 2., 3.],
&[4., 3., 4., 4., 4., 2., 2., 0., 1., 4.],
&[3., 4., 2., 2., 1., 4., 4., 4., 1., 3.],
&[3., 0., 1., 4., 4., 0., 0., 3., 2., 4.],
&[2., 0., 3., 3., 1., 2., 0., 2., 4., 1.],
&[2., 4., 0., 4., 2., 4., 1., 3., 1., 4.],
&[0., 2., 2., 3., 4., 0., 4., 4., 4., 4.]]);
let y = vec![2., 2., 0., 0., 0., 2., 1., 1., 0., 1., 0., 0., 2., 0., 2.];
let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap();
let y_hat = bnb.predict(&x).unwrap();
assert!(bnb.inner.distribution.class_priors.approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2));
assert!(bnb.inner.distribution.feature_prob[1].approximate_eq(&vec!(0.8, 0.8, 0.8, 0.4, 0.8, 0.6, 0.8, 0.6, 0.6, 0.8), 1e-1));
assert!(y_hat.approximate_eq(&vec!(2.0, 2.0, 0.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), 1e-5));
}
@@ -0,0 +1,278 @@ | |||
use crate::error::Failed; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to add a documentation here to describe the general objectives of the module and sources
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can help with documentation. I plan to cover most of the new functionality by a short summary, similar to this one right before releasing v0.2.0. What would be helpful to me are links to resources that you used to learn about these methods either in description of the PR, issue or in the code somewhere. While I can (and will) use links to official papers describing all NB methods, in most of the cases there are other resources in the web that do a better job of explaining an algorithm and I like to share such resources with our users.
ede3cfb
to
babb4a7
Compare
Implement BernoulliNB as stated in #14
References: