From cc044d436d0848e368de0e7f911f2b5c39330d0e Mon Sep 17 00:00:00 2001 From: Giorgio Savastano <49161489+SpireGiorgioSavastano@users.noreply.github.com> Date: Mon, 8 Aug 2022 10:57:02 +0200 Subject: [PATCH] Error handling (#2) * more error handling * more error handling * improved error handling * improved error handling * added global constant for BAD_VALUE * modified BAD_VALUE to INF * added some docs --- Cargo.lock | 76 +++++++++++++++++++-------------------- src/emd_classification.rs | 27 ++++++++++---- src/lib.rs | 28 ++++++++++++--- src/netcdf_utils.rs | 2 +- 4 files changed, 84 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e357f0f..7e87116 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -46,9 +46,9 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c02a4d71819009c192cf4872265391563fd6a84c81ff2c0f2a7026ca4c1d85c" +checksum = "c2dd04ddaf88237dc3b8d8f9a3c1004b506b54b3313403944054d23c0870c521" dependencies = [ "cfg-if", "crossbeam-utils", @@ -56,9 +56,9 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e" +checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc" dependencies = [ "cfg-if", "crossbeam-epoch", @@ -67,9 +67,9 @@ dependencies = [ [[package]] name = "crossbeam-epoch" -version = "0.9.9" +version = "0.9.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07db9d94cbd326813772c968ccd25999e5f8ae22f4f8d1b11effa37ef6ce281d" +checksum = "045ebe27666471bb549370b4b0b3e51b07f56325befa4284db65fc89c02511b1" dependencies = [ "autocfg", "cfg-if", @@ -81,9 +81,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.10" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d82ee10ce34d7bc12c2122495e7593a9c41347ecdd64185af4ecf72cb1a7f83" +checksum = "51887d4adc7b564537b15adcfb307936f8075dfcd5f00dde9a9f1d29383682bc" dependencies = [ "cfg-if", "once_cell", @@ -91,9 +91,9 @@ dependencies = [ [[package]] name = "ctor" -version = "0.1.22" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f877be4f7c9f246b183111634f75baa039715e3f46ce860677d3b19a69fb229c" +checksum = "cdffe87e1d521a10f9696f833fe502293ea446d7f256c06128293a4119bdf4cb" dependencies = [ "quote", "syn", @@ -113,9 +113,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "ghost" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b93490550b1782c589a350f2211fff2e34682e25fed17ef53fc4fa8fe184975e" +checksum = "eb19fe8de3ea0920d282f7b77dd4227aea6b8b999b42cdf0ca41b2472b14443a" dependencies = [ "proc-macro2", "quote", @@ -176,9 +176,9 @@ dependencies = [ [[package]] name = "indoc" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05a0bd019339e5d968b37855180087b7b9d512c5046fbd244cf8c95687927d6e" +checksum = "adab1eaa3408fb7f0c777a73e7465fd5656136fc93b670eb6df3c88c2c1344e3" [[package]] name = "integer-sqrt" @@ -216,9 +216,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.126" +version = "0.2.127" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" +checksum = "505e71a4706fa491e9b1b55f51b95d4037d0821ee40131190475f692b35b009b" [[package]] name = "libloading" @@ -278,9 +278,9 @@ dependencies = [ [[package]] name = "ndarray" -version = "0.15.4" +version = "0.15.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dec23e6762830658d2b3d385a75aa212af2f67a4586d4442907144f3bb6a1ca8" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" dependencies = [ "matrixmultiply", "num-complex", @@ -447,9 +447,9 @@ checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" [[package]] name = "proc-macro2" -version = "1.0.40" +version = "1.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd96a1e8ed2596c337f8eae5f24924ec83f5ad5ab21ea8e455d3566c69fbcaf7" +checksum = "0a2ca2c61bc9f3d74d2886294ab7b9853abd9c1ad903a3ac7815c58989bb7bab" dependencies = [ "unicode-ident", ] @@ -516,9 +516,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.20" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bcdf212e9776fbcb2d23ab029360416bb1706b1aea2d1a5ba002727cbcab804" +checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179" dependencies = [ "proc-macro2", ] @@ -555,9 +555,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.13" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f25bc4c7e55e0b0b7a1d43fb893f4fa1361d0abe38b9ce4f323c2adfe6ef42" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" dependencies = [ "bitflags", ] @@ -593,15 +593,15 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.141" +version = "1.0.142" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7af873f2c95b99fcb0bd0fe622a43e29514658873c8ceba88c4cb88833a22500" +checksum = "e590c437916fb6b221e1d00df6e3294f3fccd70ca7e92541c475d6ed6ef5fee2" [[package]] name = "serde_derive" -version = "1.0.141" +version = "1.0.142" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75743a150d003dd863b51dc809bcad0d73f2102c53632f1e954e738192a3413f" +checksum = "34b5b8d809babe02f538c2cfec6f2c1ed10804c0e5a6a041a049a4f5588ccc2e" dependencies = [ "proc-macro2", "quote", @@ -616,9 +616,9 @@ checksum = "2fd0db749597d91ff862fd1d55ea87f7855a744a8425a64695b6fca237d1dad1" [[package]] name = "syn" -version = "1.0.98" +version = "1.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c50aef8a904de4c23c788f104b7dddc7d6f79c647c7c8ce4cc8f73eb0ca773dd" +checksum = "58dbef6ec655055e20b86b15a8cc6d439cca19b667537ac6a1369572d151ab13" dependencies = [ "proc-macro2", "quote", @@ -633,18 +633,18 @@ checksum = "c02424087780c9b71cc96799eaeddff35af2bc513278cda5c99fc1f5d026d3c1" [[package]] name = "thiserror" -version = "1.0.31" +version = "1.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd829fe32373d27f76265620b5309d0340cb8550f523c1dda251d6298069069a" +checksum = "f5f6586b7f764adc0231f4c79be7b920e766bb2f3e51b3661cdb263828f19994" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.31" +version = "1.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0396bc89e626244658bef819e22d0cc459e795a5ebe878e6ec336d1674a8d79a" +checksum = "12bafc5b54507e0149cdf1b145a5d80ab80a90bcd9275df43d4fff68460f6c21" dependencies = [ "proc-macro2", "quote", @@ -653,15 +653,15 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15c61ba63f9235225a22310255a29b806b907c9b8c964bcbd0a2c70f3f2deea7" +checksum = "c4f5b37a154999a8f3f98cc23a628d850e154479cd94decf3414696e12e31aaf" [[package]] name = "unindent" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52fee519a3e570f7df377a06a1a7775cdbfb7aa460be7e08de2b1f0e69973a44" +checksum = "58ee9362deb4a96cef4d437d1ad49cffc9b9e92d202b6995674e928ce684f112" [[package]] name = "vcpkg" diff --git a/src/emd_classification.rs b/src/emd_classification.rs index 974d2b9..a96e59e 100644 --- a/src/emd_classification.rs +++ b/src/emd_classification.rs @@ -1,7 +1,9 @@ use ndarray::prelude::*; use ndarray::Zip; use ordered_float::OrderedFloat; -use pathfinding::prelude::{kuhn_munkres_min, Matrix}; +use pathfinding::prelude::{kuhn_munkres_min, Matrix, MatrixFormatError}; + +const BAD_VALUE: f64 = f64::INFINITY; fn argsort(data: &[T]) -> Vec { let mut indices = (0..data.len()).collect::>(); @@ -19,6 +21,8 @@ fn euclidean_distance(v1: &ArrayView1, v2: &ArrayView1) -> f64 { .sqrt() } +/// Compute Euclidean distance between two 2-D data tensors (e.g., images). +/// pub fn euclidean_rdist_rust(x: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> Array2 { let mut c = Array2::::zeros((x.nrows(), y.nrows())); for i in 0..x.nrows() { @@ -36,6 +40,8 @@ fn euclidean_rdist_row(x: &ArrayView1<'_, f64>, y: &ArrayView2<'_, f64>) -> Arra z } +/// Parallel computation of Euclidean distance between two 2-D data tensors (e.g., images) +/// pub fn euclidean_rdist_par(x: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> Array2 { let mut c = Array2::::zeros((x.nrows(), y.nrows())); Zip::from(x.rows()) @@ -44,20 +50,29 @@ pub fn euclidean_rdist_par(x: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> Ar c } -pub fn emd_dist_serial(x: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> OrderedFloat { +/// Compute Earth Movers Distance (EMD) between two 2-D data tensors (e.g., images). +/// +pub fn emd_dist_serial( + x: ArrayView2<'_, f64>, + y: ArrayView2<'_, f64>, +) -> Result, MatrixFormatError> { let c = euclidean_rdist_rust(x, y); let costs = c.mapv(|elem| OrderedFloat::from(elem)); - let weights = Matrix::from_vec(costs.nrows(), costs.ncols(), costs.into_raw_vec()) - .expect("Failed to convert vec to Matrix"); + let weights = Matrix::from_vec(costs.nrows(), costs.ncols(), costs.into_raw_vec())?; let (emd_dist, _assignments) = kuhn_munkres_min(&weights); - emd_dist + Ok(emd_dist) } fn compute_emd_bulk(x: ArrayView2<'_, f64>, y: ArrayView3<'_, f64>) -> Array1> { let mut c = Array1::>::zeros(y.shape()[0]); Zip::from(&mut c) .and(y.axis_iter(Axis(0))) - .for_each(|c, mat_y| *c = emd_dist_serial(mat_y, x)); + .for_each(|c, mat_y| { + *c = emd_dist_serial(mat_y, x).unwrap_or_else(|err| { + println!("BAD_VALUE due to: {}", err); + return OrderedFloat::from(BAD_VALUE); + }) + }); c } diff --git a/src/lib.rs b/src/lib.rs index 6c43ee6..549170b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,19 @@ +//! # Pillars +//! +//! `pillars` is a collection of algorithms implemented in Python and Rust. +//! +//! ## Highlights +//! +//! - Computation of EMD distance +//! + use numpy::{ IntoPyArray, PyArray1, PyArray2, PyArray3, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, }; use pyo3::{exceptions, pymodule, types::PyModule, PyResult, Python}; -mod emd_classification; -mod netcdf_utils; +pub mod emd_classification; +pub mod netcdf_utils; #[pymodule] fn pillars(_py: Python<'_>, m: &PyModule) -> PyResult<()> { @@ -33,11 +42,22 @@ fn pillars(_py: Python<'_>, m: &PyModule) -> PyResult<()> { } #[pyfn(m)] - fn compute_emd<'py>(x: PyReadonlyArray2<'py, f64>, y: PyReadonlyArray2<'py, f64>) -> f64 { + fn compute_emd<'py>( + x: PyReadonlyArray2<'py, f64>, + y: PyReadonlyArray2<'py, f64>, + ) -> PyResult { let x = x.as_array(); let y = y.as_array(); let z = emd_classification::emd_dist_serial(x, y); - f64::from(z) + + let _z = match z { + Ok(z) => return Ok(f64::from(z)), + Err(_e) => { + return Err(exceptions::PyTypeError::new_err( + "Failed to compute EMD distance.", + )) + } + }; } #[pyfn(m)] diff --git a/src/netcdf_utils.rs b/src/netcdf_utils.rs index 39e14b6..700917b 100644 --- a/src/netcdf_utils.rs +++ b/src/netcdf_utils.rs @@ -5,7 +5,7 @@ pub fn get_ddms_at_indices_ser( path: &PathBuf, variable_name: String, indices: ArrayView1, -) -> Result, netcdf::error::Error> { +) -> netcdf::error::Result> { let file = netcdf::open(path)?; let var = &file.variable(&variable_name);