Skip to content

Commit c033cb9

Browse files
authored
Merge pull request #184 from bytesnake/master
Add LOBPCG solver for large symmetric positive definite eigenproblems
2 parents 8b55efc + ae2ce6a commit c033cb9

File tree

13 files changed

+1194
-19
lines changed

13 files changed

+1194
-19
lines changed

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ optional = true
5252
paste = "0.1.9"
5353
criterion = "0.3.1"
5454

55+
[[bench]]
56+
name = "truncated_eig"
57+
harness = false
58+
5559
[[bench]]
5660
name = "eigh"
5761
harness = false
62+

benches/truncated_eig.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#[macro_use]
2+
extern crate criterion;
3+
4+
use criterion::Criterion;
5+
use ndarray::*;
6+
use ndarray_linalg::*;
7+
8+
macro_rules! impl_teig {
9+
($n:expr) => {
10+
paste::item! {
11+
fn [<teig $n>](c: &mut Criterion) {
12+
c.bench_function(&format!("truncated_eig{}", $n), |b| {
13+
let a: Array2<f64> = random(($n, $n));
14+
let a = a.t().dot(&a);
15+
16+
b.iter(move || {
17+
let _result = TruncatedEig::new(a.clone(), TruncatedOrder::Largest).decompose(1);
18+
})
19+
});
20+
c.bench_function(&format!("truncated_eig{}_t", $n), |b| {
21+
let a: Array2<f64> = random(($n, $n).f());
22+
let a = a.t().dot(&a);
23+
24+
b.iter(|| {
25+
let _result = TruncatedEig::new(a.clone(), TruncatedOrder::Largest).decompose(1);
26+
})
27+
});
28+
}
29+
}
30+
};
31+
}
32+
33+
impl_teig!(4);
34+
impl_teig!(8);
35+
impl_teig!(16);
36+
impl_teig!(32);
37+
impl_teig!(64);
38+
impl_teig!(128);
39+
40+
criterion_group!(teig, teig4, teig8, teig16, teig32, teig64, teig128);
41+
criterion_main!(teig);

examples/truncated_eig.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
extern crate ndarray;
2+
extern crate ndarray_linalg;
3+
4+
use ndarray::*;
5+
use ndarray_linalg::*;
6+
7+
fn main() {
8+
let n = 10;
9+
let v = random_unitary(n);
10+
11+
// set eigenvalues in decreasing order
12+
let t = Array1::linspace(n as f64, -(n as f64), n);
13+
14+
println!("Generate spectrum: {:?}", &t);
15+
16+
let t = Array2::from_diag(&t);
17+
let a = v.dot(&t.dot(&v.t()));
18+
19+
// calculate the truncated eigenproblem decomposition
20+
for (val, _) in TruncatedEig::new(a, TruncatedOrder::Largest) {
21+
println!("Found eigenvalue {}", val[0]);
22+
}
23+
}

examples/truncated_svd.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
extern crate ndarray;
2+
extern crate ndarray_linalg;
3+
4+
use ndarray::*;
5+
use ndarray_linalg::*;
6+
7+
fn main() {
8+
let a = arr2(&[[3., 2., 2.], [2., 3., -2.]]);
9+
10+
// calculate the truncated singular value decomposition for 2 singular values
11+
let result = TruncatedSvd::new(a, TruncatedOrder::Largest).decompose(2).unwrap();
12+
13+
// acquire singular values, left-singular vectors and right-singular vectors
14+
let (u, sigma, v_t) = result.values_vectors();
15+
println!("Result of the singular value decomposition A = UΣV^T:");
16+
println!(" === U ===");
17+
println!("{:?}", u);
18+
println!(" === Σ ===");
19+
println!("{:?}", Array2::from_diag(&sigma));
20+
println!(" === V^T ===");
21+
println!("{:?}", v_t);
22+
}

src/eigh.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,20 @@ where
4242
}
4343
}
4444

45+
impl<A, S, S2> EighInto for (ArrayBase<S, Ix2>, ArrayBase<S2, Ix2>)
46+
where
47+
A: Scalar + Lapack,
48+
S: DataMut<Elem = A>,
49+
S2: DataMut<Elem = A>,
50+
{
51+
type EigVal = Array1<A::Real>;
52+
53+
fn eigh_into(mut self, uplo: UPLO) -> Result<(Self::EigVal, Self)> {
54+
let (val, _) = self.eigh_inplace(uplo)?;
55+
Ok((val, self))
56+
}
57+
}
58+
4559
impl<A, S> Eigh for ArrayBase<S, Ix2>
4660
where
4761
A: Scalar + Lapack,
@@ -56,6 +70,21 @@ where
5670
}
5771
}
5872

73+
impl<A, S, S2> Eigh for (ArrayBase<S, Ix2>, ArrayBase<S2, Ix2>)
74+
where
75+
A: Scalar + Lapack,
76+
S: Data<Elem = A>,
77+
S2: Data<Elem = A>,
78+
{
79+
type EigVal = Array1<A::Real>;
80+
type EigVec = (Array2<A>, Array2<A>);
81+
82+
fn eigh(&self, uplo: UPLO) -> Result<(Self::EigVal, Self::EigVec)> {
83+
let (a, b) = (self.0.to_owned(), self.1.to_owned());
84+
(a, b).eigh_into(uplo)
85+
}
86+
}
87+
5988
impl<A, S> EighInplace for ArrayBase<S, Ix2>
6089
where
6190
A: Scalar + Lapack,
@@ -75,6 +104,42 @@ where
75104
}
76105
}
77106

107+
impl<A, S, S2> EighInplace for (ArrayBase<S, Ix2>, ArrayBase<S2, Ix2>)
108+
where
109+
A: Scalar + Lapack,
110+
S: DataMut<Elem = A>,
111+
S2: DataMut<Elem = A>,
112+
{
113+
type EigVal = Array1<A::Real>;
114+
115+
fn eigh_inplace(&mut self, uplo: UPLO) -> Result<(Self::EigVal, &mut Self)> {
116+
let layout = self.0.square_layout()?;
117+
// XXX Force layout to be Fortran (see #146)
118+
match layout {
119+
MatrixLayout::C(_) => self.0.swap_axes(0, 1),
120+
MatrixLayout::F(_) => {}
121+
}
122+
123+
let layout = self.1.square_layout()?;
124+
match layout {
125+
MatrixLayout::C(_) => self.1.swap_axes(0, 1),
126+
MatrixLayout::F(_) => {}
127+
}
128+
129+
let s = unsafe {
130+
A::eigh_generalized(
131+
true,
132+
self.0.square_layout()?,
133+
uplo,
134+
self.0.as_allocated_mut()?,
135+
self.1.as_allocated_mut()?,
136+
)?
137+
};
138+
139+
Ok((ArrayBase::from(s), self))
140+
}
141+
}
142+
78143
/// Calculate eigenvalues without eigenvectors
79144
pub trait EigValsh {
80145
type EigVal;

src/lapack/eigh.rs

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,17 @@ use super::{into_result, UPLO};
1212
/// Wraps `*syev` for real and `*heev` for complex
1313
pub trait Eigh_: Scalar {
1414
unsafe fn eigh(calc_eigenvec: bool, l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Vec<Self::Real>>;
15+
unsafe fn eigh_generalized(
16+
calc_eigenvec: bool,
17+
l: MatrixLayout,
18+
uplo: UPLO,
19+
a: &mut [Self],
20+
b: &mut [Self],
21+
) -> Result<Vec<Self::Real>>;
1522
}
1623

1724
macro_rules! impl_eigh {
18-
($scalar:ty, $ev:path) => {
25+
($scalar:ty, $ev:path, $evg:path) => {
1926
impl Eigh_ for $scalar {
2027
unsafe fn eigh(calc_v: bool, l: MatrixLayout, uplo: UPLO, mut a: &mut [Self]) -> Result<Vec<Self::Real>> {
2128
let (n, _) = l.size();
@@ -24,11 +31,36 @@ macro_rules! impl_eigh {
2431
let info = $ev(l.lapacke_layout(), jobz, uplo as u8, n, &mut a, n, &mut w);
2532
into_result(info, w)
2633
}
34+
35+
unsafe fn eigh_generalized(
36+
calc_v: bool,
37+
l: MatrixLayout,
38+
uplo: UPLO,
39+
mut a: &mut [Self],
40+
mut b: &mut [Self],
41+
) -> Result<Vec<Self::Real>> {
42+
let (n, _) = l.size();
43+
let jobz = if calc_v { b'V' } else { b'N' };
44+
let mut w = vec![Self::Real::zero(); n as usize];
45+
let info = $evg(
46+
l.lapacke_layout(),
47+
1,
48+
jobz,
49+
uplo as u8,
50+
n,
51+
&mut a,
52+
n,
53+
&mut b,
54+
n,
55+
&mut w,
56+
);
57+
into_result(info, w)
58+
}
2759
}
2860
};
2961
} // impl_eigh!
3062

31-
impl_eigh!(f64, lapacke::dsyev);
32-
impl_eigh!(f32, lapacke::ssyev);
33-
impl_eigh!(c64, lapacke::zheev);
34-
impl_eigh!(c32, lapacke::cheev);
63+
impl_eigh!(f64, lapacke::dsyev, lapacke::dsygv);
64+
impl_eigh!(f32, lapacke::ssyev, lapacke::ssygv);
65+
impl_eigh!(c64, lapacke::zheev, lapacke::zhegv);
66+
impl_eigh!(c32, lapacke::cheev, lapacke::chegv);

src/lapack/svddc.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ use num_traits::Zero;
33

44
use crate::error::*;
55
use crate::layout::MatrixLayout;
6-
use crate::types::*;
76
use crate::svddc::UVTFlag;
7+
use crate::types::*;
88

9-
use super::{SVDOutput, into_result};
9+
use super::{into_result, SVDOutput};
1010

1111
pub trait SVDDC_: Scalar {
1212
unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result<SVDOutput<Self>>;
@@ -15,11 +15,7 @@ pub trait SVDDC_: Scalar {
1515
macro_rules! impl_svdd {
1616
($scalar:ty, $gesdd:path) => {
1717
impl SVDDC_ for $scalar {
18-
unsafe fn svddc(
19-
l: MatrixLayout,
20-
jobz: UVTFlag,
21-
mut a: &mut [Self],
22-
) -> Result<SVDOutput<Self>> {
18+
unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, mut a: &mut [Self]) -> Result<SVDOutput<Self>> {
2319
let (m, n) = l.size();
2420
let k = m.min(n);
2521
let lda = l.lda();
@@ -51,11 +47,7 @@ macro_rules! impl_svdd {
5147
SVDOutput {
5248
s: s,
5349
u: if jobz == UVTFlag::None { None } else { Some(u) },
54-
vt: if jobz == UVTFlag::None {
55-
None
56-
} else {
57-
Some(vt)
58-
},
50+
vt: if jobz == UVTFlag::None { None } else { Some(vt) },
5951
},
6052
)
6153
}

src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
//! - [Random matrix generators](generate/index.html)
3838
//! - [Scalar trait](types/trait.Scalar.html)
3939
40+
#[macro_use]
41+
extern crate ndarray;
42+
4043
extern crate blas_src;
4144
extern crate lapack_src;
4245

@@ -52,6 +55,7 @@ pub mod inner;
5255
pub mod krylov;
5356
pub mod lapack;
5457
pub mod layout;
58+
pub mod lobpcg;
5559
pub mod norm;
5660
pub mod operator;
5761
pub mod opnorm;
@@ -73,6 +77,7 @@ pub use eigh::*;
7377
pub use generate::*;
7478
pub use inner::*;
7579
pub use layout::*;
80+
pub use lobpcg::{TruncatedEig, TruncatedOrder, TruncatedSvd};
7681
pub use norm::*;
7782
pub use operator::*;
7883
pub use opnorm::*;

0 commit comments

Comments
 (0)