Skip to content

Fix Eig for column-major arrays with real elements #298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 47 additions & 31 deletions lax/src/eig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@ macro_rules! impl_eig_complex {
mut a: &mut [Self],
) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
let (n, _) = l.size();
// Because LAPACK assumes F-continious array, C-continious array should be taken Hermitian conjugate.
// However, we utilize a fact that left eigenvector of A^H corresponds to the right eigenvector of A
// LAPACK assumes a column-major input. A row-major input can
// be interpreted as the transpose of a column-major input. So,
// for row-major inputs, we we want to solve the following,
// given the column-major input `A`:
//
// A^T V = V Λ ⟺ V^T A = Λ V^T ⟺ conj(V)^H A = Λ conj(V)^H
//
// So, in this case, the right eigenvectors are the conjugates
// of the left eigenvectors computed with `A`, and the
// eigenvalues are the eigenvalues computed with `A`.
let (jobvl, jobvr) = if calc_v {
match l {
MatrixLayout::C { .. } => (b'V', b'N'),
Expand Down Expand Up @@ -118,8 +126,22 @@ macro_rules! impl_eig_real {
mut a: &mut [Self],
) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
let (n, _) = l.size();
// Because LAPACK assumes F-continious array, C-continious array should be taken Hermitian conjugate.
// However, we utilize a fact that left eigenvector of A^H corresponds to the right eigenvector of A
// LAPACK assumes a column-major input. A row-major input can
// be interpreted as the transpose of a column-major input. So,
// for row-major inputs, we we want to solve the following,
// given the column-major input `A`:
//
// A^T V = V Λ ⟺ V^T A = Λ V^T ⟺ conj(V)^H A = Λ conj(V)^H
//
// So, in this case, the right eigenvectors are the conjugates
// of the left eigenvectors computed with `A`, and the
// eigenvalues are the eigenvalues computed with `A`.
//
// We could conjugate the eigenvalues instead of the
// eigenvectors, but we have to reconstruct the eigenvectors
// into new matrices anyway, and by not modifying the
// eigenvalues, we preserve the nice ordering specified by
// `sgeev`/`dgeev`.
let (jobvl, jobvr) = if calc_v {
match l {
MatrixLayout::C { .. } => (b'V', b'N'),
Expand Down Expand Up @@ -211,40 +233,34 @@ macro_rules! impl_eig_real {
// - v(j) = VR(:,j) + i*VR(:,j+1)
// - v(j+1) = VR(:,j) - i*VR(:,j+1).
//
// ```
// j -> <----pair----> <----pair---->
// [ ... (real), (imag), (imag), (imag), (imag), ... ] : eigs
// ^ ^ ^ ^ ^
// false false true false true : is_conjugate_pair
// ```
// In the C-layout case, we need the conjugates of the left
// eigenvectors, so the signs should be reversed.

let n = n as usize;
let v = vr.or(vl).unwrap();
let mut eigvecs = unsafe { vec_uninit(n * n) };
let mut is_conjugate_pair = false; // flag for check `j` is complex conjugate
for j in 0..n {
if eig_im[j] == 0.0 {
// j-th eigenvalue is real
for i in 0..n {
eigvecs[i + j * n] = Self::complex(v[i + j * n], 0.0);
let mut col = 0;
while col < n {
if eig_im[col] == 0. {
// The corresponding eigenvalue is real.
for row in 0..n {
let re = v[row + col * n];
eigvecs[row + col * n] = Self::complex(re, 0.);
}
col += 1;
} else {
// j-th eigenvalue is complex
// complex conjugated pair can be `j-1` or `j+1`
if is_conjugate_pair {
let j_pair = j - 1;
assert!(j_pair < n);
for i in 0..n {
eigvecs[i + j * n] = Self::complex(v[i + j_pair * n], v[i + j * n]);
}
} else {
let j_pair = j + 1;
assert!(j_pair < n);
for i in 0..n {
eigvecs[i + j * n] =
Self::complex(v[i + j * n], -v[i + j_pair * n]);
// This is a complex conjugate pair.
assert!(col + 1 < n);
for row in 0..n {
let re = v[row + col * n];
let mut im = v[row + (col + 1) * n];
if jobvl == b'V' {
im = -im;
}
eigvecs[row + col * n] = Self::complex(re, im);
eigvecs[row + (col + 1) * n] = Self::complex(re, -im);
}
is_conjugate_pair = !is_conjugate_pair;
col += 2;
}
}

Expand Down
32 changes: 26 additions & 6 deletions ndarray-linalg/tests/eig.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
use ndarray::*;
use ndarray_linalg::*;

fn sorted_eigvals<T: Scalar>(eigvals: ArrayView1<'_, T>) -> Array1<T> {
let mut indices: Vec<usize> = (0..eigvals.len()).collect();
indices.sort_by(|&ind1, &ind2| {
let e1 = eigvals[ind1];
let e2 = eigvals[ind2];
e1.re()
.partial_cmp(&e2.re())
.unwrap()
.then(e1.im().partial_cmp(&e2.im()).unwrap())
});
indices.iter().map(|&ind| eigvals[ind]).collect()
}

// Test Av_i = e_i v_i for i = 0..n
fn test_eig<T: Scalar>(a: Array2<T>, eigs: Array1<T::Complex>, vecs: Array2<T::Complex>)
where
Expand Down Expand Up @@ -87,7 +100,10 @@ fn test_matrix_real<T: Scalar>() -> Array2<T::Real> {
}

fn test_matrix_real_t<T: Scalar>() -> Array2<T::Real> {
test_matrix_real::<T>().t().permuted_axes([1, 0]).to_owned()
let orig = test_matrix_real::<T>();
let mut out = Array2::zeros(orig.raw_dim().f());
out.assign(&orig);
out
}

fn answer_eig_real<T: Scalar>() -> Array1<T::Complex> {
Expand Down Expand Up @@ -154,10 +170,10 @@ fn test_matrix_complex<T: Scalar>() -> Array2<T::Complex> {
}

fn test_matrix_complex_t<T: Scalar>() -> Array2<T::Complex> {
test_matrix_complex::<T>()
.t()
.permuted_axes([1, 0])
.to_owned()
let orig = test_matrix_complex::<T>();
let mut out = Array2::zeros(orig.raw_dim().f());
out.assign(&orig);
out
}

fn answer_eig_complex<T: Scalar>() -> Array1<T::Complex> {
Expand Down Expand Up @@ -213,7 +229,11 @@ macro_rules! impl_test_real {
fn [<$real _eigvals_t>]() {
let a = test_matrix_real_t::<$real>();
let (e, _vecs) = a.eig().unwrap();
assert_close_l2!(&e, &answer_eig_real::<$real>(), 1.0e-3);
assert_close_l2!(
&sorted_eigvals(e.view()),
&sorted_eigvals(answer_eig_real::<$real>().view()),
1.0e-3
);
}

#[test]
Expand Down