Skip to content

Commit 64e37bf

Browse files
authored
Merge pull request #224 from rust-ndarray/lapack-qr
Rewrite QR using LAPACK
2 parents bdb2ffb + cbe6df7 commit 64e37bf

File tree

2 files changed

+131
-22
lines changed

2 files changed

+131
-22
lines changed

lax/src/qr.rs

Lines changed: 129 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,120 @@
22
33
use crate::{error::*, layout::MatrixLayout};
44
use cauchy::*;
5-
use num_traits::Zero;
6-
use std::cmp::min;
5+
use num_traits::{ToPrimitive, Zero};
76

8-
/// Wraps `*geqrf` and `*orgqr` (`*ungqr` for complex numbers)
97
pub trait QR_: Sized {
10-
unsafe fn householder(l: MatrixLayout, a: &mut [Self]) -> Result<Vec<Self>>;
11-
unsafe fn q(l: MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()>;
12-
unsafe fn qr(l: MatrixLayout, a: &mut [Self]) -> Result<Vec<Self>>;
8+
/// Execute Householder reflection as the first step of QR-decomposition
9+
///
10+
/// For C-continuous array,
11+
/// this will call LQ-decomposition of the transposed matrix $ A^T = LQ^T $
12+
fn householder(l: MatrixLayout, a: &mut [Self]) -> Result<Vec<Self>>;
13+
14+
/// Reconstruct Q-matrix from Householder-reflectors
15+
fn q(l: MatrixLayout, a: &mut [Self], tau: &[Self]) -> Result<()>;
16+
17+
/// Execute QR-decomposition at once
18+
fn qr(l: MatrixLayout, a: &mut [Self]) -> Result<Vec<Self>>;
1319
}
1420

1521
macro_rules! impl_qr {
16-
($scalar:ty, $qrf:path, $gqr:path) => {
22+
($scalar:ty, $qrf:path, $lqf:path, $gqr:path, $glq:path) => {
1723
impl QR_ for $scalar {
18-
unsafe fn householder(l: MatrixLayout, mut a: &mut [Self]) -> Result<Vec<Self>> {
19-
let (row, col) = l.size();
20-
let k = min(row, col);
24+
fn householder(l: MatrixLayout, mut a: &mut [Self]) -> Result<Vec<Self>> {
25+
let m = l.lda();
26+
let n = l.len();
27+
let k = m.min(n);
2128
let mut tau = vec![Self::zero(); k as usize];
22-
$qrf(l.lapacke_layout(), row, col, &mut a, l.lda(), &mut tau).as_lapack_result()?;
29+
30+
// eval work size
31+
let mut info = 0;
32+
let mut work_size = [Self::zero()];
33+
unsafe {
34+
match l {
35+
MatrixLayout::F { .. } => {
36+
$qrf(m, n, &mut a, m, &mut tau, &mut work_size, -1, &mut info);
37+
}
38+
MatrixLayout::C { .. } => {
39+
$lqf(m, n, &mut a, m, &mut tau, &mut work_size, -1, &mut info);
40+
}
41+
}
42+
}
43+
info.as_lapack_result()?;
44+
45+
// calc
46+
let lwork = work_size[0].to_usize().unwrap();
47+
let mut work = vec![Self::zero(); lwork];
48+
unsafe {
49+
match l {
50+
MatrixLayout::F { .. } => {
51+
$qrf(
52+
m,
53+
n,
54+
&mut a,
55+
m,
56+
&mut tau,
57+
&mut work,
58+
lwork as i32,
59+
&mut info,
60+
);
61+
}
62+
MatrixLayout::C { .. } => {
63+
$lqf(
64+
m,
65+
n,
66+
&mut a,
67+
m,
68+
&mut tau,
69+
&mut work,
70+
lwork as i32,
71+
&mut info,
72+
);
73+
}
74+
}
75+
}
76+
info.as_lapack_result()?;
77+
2378
Ok(tau)
2479
}
2580

26-
unsafe fn q(l: MatrixLayout, mut a: &mut [Self], tau: &[Self]) -> Result<()> {
27-
let (row, col) = l.size();
28-
let k = min(row, col);
29-
$gqr(l.lapacke_layout(), row, k, k, &mut a, l.lda(), &tau).as_lapack_result()?;
81+
fn q(l: MatrixLayout, mut a: &mut [Self], tau: &[Self]) -> Result<()> {
82+
let m = l.lda();
83+
let n = l.len();
84+
let k = m.min(n);
85+
assert_eq!(tau.len(), k as usize);
86+
87+
// eval work size
88+
let mut info = 0;
89+
let mut work_size = [Self::zero()];
90+
unsafe {
91+
match l {
92+
MatrixLayout::F { .. } => {
93+
$gqr(m, k, k, &mut a, m, &tau, &mut work_size, -1, &mut info)
94+
}
95+
MatrixLayout::C { .. } => {
96+
$glq(k, n, k, &mut a, m, &tau, &mut work_size, -1, &mut info)
97+
}
98+
}
99+
};
100+
101+
// calc
102+
let lwork = work_size[0].to_usize().unwrap();
103+
let mut work = vec![Self::zero(); lwork];
104+
unsafe {
105+
match l {
106+
MatrixLayout::F { .. } => {
107+
$gqr(m, k, k, &mut a, m, &tau, &mut work, lwork as i32, &mut info)
108+
}
109+
MatrixLayout::C { .. } => {
110+
$glq(k, n, k, &mut a, m, &tau, &mut work, lwork as i32, &mut info)
111+
}
112+
}
113+
}
114+
info.as_lapack_result()?;
30115
Ok(())
31116
}
32117

33-
unsafe fn qr(l: MatrixLayout, a: &mut [Self]) -> Result<Vec<Self>> {
118+
fn qr(l: MatrixLayout, a: &mut [Self]) -> Result<Vec<Self>> {
34119
let tau = Self::householder(l, a)?;
35120
let r = Vec::from(&*a);
36121
Self::q(l, a, &tau)?;
@@ -40,7 +125,31 @@ macro_rules! impl_qr {
40125
};
41126
} // endmacro
42127

43-
impl_qr!(f64, lapacke::dgeqrf, lapacke::dorgqr);
44-
impl_qr!(f32, lapacke::sgeqrf, lapacke::sorgqr);
45-
impl_qr!(c64, lapacke::zgeqrf, lapacke::zungqr);
46-
impl_qr!(c32, lapacke::cgeqrf, lapacke::cungqr);
128+
impl_qr!(
129+
f64,
130+
lapack::dgeqrf,
131+
lapack::dgelqf,
132+
lapack::dorgqr,
133+
lapack::dorglq
134+
);
135+
impl_qr!(
136+
f32,
137+
lapack::sgeqrf,
138+
lapack::sgelqf,
139+
lapack::sorgqr,
140+
lapack::sorglq
141+
);
142+
impl_qr!(
143+
c64,
144+
lapack::zgeqrf,
145+
lapack::zgelqf,
146+
lapack::zungqr,
147+
lapack::zunglq
148+
);
149+
impl_qr!(
150+
c32,
151+
lapack::cgeqrf,
152+
lapack::cgelqf,
153+
lapack::cungqr,
154+
lapack::cunglq
155+
);

ndarray-linalg/src/qr.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ where
6161

6262
fn qr_square_inplace(&mut self) -> Result<(&mut Self, Self::R)> {
6363
let l = self.square_layout()?;
64-
let r = unsafe { A::qr(l, self.as_allocated_mut()?)? };
64+
let r = A::qr(l, self.as_allocated_mut()?)?;
6565
let r: Array2<_> = into_matrix(l, r)?;
6666
Ok((self, r.into_triangular(UPLO::Upper)))
6767
}
@@ -107,7 +107,7 @@ where
107107
let m = self.ncols();
108108
let k = ::std::cmp::min(n, m);
109109
let l = self.layout()?;
110-
let r = unsafe { A::qr(l, self.as_allocated_mut()?)? };
110+
let r = A::qr(l, self.as_allocated_mut()?)?;
111111
let r: Array2<_> = into_matrix(l, r)?;
112112
let q = self;
113113
Ok((take_slice(&q, n, k), take_slice_upper(&r, k, m)))

0 commit comments

Comments
 (0)