Skip to content

Commit

Permalink
Error handling (#2)
Browse files Browse the repository at this point in the history
* more error handling

* more error handling

* improved error handling

* improved error handling

* added global constant for BAD_VALUE

* modified BAD_VALUE to INF

* added some docs
  • Loading branch information
SpireGiorgioSavastano authored Aug 8, 2022
1 parent 152008a commit cc044d4
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 49 deletions.
76 changes: 38 additions & 38 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 21 additions & 6 deletions src/emd_classification.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use ndarray::prelude::*;
use ndarray::Zip;
use ordered_float::OrderedFloat;
use pathfinding::prelude::{kuhn_munkres_min, Matrix};
use pathfinding::prelude::{kuhn_munkres_min, Matrix, MatrixFormatError};

const BAD_VALUE: f64 = f64::INFINITY;

fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
let mut indices = (0..data.len()).collect::<Vec<_>>();
Expand All @@ -19,6 +21,8 @@ fn euclidean_distance(v1: &ArrayView1<f64>, v2: &ArrayView1<f64>) -> f64 {
.sqrt()
}

/// Compute Euclidean distance between two 2-D data tensors (e.g., images).
///
pub fn euclidean_rdist_rust(x: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> Array2<f64> {
let mut c = Array2::<f64>::zeros((x.nrows(), y.nrows()));
for i in 0..x.nrows() {
Expand All @@ -36,6 +40,8 @@ fn euclidean_rdist_row(x: &ArrayView1<'_, f64>, y: &ArrayView2<'_, f64>) -> Arra
z
}

/// Parallel computation of Euclidean distance between two 2-D data tensors (e.g., images)
///
pub fn euclidean_rdist_par(x: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> Array2<f64> {
let mut c = Array2::<f64>::zeros((x.nrows(), y.nrows()));
Zip::from(x.rows())
Expand All @@ -44,20 +50,29 @@ pub fn euclidean_rdist_par(x: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> Ar
c
}

pub fn emd_dist_serial(x: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> OrderedFloat<f64> {
/// Compute Earth Movers Distance (EMD) between two 2-D data tensors (e.g., images).
///
pub fn emd_dist_serial(
x: ArrayView2<'_, f64>,
y: ArrayView2<'_, f64>,
) -> Result<OrderedFloat<f64>, MatrixFormatError> {
let c = euclidean_rdist_rust(x, y);
let costs = c.mapv(|elem| OrderedFloat::from(elem));
let weights = Matrix::from_vec(costs.nrows(), costs.ncols(), costs.into_raw_vec())
.expect("Failed to convert vec to Matrix");
let weights = Matrix::from_vec(costs.nrows(), costs.ncols(), costs.into_raw_vec())?;
let (emd_dist, _assignments) = kuhn_munkres_min(&weights);
emd_dist
Ok(emd_dist)
}

fn compute_emd_bulk(x: ArrayView2<'_, f64>, y: ArrayView3<'_, f64>) -> Array1<OrderedFloat<f64>> {
let mut c = Array1::<OrderedFloat<f64>>::zeros(y.shape()[0]);
Zip::from(&mut c)
.and(y.axis_iter(Axis(0)))
.for_each(|c, mat_y| *c = emd_dist_serial(mat_y, x));
.for_each(|c, mat_y| {
*c = emd_dist_serial(mat_y, x).unwrap_or_else(|err| {
println!("BAD_VALUE due to: {}", err);
return OrderedFloat::from(BAD_VALUE);
})
});
c
}

Expand Down
28 changes: 24 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
//! # Pillars
//!
//! `pillars` is a collection of algorithms implemented in Python and Rust.
//!
//! ## Highlights
//!
//! - Computation of EMD distance
//!
use numpy::{
IntoPyArray, PyArray1, PyArray2, PyArray3, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3,
};
use pyo3::{exceptions, pymodule, types::PyModule, PyResult, Python};

mod emd_classification;
mod netcdf_utils;
pub mod emd_classification;
pub mod netcdf_utils;

#[pymodule]
fn pillars(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
Expand Down Expand Up @@ -33,11 +42,22 @@ fn pillars(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
}

#[pyfn(m)]
fn compute_emd<'py>(x: PyReadonlyArray2<'py, f64>, y: PyReadonlyArray2<'py, f64>) -> f64 {
fn compute_emd<'py>(
x: PyReadonlyArray2<'py, f64>,
y: PyReadonlyArray2<'py, f64>,
) -> PyResult<f64> {
let x = x.as_array();
let y = y.as_array();
let z = emd_classification::emd_dist_serial(x, y);
f64::from(z)

let _z = match z {
Ok(z) => return Ok(f64::from(z)),
Err(_e) => {
return Err(exceptions::PyTypeError::new_err(
"Failed to compute EMD distance.",
))
}
};
}

#[pyfn(m)]
Expand Down
2 changes: 1 addition & 1 deletion src/netcdf_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub fn get_ddms_at_indices_ser(
path: &PathBuf,
variable_name: String,
indices: ArrayView1<usize>,
) -> Result<Array3<f64>, netcdf::error::Error> {
) -> netcdf::error::Result<Array3<f64>> {
let file = netcdf::open(path)?;

let var = &file.variable(&variable_name);
Expand Down

0 comments on commit cc044d4

Please sign in to comment.