Skip to content

Revise tests for least-square problems #227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions lax/src/least_squares.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ macro_rules! impl_least_squares {
}
let k = ::std::cmp::min(m, n);
let nrhs = 1;
let ldb = match a_layout {
MatrixLayout::F { .. } => m.max(n),
MatrixLayout::C { .. } => 1,
};
let rcond: Self::Real = -1.;
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
let mut rank: i32 = 0;
Expand All @@ -54,9 +58,7 @@ macro_rules! impl_least_squares {
a,
a_layout.lda(),
b,
// this is the 'leading dimension of b', in the case where
// b is a single vector, this is 1
nrhs,
ldb,
&mut singular_values,
rcond,
&mut rank,
Expand Down
212 changes: 21 additions & 191 deletions ndarray-linalg/src/least_squares.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
//! // `a` and `b` have been moved, no longer valid
//! ```

use ndarray::{s, Array, Array1, Array2, ArrayBase, Axis, Data, DataMut, Dimension, Ix0, Ix1, Ix2};
use ndarray::*;

use crate::error::*;
use crate::lapack::least_squares::*;
Expand Down Expand Up @@ -352,7 +352,10 @@ where
// we need a new rhs b/c it will be overwritten with the solution
// for which we need `n` entries
let k = rhs.shape()[1];
let mut new_rhs = Array2::<E>::zeros((n, k));
let mut new_rhs = match self.layout()? {
MatrixLayout::C { .. } => Array2::<E>::zeros((n, k)),
MatrixLayout::F { .. } => Array2::<E>::zeros((n, k).f()),
};
new_rhs.slice_mut(s![0..m, ..]).assign(rhs);
compute_least_squares_nrhs(self, &mut new_rhs)
} else {
Expand Down Expand Up @@ -414,117 +417,9 @@ fn compute_residual_array1<E: Scalar, D: Data<Elem = E>>(

#[cfg(test)]
mod tests {
use super::*;
use crate::{error::LinalgError, *};
use approx::AbsDiffEq;
use ndarray::{ArcArray1, ArcArray2, Array1, Array2, CowArray};
use num_complex::Complex;

//
// Test cases taken from the scipy test suite for the scipy lstsq function
// https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
//
#[test]
fn scipy_test_simple_exact() {
let a = array![[1., 20.], [-30., 4.]];
let bs = vec![
array![[1., 0.], [0., 1.]],
array![[1.], [0.]],
array![[2., 1.], [-30., 4.]],
];
for b in &bs {
let res = a.least_squares(b).unwrap();
assert_eq!(res.rank, 2);
let b_hat = a.dot(&res.solution);
let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum_axis(Axis(0));
assert!(res
.residual_sum_of_squares
.unwrap()
.abs_diff_eq(&rssq, 1e-12));
assert!(b_hat.abs_diff_eq(&b, 1e-12));
}
}

#[test]
fn scipy_test_simple_overdetermined() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f64> = array![1., 2., 3.];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
let b_hat = a.dot(&res.solution);
let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum();
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12));
assert!(res
.solution
.abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12));
}

#[test]
fn scipy_test_simple_overdetermined_f32() {
let a: Array2<f32> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f32> = array![1., 2., 3.];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
let b_hat = a.dot(&res.solution);
let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum();
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-6));
assert!(res
.solution
.abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-6));
}

fn c(re: f64, im: f64) -> Complex<f64> {
Complex::new(re, im)
}

#[test]
fn scipy_test_simple_overdetermined_complex() {
let a: Array2<c64> = array![
[c(1., 2.), c(2., 0.)],
[c(4., 0.), c(5., 0.)],
[c(3., 0.), c(4., 0.)]
];
let b: Array1<c64> = array![c(1., 0.), c(2., 4.), c(3., 0.)];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
let b_hat = a.dot(&res.solution);
let rssq = (&b_hat - &b).mapv(|x| x.powi(2).abs()).sum();
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12));
assert!(res.solution.abs_diff_eq(
&array![
c(-0.4831460674157303, 0.258426966292135),
c(0.921348314606741, 0.292134831460674)
],
1e-12
));
}

#[test]
fn scipy_test_simple_underdetermined() {
let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]];
let b: Array1<f64> = array![1., 2.];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
assert!(res.residual_sum_of_squares.is_none());
let expected = array![-0.055555555555555, 0.111111111111111, 0.277777777777777];
assert!(res.solution.abs_diff_eq(&expected, 1e-12));
}

/// This test case tests the underdetermined case for multiple right hand
/// sides. Adapted from scipy lstsq tests.
#[test]
fn scipy_test_simple_underdetermined_nrhs() {
let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]];
let b: Array2<f64> = array![[1., 1.], [2., 2.]];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
assert!(res.residual_sum_of_squares.is_none());
let expected = array![
[-0.055555555555555, -0.055555555555555],
[0.111111111111111, 0.111111111111111],
[0.277777777777777, 0.277777777777777]
];
assert!(res.solution.abs_diff_eq(&expected, 1e-12));
}
use ndarray::*;

//
// Test that the different lest squares traits work as intended on the
Expand Down Expand Up @@ -554,23 +449,23 @@ mod tests {
}

#[test]
fn test_least_squares_on_arc() {
fn on_arc() {
let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
let res = a.least_squares(&b).unwrap();
assert_result(&a, &b, &res);
}

#[test]
fn test_least_squares_on_cow() {
fn on_cow() {
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
let b = CowArray::from(array![1., 2., 3.]);
let res = a.least_squares(&b).unwrap();
assert_result(&a, &b, &res);
}

#[test]
fn test_least_squares_on_view() {
fn on_view() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f64> = array![1., 2., 3.];
let av = a.view();
Expand All @@ -580,7 +475,7 @@ mod tests {
}

#[test]
fn test_least_squares_on_view_mut() {
fn on_view_mut() {
let mut a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let mut b: Array1<f64> = array![1., 2., 3.];
let av = a.view_mut();
Expand All @@ -590,7 +485,7 @@ mod tests {
}

#[test]
fn test_least_squares_into_on_owned() {
fn into_on_owned() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f64> = array![1., 2., 3.];
let ac = a.clone();
Expand All @@ -600,7 +495,7 @@ mod tests {
}

#[test]
fn test_least_squares_into_on_arc() {
fn into_on_arc() {
let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
let a2 = a.clone();
Expand All @@ -610,7 +505,7 @@ mod tests {
}

#[test]
fn test_least_squares_into_on_cow() {
fn into_on_cow() {
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
let b = CowArray::from(array![1., 2., 3.]);
let a2 = a.clone();
Expand All @@ -620,7 +515,7 @@ mod tests {
}

#[test]
fn test_least_squares_in_place_on_owned() {
fn in_place_on_owned() {
let a = array![[1., 2.], [4., 5.], [3., 4.]];
let b = array![1., 2., 3.];
let mut a2 = a.clone();
Expand All @@ -630,7 +525,7 @@ mod tests {
}

#[test]
fn test_least_squares_in_place_on_cow() {
fn in_place_on_cow() {
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
let b = CowArray::from(array![1., 2., 3.]);
let mut a2 = a.clone();
Expand All @@ -640,7 +535,7 @@ mod tests {
}

#[test]
fn test_least_squares_in_place_on_mut_view() {
fn in_place_on_mut_view() {
let a = array![[1., 2.], [4., 5.], [3., 4.]];
let b = array![1., 2., 3.];
let mut a2 = a.clone();
Expand All @@ -651,95 +546,30 @@ mod tests {
assert_result(&a, &b, &res);
}

//
// Test cases taken from the netlib documentation at
// https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
//
#[test]
fn netlib_lapack_example_for_dgels_1() {
let a: Array2<f64> = array![
[1., 1., 1.],
[2., 3., 4.],
[3., 5., 2.],
[4., 2., 5.],
[5., 4., 3.]
];
let b: Array1<f64> = array![-10., 12., 14., 16., 18.];
let expected: Array1<f64> = array![2., 1., 1.];
let result = a.least_squares(&b).unwrap();
assert!(result.solution.abs_diff_eq(&expected, 1e-12));

let residual = b - a.dot(&result.solution);
let resid_ssq = result.residual_sum_of_squares.unwrap();
assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12);
}

#[test]
fn netlib_lapack_example_for_dgels_2() {
let a: Array2<f64> = array![
[1., 1., 1.],
[2., 3., 4.],
[3., 5., 2.],
[4., 2., 5.],
[5., 4., 3.]
];
let b: Array1<f64> = array![-3., 14., 12., 16., 16.];
let expected: Array1<f64> = array![1., 1., 2.];
let result = a.least_squares(&b).unwrap();
assert!(result.solution.abs_diff_eq(&expected, 1e-12));

let residual = b - a.dot(&result.solution);
let resid_ssq = result.residual_sum_of_squares.unwrap();
assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12);
}

#[test]
fn netlib_lapack_example_for_dgels_nrhs() {
let a: Array2<f64> = array![
[1., 1., 1.],
[2., 3., 4.],
[3., 5., 2.],
[4., 2., 5.],
[5., 4., 3.]
];
let b: Array2<f64> = array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]];
let expected: Array2<f64> = array![[2., 1.], [1., 1.], [1., 2.]];
let result = a.least_squares(&b).unwrap();
assert!(result.solution.abs_diff_eq(&expected, 1e-12));

let residual = &b - &a.dot(&result.solution);
let residual_ssq = residual.mapv(|x| x.powi(2)).sum_axis(Axis(0));
assert!(result
.residual_sum_of_squares
.unwrap()
.abs_diff_eq(&residual_ssq, 1e-12));
}

//
// Testing error cases
//
use crate::layout::MatrixLayout;

#[test]
fn test_incompatible_shape_error_on_mismatching_num_rows() {
fn incompatible_shape_error_on_mismatching_num_rows() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f64> = array![1., 2.];
let res = a.least_squares(&b);
match res {
Err(LinalgError::Lapack(err)) if matches!(err, lapack::error::Error::InvalidShape) => {}
Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {}
_ => panic!("Expected Err()"),
}
}

#[test]
fn test_incompatible_shape_error_on_mismatching_layout() {
fn incompatible_shape_error_on_mismatching_layout() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b = array![[1.], [2.]].t().to_owned();
assert_eq!(b.layout().unwrap(), MatrixLayout::F { col: 2, lda: 1 });

let res = a.least_squares(&b);
match res {
Err(LinalgError::Lapack(err)) if matches!(err, lapack::error::Error::InvalidShape) => {}
Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {}
_ => panic!("Expected Err()"),
}
}
Expand Down
Loading