Skip to content

Commit 59e1420

Browse files
authored
Fix tests (#203)
* Fix tests * Add again removed line Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
1 parent 4cfa51e commit 59e1420

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

src/dataset/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ pub mod digits;
88
pub mod generator;
99
pub mod iris;
1010

11-
use crate::numbers::basenum::Number;
1211
#[cfg(not(target_arch = "wasm32"))]
13-
use crate::numbers::realnum::RealNumber;
12+
use crate::numbers::{basenum::Number, realnum::RealNumber};
1413
#[cfg(not(target_arch = "wasm32"))]
1514
use std::fs::File;
1615
use std::io;

src/linear/logistic_regression.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,7 @@ mod tests {
715715
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
716716
#[test]
717717
fn lr_fit_predict() {
718-
let x = DenseMatrix::from_2d_array(&[
718+
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
719719
&[1., -5.],
720720
&[2., 5.],
721721
&[3., -2.],
@@ -739,8 +739,12 @@ mod tests {
739739
assert_eq!(lr.coefficients().shape(), (3, 2));
740740
assert_eq!(lr.intercept().shape(), (3, 1));
741741

742-
assert!((*lr.coefficients().get((0, 0)) - 0.0435f32).abs() < 1e-4);
743-
assert!((*lr.intercept().get((0, 0)) - 0.1250f32).abs() < 1e-4);
742+
assert!((*lr.coefficients().get((0, 0)) - 0.0435).abs() < 1e-4);
743+
assert!(
744+
(*lr.intercept().get((0, 0)) - 0.1250).abs() < 1e-4,
745+
"expected to be least than 1e-4, got {}",
746+
(*lr.intercept().get((0, 0)) - 0.1250).abs()
747+
);
744748

745749
let y_hat = lr.predict(&x).unwrap();
746750

src/naive_bayes/multinomial.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,10 @@ mod tests {
488488
&distribution.class_priors,
489489
&vec!(0.4666666666666667, 0.2, 0.3333333333333333)
490490
);
491+
492+
// Due to float differences in WASM32,
493+
// we disable this test for that arch
494+
#[cfg(not(target_arch = "wasm32"))]
491495
assert_eq!(
492496
&nb.feature_log_prob()[1],
493497
&vec![

0 commit comments

Comments
 (0)