Skip to content

Fix Solve::solve_h_* for complex inputs with standard layout #296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions lax/src/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,49 @@ macro_rules! impl_solve {
ipiv: &Pivot,
b: &mut [Self],
) -> Result<()> {
let t = match l {
// If the array has C layout, then it needs to be handled
// specially, since LAPACK expects a Fortran-layout array.
// Reinterpreting a C layout array as Fortran layout is
// equivalent to transposing it. So, we can handle the "no
// transpose" and "transpose" cases by swapping to "transpose"
// or "no transpose", respectively. For the "Hermite" case, we
// can take advantage of the following:
//
// ```text
// A^H x = b
// ⟺ conj(A^T) x = b
// ⟺ conj(conj(A^T) x) = conj(b)
// ⟺ conj(conj(A^T)) conj(x) = conj(b)
// ⟺ A^T conj(x) = conj(b)
// ```
//
// So, we can handle this case by switching to "no transpose"
// (which is equivalent to transposing the array since it will
// be reinterpreted as Fortran layout) and applying the
// elementwise conjugate to `x` and `b`.
let (t, conj) = match l {
MatrixLayout::C { .. } => match t {
Transpose::No => Transpose::Transpose,
Transpose::Transpose | Transpose::Hermite => Transpose::No,
Transpose::No => (Transpose::Transpose, false),
Transpose::Transpose => (Transpose::No, false),
Transpose::Hermite => (Transpose::No, true),
},
_ => t,
MatrixLayout::F { .. } => (t, false),
};
let (n, _) = l.size();
let nrhs = 1;
let ldb = l.lda();
let mut info = 0;
if conj {
for b_elem in &mut *b {
*b_elem = b_elem.conj();
}
}
unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) };
if conj {
for b_elem in &mut *b {
*b_elem = b_elem.conj();
}
}
info.as_lapack_result()?;
Ok(())
}
Expand Down
202 changes: 174 additions & 28 deletions ndarray-linalg/tests/solve.rs
Original file line number Diff line number Diff line change
@@ -1,42 +1,188 @@
use ndarray::*;
use ndarray_linalg::*;
use ndarray::prelude::*;
use ndarray_linalg::{
assert_aclose, assert_close_l2, c32, c64, random, random_hpd, solve::*, OperationNorm, Scalar,
};

macro_rules! test_solve {
(
[$($elem_type:ty => $rtol:expr),*],
$a_ident:ident = $a:expr,
$x_ident:ident = $x:expr,
b = $b:expr,
$solve:ident,
) => {
$({
let $a_ident: Array2<$elem_type> = $a;
let $x_ident: Array1<$elem_type> = $x;
let b: Array1<$elem_type> = $b;
let a = $a_ident;
let x = $x_ident;
let rtol = $rtol;
assert_close_l2!(&a.$solve(&b).unwrap(), &x, rtol);
assert_close_l2!(&a.factorize().unwrap().$solve(&b).unwrap(), &x, rtol);
assert_close_l2!(&a.factorize_into().unwrap().$solve(&b).unwrap(), &x, rtol);
})*
};
}

macro_rules! test_solve_into {
(
[$($elem_type:ty => $rtol:expr),*],
$a_ident:ident = $a:expr,
$x_ident:ident = $x:expr,
b = $b:expr,
$solve_into:ident,
) => {
$({
let $a_ident: Array2<$elem_type> = $a;
let $x_ident: Array1<$elem_type> = $x;
let b: Array1<$elem_type> = $b;
let a = $a_ident;
let x = $x_ident;
let rtol = $rtol;
assert_close_l2!(&a.$solve_into(b.clone()).unwrap(), &x, rtol);
assert_close_l2!(&a.factorize().unwrap().$solve_into(b.clone()).unwrap(), &x, rtol);
assert_close_l2!(&a.factorize_into().unwrap().$solve_into(b.clone()).unwrap(), &x, rtol);
})*
};
}

macro_rules! test_solve_inplace {
(
[$($elem_type:ty => $rtol:expr),*],
$a_ident:ident = $a:expr,
$x_ident:ident = $x:expr,
b = $b:expr,
$solve_inplace:ident,
) => {
$({
let $a_ident: Array2<$elem_type> = $a;
let $x_ident: Array1<$elem_type> = $x;
let b: Array1<$elem_type> = $b;
let a = $a_ident;
let x = $x_ident;
let rtol = $rtol;
{
let mut b = b.clone();
assert_close_l2!(&a.$solve_inplace(&mut b).unwrap(), &x, rtol);
assert_close_l2!(&b, &x, rtol);
}
{
let mut b = b.clone();
assert_close_l2!(&a.factorize().unwrap().$solve_inplace(&mut b).unwrap(), &x, rtol);
assert_close_l2!(&b, &x, rtol);
}
{
let mut b = b.clone();
assert_close_l2!(&a.factorize_into().unwrap().$solve_inplace(&mut b).unwrap(), &x, rtol);
assert_close_l2!(&b, &x, rtol);
}
})*
};
}

macro_rules! test_solve_all {
(
[$($elem_type:ty => $rtol:expr),*],
$a_ident:ident = $a:expr,
$x_ident:ident = $x:expr,
b = $b:expr,
[$solve:ident, $solve_into:ident, $solve_inplace:ident],
) => {
test_solve!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve,);
test_solve_into!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve_into,);
test_solve_inplace!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve_inplace,);
};
}

#[test]
fn solve_random_float() {
for n in 0..=8 {
for &set_f in &[false, true] {
test_solve_all!(
[f32 => 1e-3, f64 => 1e-9],
a = random([n; 2].set_f(set_f)),
x = random(n),
b = a.dot(&x),
[solve, solve_into, solve_inplace],
);
}
}
}

#[test]
fn solve_random_complex() {
for n in 0..=8 {
for &set_f in &[false, true] {
test_solve_all!(
[c32 => 1e-3, c64 => 1e-9],
a = random([n; 2].set_f(set_f)),
x = random(n),
b = a.dot(&x),
[solve, solve_into, solve_inplace],
);
}
}
}

#[test]
fn solve_random() {
let a: Array2<f64> = random((3, 3));
let x: Array1<f64> = random(3);
let b = a.dot(&x);
let y = a.solve_into(b).unwrap();
assert_close_l2!(&x, &y, 1e-7);
fn solve_t_random_float() {
for n in 0..=8 {
for &set_f in &[false, true] {
test_solve_all!(
[f32 => 1e-3, f64 => 1e-9],
a = random([n; 2].set_f(set_f)),
x = random(n),
b = a.t().dot(&x),
[solve_t, solve_t_into, solve_t_inplace],
);
}
}
}

#[test]
fn solve_random_t() {
let a: Array2<f64> = random((3, 3).f());
let x: Array1<f64> = random(3);
let b = a.dot(&x);
let y = a.solve_into(b).unwrap();
assert_close_l2!(&x, &y, 1e-7);
fn solve_t_random_complex() {
for n in 0..=8 {
for &set_f in &[false, true] {
test_solve_all!(
[c32 => 1e-3, c64 => 1e-9],
a = random([n; 2].set_f(set_f)),
x = random(n),
b = a.t().dot(&x),
[solve_t, solve_t_into, solve_t_inplace],
);
}
}
}

#[test]
fn solve_factorized() {
let a: Array2<f64> = random((3, 3));
let ans: Array1<f64> = random(3);
let b = a.dot(&ans);
let f = a.factorize_into().unwrap();
let x = f.solve_into(b).unwrap();
assert_close_l2!(&x, &ans, 1e-7);
fn solve_h_random_float() {
for n in 0..=8 {
for &set_f in &[false, true] {
test_solve_all!(
[f32 => 1e-3, f64 => 1e-9],
a = random([n; 2].set_f(set_f)),
x = random(n),
b = a.t().mapv(|x| x.conj()).dot(&x),
[solve_h, solve_h_into, solve_h_inplace],
);
}
}
}

#[test]
fn solve_factorized_t() {
let a: Array2<f64> = random((3, 3).f());
let ans: Array1<f64> = random(3);
let b = a.dot(&ans);
let f = a.factorize_into().unwrap();
let x = f.solve_into(b).unwrap();
assert_close_l2!(&x, &ans, 1e-7);
fn solve_h_random_complex() {
for n in 0..=8 {
for &set_f in &[false, true] {
test_solve_all!(
[c32 => 1e-3, c64 => 1e-9],
a = random([n; 2].set_f(set_f)),
x = random(n),
b = a.t().mapv(|x| x.conj()).dot(&x),
[solve_h, solve_h_into, solve_h_inplace],
);
}
}
}

#[test]
Expand Down