Skip to content

Commit 3bf9387

Browse files
authored
Merge pull request #290 from jturner314/check-compatible-shapes
Add checks for matching shapes in Solve, SolveH, and EighInplace
2 parents db13838 + d7e5672 commit 3bf9387

File tree

6 files changed

+172
-0
lines changed

6 files changed

+172
-0
lines changed

ndarray-linalg/src/eigh.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,17 @@ where
144144
{
145145
type EigVal = Array1<A::Real>;
146146

147+
/// Solves the generalized eigenvalue problem.
148+
///
149+
/// # Panics
150+
///
151+
/// Panics if the shapes of the matrices are different.
147152
fn eigh_inplace(&mut self, uplo: UPLO) -> Result<(Self::EigVal, &mut Self)> {
153+
assert_eq!(
154+
self.0.shape(),
155+
self.1.shape(),
156+
"The shapes of the matrices must be identical.",
157+
);
148158
let layout = self.0.square_layout()?;
149159
// XXX Force layout to be Fortran (see #146)
150160
match layout {

ndarray-linalg/src/solve.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,59 +77,103 @@ pub use lax::{Pivot, Transpose};
7777
pub trait Solve<A: Scalar> {
7878
/// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
7979
/// is the argument, and `x` is the successful result.
80+
///
81+
/// # Panics
82+
///
83+
/// Panics if the length of `b` is not the equal to the number of columns
84+
/// of `A`.
8085
fn solve<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
8186
let mut b = replicate(b);
8287
self.solve_inplace(&mut b)?;
8388
Ok(b)
8489
}
90+
8591
/// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
8692
/// is the argument, and `x` is the successful result.
93+
///
94+
/// # Panics
95+
///
96+
/// Panics if the length of `b` is not the equal to the number of columns
97+
/// of `A`.
8798
fn solve_into<S: DataMut<Elem = A>>(
8899
&self,
89100
mut b: ArrayBase<S, Ix1>,
90101
) -> Result<ArrayBase<S, Ix1>> {
91102
self.solve_inplace(&mut b)?;
92103
Ok(b)
93104
}
105+
94106
/// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
95107
/// is the argument, and `x` is the successful result.
108+
///
109+
/// # Panics
110+
///
111+
/// Panics if the length of `b` is not the equal to the number of columns
112+
/// of `A`.
96113
fn solve_inplace<'a, S: DataMut<Elem = A>>(
97114
&self,
98115
b: &'a mut ArrayBase<S, Ix1>,
99116
) -> Result<&'a mut ArrayBase<S, Ix1>>;
100117

101118
/// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
102119
/// is the argument, and `x` is the successful result.
120+
///
121+
/// # Panics
122+
///
123+
/// Panics if the length of `b` is not the equal to the number of rows of
124+
/// `A`.
103125
fn solve_t<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
104126
let mut b = replicate(b);
105127
self.solve_t_inplace(&mut b)?;
106128
Ok(b)
107129
}
130+
108131
/// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
109132
/// is the argument, and `x` is the successful result.
133+
///
134+
/// # Panics
135+
///
136+
/// Panics if the length of `b` is not the equal to the number of rows of
137+
/// `A`.
110138
fn solve_t_into<S: DataMut<Elem = A>>(
111139
&self,
112140
mut b: ArrayBase<S, Ix1>,
113141
) -> Result<ArrayBase<S, Ix1>> {
114142
self.solve_t_inplace(&mut b)?;
115143
Ok(b)
116144
}
145+
117146
/// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
118147
/// is the argument, and `x` is the successful result.
148+
///
149+
/// # Panics
150+
///
151+
/// Panics if the length of `b` is not the equal to the number of rows of
152+
/// `A`.
119153
fn solve_t_inplace<'a, S: DataMut<Elem = A>>(
120154
&self,
121155
b: &'a mut ArrayBase<S, Ix1>,
122156
) -> Result<&'a mut ArrayBase<S, Ix1>>;
123157

124158
/// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
125159
/// is the argument, and `x` is the successful result.
160+
///
161+
/// # Panics
162+
///
163+
/// Panics if the length of `b` is not the equal to the number of rows of
164+
/// `A`.
126165
fn solve_h<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
127166
let mut b = replicate(b);
128167
self.solve_h_inplace(&mut b)?;
129168
Ok(b)
130169
}
131170
/// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
132171
/// is the argument, and `x` is the successful result.
172+
///
173+
/// # Panics
174+
///
175+
/// Panics if the length of `b` is not the equal to the number of rows of
176+
/// `A`.
133177
fn solve_h_into<S: DataMut<Elem = A>>(
134178
&self,
135179
mut b: ArrayBase<S, Ix1>,
@@ -139,6 +183,11 @@ pub trait Solve<A: Scalar> {
139183
}
140184
/// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
141185
/// is the argument, and `x` is the successful result.
186+
///
187+
/// # Panics
188+
///
189+
/// Panics if the length of `b` is not the equal to the number of rows of
190+
/// `A`.
142191
fn solve_h_inplace<'a, S: DataMut<Elem = A>>(
143192
&self,
144193
b: &'a mut ArrayBase<S, Ix1>,
@@ -167,6 +216,11 @@ where
167216
where
168217
Sb: DataMut<Elem = A>,
169218
{
219+
assert_eq!(
220+
rhs.len(),
221+
self.a.len_of(Axis(1)),
222+
"The length of `rhs` must be compatible with the shape of the factored matrix.",
223+
);
170224
A::solve(
171225
self.a.square_layout()?,
172226
Transpose::No,
@@ -183,6 +237,11 @@ where
183237
where
184238
Sb: DataMut<Elem = A>,
185239
{
240+
assert_eq!(
241+
rhs.len(),
242+
self.a.len_of(Axis(0)),
243+
"The length of `rhs` must be compatible with the shape of the factored matrix.",
244+
);
186245
A::solve(
187246
self.a.square_layout()?,
188247
Transpose::Transpose,
@@ -199,6 +258,11 @@ where
199258
where
200259
Sb: DataMut<Elem = A>,
201260
{
261+
assert_eq!(
262+
rhs.len(),
263+
self.a.len_of(Axis(0)),
264+
"The length of `rhs` must be compatible with the shape of the factored matrix.",
265+
);
202266
A::solve(
203267
self.a.square_layout()?,
204268
Transpose::Hermite,

ndarray-linalg/src/solveh.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,25 +69,42 @@ pub trait SolveH<A: Scalar> {
6969
/// Solves a system of linear equations `A * x = b` with Hermitian (or real
7070
/// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
7171
/// `x` is the successful result.
72+
///
73+
/// # Panics
74+
///
75+
/// Panics if the length of `b` is not the equal to the number of columns
76+
/// of `A`.
7277
fn solveh<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
7378
let mut b = replicate(b);
7479
self.solveh_inplace(&mut b)?;
7580
Ok(b)
7681
}
82+
7783
/// Solves a system of linear equations `A * x = b` with Hermitian (or real
7884
/// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
7985
/// `x` is the successful result.
86+
///
87+
/// # Panics
88+
///
89+
/// Panics if the length of `b` is not the equal to the number of columns
90+
/// of `A`.
8091
fn solveh_into<S: DataMut<Elem = A>>(
8192
&self,
8293
mut b: ArrayBase<S, Ix1>,
8394
) -> Result<ArrayBase<S, Ix1>> {
8495
self.solveh_inplace(&mut b)?;
8596
Ok(b)
8697
}
98+
8799
/// Solves a system of linear equations `A * x = b` with Hermitian (or real
88100
/// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
89101
/// `x` is the successful result. The value of `x` is also assigned to the
90102
/// argument.
103+
///
104+
/// # Panics
105+
///
106+
/// Panics if the length of `b` is not the equal to the number of columns
107+
/// of `A`.
91108
fn solveh_inplace<'a, S: DataMut<Elem = A>>(
92109
&self,
93110
b: &'a mut ArrayBase<S, Ix1>,
@@ -113,6 +130,11 @@ where
113130
where
114131
Sb: DataMut<Elem = A>,
115132
{
133+
assert_eq!(
134+
rhs.len(),
135+
self.a.len_of(Axis(1)),
136+
"The length of `rhs` must be compatible with the shape of the factored matrix.",
137+
);
116138
A::solveh(
117139
self.a.square_layout()?,
118140
UPLO::Upper,

ndarray-linalg/tests/eigh.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
use ndarray::*;
22
use ndarray_linalg::*;
33

4+
#[should_panic]
5+
#[test]
6+
fn eigh_generalized_shape_mismatch() {
7+
let a = Array2::<f64>::eye(3);
8+
let b = Array2::<f64>::eye(2);
9+
let _ = (a, b).eigh_inplace(UPLO::Upper);
10+
}
11+
412
#[test]
513
fn fixed() {
614
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);

ndarray-linalg/tests/solve.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,14 @@ fn solve_random_complex() {
125125
}
126126
}
127127

128+
#[should_panic]
129+
#[test]
130+
fn solve_shape_mismatch() {
131+
let a: Array2<f64> = random((3, 3));
132+
let b: Array1<f64> = random(2);
133+
let _ = a.solve_into(b);
134+
}
135+
128136
#[test]
129137
fn solve_t_random_float() {
130138
for n in 0..=8 {
@@ -140,6 +148,14 @@ fn solve_t_random_float() {
140148
}
141149
}
142150

151+
#[should_panic]
152+
#[test]
153+
fn solve_t_shape_mismatch() {
154+
let a: Array2<f64> = random((3, 3).f());
155+
let b: Array1<f64> = random(4);
156+
let _ = a.solve_into(b);
157+
}
158+
143159
#[test]
144160
fn solve_t_random_complex() {
145161
for n in 0..=8 {
@@ -155,6 +171,15 @@ fn solve_t_random_complex() {
155171
}
156172
}
157173

174+
#[should_panic]
175+
#[test]
176+
fn solve_factorized_shape_mismatch() {
177+
let a: Array2<f64> = random((3, 3));
178+
let b: Array1<f64> = random(4);
179+
let f = a.factorize_into().unwrap();
180+
let _ = f.solve_into(b);
181+
}
182+
158183
#[test]
159184
fn solve_h_random_float() {
160185
for n in 0..=8 {
@@ -170,6 +195,15 @@ fn solve_h_random_float() {
170195
}
171196
}
172197

198+
#[should_panic]
199+
#[test]
200+
fn solve_factorized_t_shape_mismatch() {
201+
let a: Array2<f64> = random((3, 3).f());
202+
let b: Array1<f64> = random(4);
203+
let f = a.factorize_into().unwrap();
204+
let _ = f.solve_into(b);
205+
}
206+
173207
#[test]
174208
fn solve_h_random_complex() {
175209
for n in 0..=8 {

ndarray-linalg/tests/solveh.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,23 @@
11
use ndarray::*;
22
use ndarray_linalg::*;
33

4+
#[should_panic]
5+
#[test]
6+
fn solveh_shape_mismatch() {
7+
let a: Array2<f64> = random_hpd(3);
8+
let b: Array1<f64> = random(2);
9+
let _ = a.solveh_into(b);
10+
}
11+
12+
#[should_panic]
13+
#[test]
14+
fn factorizeh_solveh_shape_mismatch() {
15+
let a: Array2<f64> = random_hpd(3);
16+
let b: Array1<f64> = random(2);
17+
let f = a.factorizeh_into().unwrap();
18+
let _ = f.solveh_into(b);
19+
}
20+
421
#[test]
522
fn solveh_random() {
623
let a: Array2<f64> = random_hpd(3);
@@ -15,6 +32,23 @@ fn solveh_random() {
1532
assert_close_l2!(&x, &y, 1e-7);
1633
}
1734

35+
#[should_panic]
36+
#[test]
37+
fn solveh_t_shape_mismatch() {
38+
let a: Array2<f64> = random_hpd(3).reversed_axes();
39+
let b: Array1<f64> = random(2);
40+
let _ = a.solveh_into(b);
41+
}
42+
43+
#[should_panic]
44+
#[test]
45+
fn factorizeh_solveh_t_shape_mismatch() {
46+
let a: Array2<f64> = random_hpd(3).reversed_axes();
47+
let b: Array1<f64> = random(2);
48+
let f = a.factorizeh_into().unwrap();
49+
let _ = f.solveh_into(b);
50+
}
51+
1852
#[test]
1953
fn solveh_random_t() {
2054
let a: Array2<f64> = random_hpd(3).reversed_axes();

0 commit comments

Comments
 (0)