Skip to content

Commit a4b3118

Browse files
authored
Merge pull request #331 from rust-ndarray/revise-flag-naming
Revise enum namings
2 parents 08aae3b + 86c61c3 commit a4b3118

File tree

9 files changed

+186
-191
lines changed

9 files changed

+186
-191
lines changed

lax/src/eig.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ macro_rules! impl_eig_complex {
3535
// eigenvalues are the eigenvalues computed with `A`.
3636
let (jobvl, jobvr) = if calc_v {
3737
match l {
38-
MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not),
39-
MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc),
38+
MatrixLayout::C { .. } => (JobEv::All, JobEv::None),
39+
MatrixLayout::F { .. } => (JobEv::None, JobEv::All),
4040
}
4141
} else {
42-
(EigenVectorFlag::Not, EigenVectorFlag::Not)
42+
(JobEv::None, JobEv::None)
4343
};
4444
let mut eigs: Vec<MaybeUninit<Self>> = unsafe { vec_uninit(n as usize) };
4545
let mut rwork: Vec<MaybeUninit<Self::Real>> = unsafe { vec_uninit(2 * n as usize) };
@@ -143,11 +143,11 @@ macro_rules! impl_eig_real {
143143
// `sgeev`/`dgeev`.
144144
let (jobvl, jobvr) = if calc_v {
145145
match l {
146-
MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not),
147-
MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc),
146+
MatrixLayout::C { .. } => (JobEv::All, JobEv::None),
147+
MatrixLayout::F { .. } => (JobEv::None, JobEv::All),
148148
}
149149
} else {
150-
(EigenVectorFlag::Not, EigenVectorFlag::Not)
150+
(JobEv::None, JobEv::None)
151151
};
152152
let mut eig_re: Vec<MaybeUninit<Self>> = unsafe { vec_uninit(n as usize) };
153153
let mut eig_im: Vec<MaybeUninit<Self>> = unsafe { vec_uninit(n as usize) };

lax/src/eigh.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ macro_rules! impl_eigh {
4141
) -> Result<Vec<Self::Real>> {
4242
assert_eq!(layout.len(), layout.lda());
4343
let n = layout.len();
44-
let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not };
44+
let jobz = if calc_v { JobEv::All } else { JobEv::None };
4545
let mut eigs: Vec<MaybeUninit<Self::Real>> = unsafe { vec_uninit(n as usize) };
4646

4747
$(
@@ -100,7 +100,7 @@ macro_rules! impl_eigh {
100100
) -> Result<Vec<Self::Real>> {
101101
assert_eq!(layout.len(), layout.lda());
102102
let n = layout.len();
103-
let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not };
103+
let jobz = if calc_v { JobEv::All } else { JobEv::None };
104104
let mut eigs: Vec<MaybeUninit<Self::Real>> = unsafe { vec_uninit(n as usize) };
105105

106106
$(

lax/src/flags.rs

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
//! Charactor flags, e.g. `'T'`, used in LAPACK API
2+
3+
/// Upper/Lower specification for seveal usages
4+
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
5+
#[repr(u8)]
6+
pub enum UPLO {
7+
Upper = b'U',
8+
Lower = b'L',
9+
}
10+
11+
impl UPLO {
12+
pub fn t(self) -> Self {
13+
match self {
14+
UPLO::Upper => UPLO::Lower,
15+
UPLO::Lower => UPLO::Upper,
16+
}
17+
}
18+
19+
/// To use Fortran LAPACK API in lapack-sys crate
20+
pub fn as_ptr(&self) -> *const i8 {
21+
self as *const UPLO as *const i8
22+
}
23+
}
24+
25+
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
26+
#[repr(u8)]
27+
pub enum Transpose {
28+
No = b'N',
29+
Transpose = b'T',
30+
Hermite = b'C',
31+
}
32+
33+
impl Transpose {
34+
/// To use Fortran LAPACK API in lapack-sys crate
35+
pub fn as_ptr(&self) -> *const i8 {
36+
self as *const Transpose as *const i8
37+
}
38+
}
39+
40+
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
41+
#[repr(u8)]
42+
pub enum NormType {
43+
One = b'O',
44+
Infinity = b'I',
45+
Frobenius = b'F',
46+
}
47+
48+
impl NormType {
49+
pub fn transpose(self) -> Self {
50+
match self {
51+
NormType::One => NormType::Infinity,
52+
NormType::Infinity => NormType::One,
53+
NormType::Frobenius => NormType::Frobenius,
54+
}
55+
}
56+
57+
/// To use Fortran LAPACK API in lapack-sys crate
58+
pub fn as_ptr(&self) -> *const i8 {
59+
self as *const NormType as *const i8
60+
}
61+
}
62+
63+
/// Flag for calculating eigenvectors or not
64+
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
65+
#[repr(u8)]
66+
pub enum JobEv {
67+
/// Calculate eigenvectors in addition to eigenvalues
68+
All = b'V',
69+
/// Do not calculate eigenvectors. Only calculate eigenvalues.
70+
None = b'N',
71+
}
72+
73+
impl JobEv {
74+
pub fn is_calc(&self) -> bool {
75+
match self {
76+
JobEv::All => true,
77+
JobEv::None => false,
78+
}
79+
}
80+
81+
pub fn then<T, F: FnOnce() -> T>(&self, f: F) -> Option<T> {
82+
if self.is_calc() {
83+
Some(f())
84+
} else {
85+
None
86+
}
87+
}
88+
89+
/// To use Fortran LAPACK API in lapack-sys crate
90+
pub fn as_ptr(&self) -> *const i8 {
91+
self as *const JobEv as *const i8
92+
}
93+
}
94+
95+
/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned.
96+
///
97+
/// For an input array of shape *m*×*n*, the following are computed:
98+
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
99+
#[repr(u8)]
100+
pub enum JobSvd {
101+
/// All *m* columns of *U* and all *n* rows of *V*ᵀ.
102+
All = b'A',
103+
/// The first min(*m*,*n*) columns of *U* and the first min(*m*,*n*) rows of *V*ᵀ.
104+
Some = b'S',
105+
/// No columns of *U* or rows of *V*ᵀ.
106+
None = b'N',
107+
}
108+
109+
impl JobSvd {
110+
pub fn from_bool(calc_uv: bool) -> Self {
111+
if calc_uv {
112+
JobSvd::All
113+
} else {
114+
JobSvd::None
115+
}
116+
}
117+
118+
pub fn as_ptr(&self) -> *const i8 {
119+
self as *const JobSvd as *const i8
120+
}
121+
}
122+
123+
/// Specify whether input triangular matrix is unit or not
124+
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
125+
#[repr(u8)]
126+
pub enum Diag {
127+
/// Unit triangular matrix, i.e. all diagonal elements of the matrix are `1`
128+
Unit = b'U',
129+
/// Non-unit triangular matrix. Its diagonal elements may be different from `1`
130+
NonUnit = b'N',
131+
}
132+
133+
impl Diag {
134+
pub fn as_ptr(&self) -> *const i8 {
135+
self as *const Diag as *const i8
136+
}
137+
}

lax/src/lib.rs

Lines changed: 2 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ extern crate openblas_src as _src;
6969
extern crate netlib_src as _src;
7070

7171
pub mod error;
72+
pub mod flags;
7273
pub mod layout;
7374

7475
mod cholesky;
@@ -88,6 +89,7 @@ mod tridiagonal;
8889
pub use self::cholesky::*;
8990
pub use self::eig::*;
9091
pub use self::eigh::*;
92+
pub use self::flags::*;
9193
pub use self::least_squares::*;
9294
pub use self::opnorm::*;
9395
pub use self::qr::*;
@@ -173,96 +175,6 @@ impl<T> VecAssumeInit for Vec<MaybeUninit<T>> {
173175
}
174176
}
175177

176-
/// Upper/Lower specification for seveal usages
177-
#[derive(Debug, Clone, Copy)]
178-
#[repr(u8)]
179-
pub enum UPLO {
180-
Upper = b'U',
181-
Lower = b'L',
182-
}
183-
184-
impl UPLO {
185-
pub fn t(self) -> Self {
186-
match self {
187-
UPLO::Upper => UPLO::Lower,
188-
UPLO::Lower => UPLO::Upper,
189-
}
190-
}
191-
192-
/// To use Fortran LAPACK API in lapack-sys crate
193-
pub fn as_ptr(&self) -> *const i8 {
194-
self as *const UPLO as *const i8
195-
}
196-
}
197-
198-
#[derive(Debug, Clone, Copy)]
199-
#[repr(u8)]
200-
pub enum Transpose {
201-
No = b'N',
202-
Transpose = b'T',
203-
Hermite = b'C',
204-
}
205-
206-
impl Transpose {
207-
/// To use Fortran LAPACK API in lapack-sys crate
208-
pub fn as_ptr(&self) -> *const i8 {
209-
self as *const Transpose as *const i8
210-
}
211-
}
212-
213-
#[derive(Debug, Clone, Copy)]
214-
#[repr(u8)]
215-
pub enum NormType {
216-
One = b'O',
217-
Infinity = b'I',
218-
Frobenius = b'F',
219-
}
220-
221-
impl NormType {
222-
pub fn transpose(self) -> Self {
223-
match self {
224-
NormType::One => NormType::Infinity,
225-
NormType::Infinity => NormType::One,
226-
NormType::Frobenius => NormType::Frobenius,
227-
}
228-
}
229-
230-
/// To use Fortran LAPACK API in lapack-sys crate
231-
pub fn as_ptr(&self) -> *const i8 {
232-
self as *const NormType as *const i8
233-
}
234-
}
235-
236-
/// Flag for calculating eigenvectors or not
237-
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
238-
#[repr(u8)]
239-
pub enum EigenVectorFlag {
240-
Calc = b'V',
241-
Not = b'N',
242-
}
243-
244-
impl EigenVectorFlag {
245-
pub fn is_calc(&self) -> bool {
246-
match self {
247-
EigenVectorFlag::Calc => true,
248-
EigenVectorFlag::Not => false,
249-
}
250-
}
251-
252-
pub fn then<T, F: FnOnce() -> T>(&self, f: F) -> Option<T> {
253-
if self.is_calc() {
254-
Some(f())
255-
} else {
256-
None
257-
}
258-
}
259-
260-
/// To use Fortran LAPACK API in lapack-sys crate
261-
pub fn as_ptr(&self) -> *const i8 {
262-
self as *const EigenVectorFlag as *const i8
263-
}
264-
}
265-
266178
/// Create a vector without initialization
267179
///
268180
/// Safety

lax/src/svd.rs

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,9 @@
11
//! Singular-value decomposition
22
3-
use crate::{error::*, layout::MatrixLayout, *};
3+
use super::{error::*, layout::*, *};
44
use cauchy::*;
55
use num_traits::{ToPrimitive, Zero};
66

7-
#[repr(u8)]
8-
#[derive(Debug, Copy, Clone)]
9-
enum FlagSVD {
10-
All = b'A',
11-
// OverWrite = b'O',
12-
// Separately = b'S',
13-
No = b'N',
14-
}
15-
16-
impl FlagSVD {
17-
fn from_bool(calc_uv: bool) -> Self {
18-
if calc_uv {
19-
FlagSVD::All
20-
} else {
21-
FlagSVD::No
22-
}
23-
}
24-
25-
fn as_ptr(&self) -> *const i8 {
26-
self as *const FlagSVD as *const i8
27-
}
28-
}
29-
307
/// Result of SVD
318
pub struct SVDOutput<A: Scalar> {
329
/// diagonal values
@@ -55,24 +32,26 @@ macro_rules! impl_svd {
5532
impl SVD_ for $scalar {
5633
fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self],) -> Result<SVDOutput<Self>> {
5734
let ju = match l {
58-
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u),
59-
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt),
35+
MatrixLayout::F { .. } => JobSvd::from_bool(calc_u),
36+
MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt),
6037
};
6138
let jvt = match l {
62-
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt),
63-
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u),
39+
MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt),
40+
MatrixLayout::C { .. } => JobSvd::from_bool(calc_u),
6441
};
6542

6643
let m = l.lda();
6744
let mut u = match ju {
68-
FlagSVD::All => Some(unsafe { vec_uninit( (m * m) as usize) }),
69-
FlagSVD::No => None,
45+
JobSvd::All => Some(unsafe { vec_uninit( (m * m) as usize) }),
46+
JobSvd::None => None,
47+
_ => unimplemented!("SVD with partial vector output is not supported yet")
7048
};
7149

7250
let n = l.len();
7351
let mut vt = match jvt {
74-
FlagSVD::All => Some(unsafe { vec_uninit( (n * n) as usize) }),
75-
FlagSVD::No => None,
52+
JobSvd::All => Some(unsafe { vec_uninit( (n * n) as usize) }),
53+
JobSvd::None => None,
54+
_ => unimplemented!("SVD with partial vector output is not supported yet")
7655
};
7756

7857
let k = std::cmp::min(m, n);

0 commit comments

Comments
 (0)