Skip to content

Commit

Permalink
improved error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
SpireGiorgioSavastano authored and Giorgio Savastano committed Dec 5, 2022
1 parent 3ac67f9 commit df4cc53
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 26 deletions.
14 changes: 7 additions & 7 deletions src/emd_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub fn emd_dist_serial(
fn compute_emd_bulk(
x: ArrayView2<'_, f64>,
y: ArrayView3<'_, f64>,
) -> Result<Array1<OrderedFloat<f64>>, MatrixFormatError> {
) -> Array1<OrderedFloat<f64>> {
let mut c = Array1::<OrderedFloat<f64>>::zeros(y.shape()[0]);
Zip::from(&mut c)
.and(y.axis_iter(Axis(0)))
Expand All @@ -83,23 +83,23 @@ pub fn classify_closest_n(
x: ArrayView2<'_, f64>,
y: ArrayView3<'_, f64>,
n: usize,
) -> Result<Array1<usize>, MatrixFormatError> {
) -> Array1<usize> {
let c = compute_emd_bulk(x, y);
let res = argsort(&c?.to_vec());
let res = argsort(&c.to_vec());
assert!(n < res.len());
unsafe { Ok(Array::from_vec(res.get_unchecked(0..n).to_vec())) }
unsafe { Array::from_vec(res.get_unchecked(0..n).to_vec()) }
}

pub fn classify_closest_n_bulk(
x: ArrayView3<'_, f64>,
y: ArrayView3<'_, f64>,
n: usize,
) -> Result<Array2<usize>, MatrixFormatError> {
) -> Array2<usize> {
let mut c = Array2::<usize>::zeros((x.shape()[0], n));
Zip::from(c.rows_mut())
.and(x.axis_iter(Axis(0)))
.par_for_each(|mut c, mat_x| c += &classify_closest_n(mat_x, y, n).unwrap());
Ok(c)
.par_for_each(|mut c, mat_x| c += &classify_closest_n(mat_x, y, n));
c
}

#[cfg(test)]
Expand Down
23 changes: 4 additions & 19 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,11 @@ fn pillars(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
x: PyReadonlyArray2<'py, f64>,
y: PyReadonlyArray3<'py, f64>,
n: usize,
) -> PyResult<&'py PyArray1<usize>> {
) -> &'py PyArray1<usize> {
let x = x.as_array();
let y = y.as_array();
let z = emd_classification::classify_closest_n(x, y, n);

let _z = match z {
Ok(z) => return Ok(z.into_pyarray(py)),
Err(_e) => {
return Err(exceptions::PyTypeError::new_err(
"Failed to compute EMD distance.",
))
}
};
z.into_pyarray(py)
}

#[pyfn(m)]
Expand All @@ -87,18 +79,11 @@ fn pillars(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
x: PyReadonlyArray3<'py, f64>,
y: PyReadonlyArray3<'py, f64>,
n: usize,
) -> PyResult<&'py PyArray2<usize>> {
) -> &'py PyArray2<usize> {
let x = x.as_array();
let y = y.as_array();
let z = emd_classification::classify_closest_n_bulk(x, y, n);
let _z = match z {
Ok(z) => return Ok(z.into_pyarray(py)),
Err(_e) => {
return Err(exceptions::PyTypeError::new_err(
"Failed to compute EMD distance.",
))
}
};
z.into_pyarray(py)
}

#[pyfn(m)]
Expand Down

0 comments on commit df4cc53

Please sign in to comment.