Skip to content

Commit 6bbe87d

Browse files
committed
Add tests for multi-column b
1 parent 05a2bcd commit 6bbe87d

File tree

1 file changed

+189
-0
lines changed

1 file changed

+189
-0
lines changed
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
/// Solve least square problem `|b - Ax|` with multi-column `b`
2+
use approx::AbsDiffEq;
3+
use ndarray::*;
4+
use ndarray_linalg::*;
5+
6+
/// A is square. `x = A^{-1} b`, `|b - Ax| = 0`
7+
fn test_exact<T: Scalar + Lapack>(a: Array2<T>, b: Array2<T>) {
8+
assert_eq!(a.layout().unwrap().size(), (3, 3));
9+
assert_eq!(b.layout().unwrap().size(), (3, 2));
10+
11+
let result = a.least_squares(&b).unwrap();
12+
// unpack result
13+
let x: Array2<T> = result.solution;
14+
let residual_l2_square: Array1<T::Real> = result.residual_sum_of_squares.unwrap();
15+
16+
// must be full-rank
17+
assert_eq!(result.rank, 3);
18+
19+
// |b - Ax| == 0
20+
for &residual in &residual_l2_square {
21+
assert!(residual < T::real(1.0e-4));
22+
}
23+
24+
// b == Ax
25+
let ax = a.dot(&x);
26+
assert_close_l2!(&b, &ax, T::real(1.0e-4));
27+
}
28+
29+
macro_rules! impl_exact {
30+
($scalar:ty) => {
31+
paste::item! {
32+
#[test]
33+
fn [<least_squares_ $scalar _exact_ac_bc>]() {
34+
let a: Array2<f64> = random((3, 3));
35+
let b: Array2<f64> = random((3, 2));
36+
test_exact(a, b)
37+
}
38+
39+
#[test]
40+
fn [<least_squares_ $scalar _exact_ac_bf>]() {
41+
let a: Array2<f64> = random((3, 3));
42+
let b: Array2<f64> = random((3, 2).f());
43+
test_exact(a, b)
44+
}
45+
46+
#[test]
47+
fn [<least_squares_ $scalar _exact_af_bc>]() {
48+
let a: Array2<f64> = random((3, 3).f());
49+
let b: Array2<f64> = random((3, 2));
50+
test_exact(a, b)
51+
}
52+
53+
#[test]
54+
fn [<least_squares_ $scalar _exact_af_bf>]() {
55+
let a: Array2<f64> = random((3, 3).f());
56+
let b: Array2<f64> = random((3, 2).f());
57+
test_exact(a, b)
58+
}
59+
}
60+
};
61+
}
62+
63+
impl_exact!(f32);
64+
impl_exact!(f64);
65+
impl_exact!(c32);
66+
impl_exact!(c64);
67+
68+
/// #column < #row case.
69+
/// Linear problem is overdetermined, `|b - Ax| > 0`.
70+
fn test_overdetermined<T: Scalar + Lapack>(a: Array2<T>, bs: Array2<T>)
71+
where
72+
T::Real: AbsDiffEq<Epsilon = T::Real>,
73+
{
74+
assert_eq!(a.layout().unwrap().size(), (4, 3));
75+
assert_eq!(bs.layout().unwrap().size(), (4, 2));
76+
77+
let result = a.least_squares(&bs).unwrap();
78+
// unpack result
79+
let xs = result.solution;
80+
let residual_l2_square = result.residual_sum_of_squares.unwrap();
81+
82+
// Must be full-rank
83+
assert_eq!(result.rank, 3);
84+
85+
for j in 0..2 {
86+
let b = bs.index_axis(Axis(1), j);
87+
let x = xs.index_axis(Axis(1), j);
88+
let residual = &b - &a.dot(&x);
89+
let residual_l2_sq = residual_l2_square[j];
90+
assert!(residual_l2_sq.abs_diff_eq(&residual.norm_l2().powi(2), T::real(1.0e-4)));
91+
92+
// `|residual| < |b|`
93+
assert!(residual.norm_l2() < b.norm_l2());
94+
}
95+
}
96+
97+
macro_rules! impl_overdetermined {
98+
($scalar:ty) => {
99+
paste::item! {
100+
#[test]
101+
fn [<least_squares_ $scalar _overdetermined_ac_bc>]() {
102+
let a: Array2<f64> = random((4, 3));
103+
let b: Array2<f64> = random((4, 2));
104+
test_overdetermined(a, b)
105+
}
106+
107+
#[test]
108+
fn [<least_squares_ $scalar _overdetermined_af_bc>]() {
109+
let a: Array2<f64> = random((4, 3).f());
110+
let b: Array2<f64> = random((4, 2));
111+
test_overdetermined(a, b)
112+
}
113+
114+
#[test]
115+
fn [<least_squares_ $scalar _overdetermined_ac_bf>]() {
116+
let a: Array2<f64> = random((4, 3));
117+
let b: Array2<f64> = random((4, 2).f());
118+
test_overdetermined(a, b)
119+
}
120+
121+
#[test]
122+
fn [<least_squares_ $scalar _overdetermined_af_bf>]() {
123+
let a: Array2<f64> = random((4, 3).f());
124+
let b: Array2<f64> = random((4, 2).f());
125+
test_overdetermined(a, b)
126+
}
127+
}
128+
};
129+
}
130+
131+
impl_overdetermined!(f32);
132+
impl_overdetermined!(f64);
133+
impl_overdetermined!(c32);
134+
impl_overdetermined!(c64);
135+
136+
/// #column > #row case.
137+
/// Linear problem is underdetermined, `|b - Ax| = 0` and `x` is not unique
138+
fn test_underdetermined<T: Scalar + Lapack>(a: Array2<T>, b: Array2<T>) {
139+
assert_eq!(a.layout().unwrap().size(), (3, 4));
140+
assert_eq!(b.layout().unwrap().size(), (3, 2));
141+
142+
let result = a.least_squares(&b).unwrap();
143+
assert_eq!(result.rank, 3);
144+
assert!(result.residual_sum_of_squares.is_none());
145+
146+
// b == Ax
147+
let x = result.solution;
148+
let ax = a.dot(&x);
149+
assert_close_l2!(&b, &ax, T::real(1.0e-4));
150+
}
151+
152+
macro_rules! impl_underdetermined {
153+
($scalar:ty) => {
154+
paste::item! {
155+
#[test]
156+
fn [<least_squares_ $scalar _underdetermined_ac_bc>]() {
157+
let a: Array2<f64> = random((3, 4));
158+
let b: Array2<f64> = random((3, 2));
159+
test_underdetermined(a, b)
160+
}
161+
162+
#[test]
163+
fn [<least_squares_ $scalar _underdetermined_af_bc>]() {
164+
let a: Array2<f64> = random((3, 4).f());
165+
let b: Array2<f64> = random((3, 2));
166+
test_underdetermined(a, b)
167+
}
168+
169+
#[test]
170+
fn [<least_squares_ $scalar _underdetermined_ac_bf>]() {
171+
let a: Array2<f64> = random((3, 4));
172+
let b: Array2<f64> = random((3, 2).f());
173+
test_underdetermined(a, b)
174+
}
175+
176+
#[test]
177+
fn [<least_squares_ $scalar _underdetermined_af_bf>]() {
178+
let a: Array2<f64> = random((3, 4).f());
179+
let b: Array2<f64> = random((3, 2).f());
180+
test_underdetermined(a, b)
181+
}
182+
}
183+
};
184+
}
185+
186+
impl_underdetermined!(f32);
187+
impl_underdetermined!(f64);
188+
impl_underdetermined!(c32);
189+
impl_underdetermined!(c64);

0 commit comments

Comments
 (0)