diff --git a/src/emd_classification.rs b/src/emd_classification.rs index c047276..67f8379 100644 --- a/src/emd_classification.rs +++ b/src/emd_classification.rs @@ -66,7 +66,7 @@ pub fn emd_dist_serial( fn compute_emd_bulk( x: ArrayView2<'_, f64>, y: ArrayView3<'_, f64>, -) -> Result>, MatrixFormatError> { +) -> Array1> { let mut c = Array1::>::zeros(y.shape()[0]); Zip::from(&mut c) .and(y.axis_iter(Axis(0))) @@ -83,23 +83,23 @@ pub fn classify_closest_n( x: ArrayView2<'_, f64>, y: ArrayView3<'_, f64>, n: usize, -) -> Result, MatrixFormatError> { +) -> Array1 { let c = compute_emd_bulk(x, y); - let res = argsort(&c?.to_vec()); + let res = argsort(&c.to_vec()); assert!(n < res.len()); - unsafe { Ok(Array::from_vec(res.get_unchecked(0..n).to_vec())) } + unsafe { Array::from_vec(res.get_unchecked(0..n).to_vec()) } } pub fn classify_closest_n_bulk( x: ArrayView3<'_, f64>, y: ArrayView3<'_, f64>, n: usize, -) -> Result, MatrixFormatError> { +) -> Array2 { let mut c = Array2::::zeros((x.shape()[0], n)); Zip::from(c.rows_mut()) .and(x.axis_iter(Axis(0))) - .par_for_each(|mut c, mat_x| c += &classify_closest_n(mat_x, y, n).unwrap()); - Ok(c) + .par_for_each(|mut c, mat_x| c += &classify_closest_n(mat_x, y, n)); + c } #[cfg(test)] diff --git a/src/lib.rs b/src/lib.rs index b3f6baf..549170b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -66,19 +66,11 @@ fn pillars(_py: Python<'_>, m: &PyModule) -> PyResult<()> { x: PyReadonlyArray2<'py, f64>, y: PyReadonlyArray3<'py, f64>, n: usize, - ) -> PyResult<&'py PyArray1> { + ) -> &'py PyArray1 { let x = x.as_array(); let y = y.as_array(); let z = emd_classification::classify_closest_n(x, y, n); - - let _z = match z { - Ok(z) => return Ok(z.into_pyarray(py)), - Err(_e) => { - return Err(exceptions::PyTypeError::new_err( - "Failed to compute EMD distance.", - )) - } - }; + z.into_pyarray(py) } #[pyfn(m)] @@ -87,18 +79,11 @@ fn pillars(_py: Python<'_>, m: &PyModule) -> PyResult<()> { x: PyReadonlyArray3<'py, f64>, y: PyReadonlyArray3<'py, f64>, n: usize, - ) -> PyResult<&'py PyArray2> { + ) -> &'py PyArray2 { let x = x.as_array(); let y = y.as_array(); let z = emd_classification::classify_closest_n_bulk(x, y, n); - let _z = match z { - Ok(z) => return Ok(z.into_pyarray(py)), - Err(_e) => { - return Err(exceptions::PyTypeError::new_err( - "Failed to compute EMD distance.", - )) - } - }; + z.into_pyarray(py) } #[pyfn(m)]