Skip to content

Commit 24bef8f

Browse files
authored
Merge pull request #258 from kngwyu/ndarray-014
Bump ndarray version to 0.14
2 parents d96c02a + 15bc1ad commit 24bef8f

File tree

5 files changed

+32
-54
lines changed

5 files changed

+32
-54
lines changed

ndarray-linalg/Cargo.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ rand = "0.5"
3636
thiserror = "1.0.20"
3737

3838
[dependencies.ndarray]
39-
version = "0.13.0"
39+
version = "0.14"
4040
features = ["blas", "approx"]
4141
default-features = false
4242

@@ -46,9 +46,10 @@ path = "../lax"
4646
default-features = false
4747

4848
[dev-dependencies]
49-
paste = "0.1.9"
50-
criterion = "0.3.1"
51-
approx = { version = "0.3.2", features = ["num-complex"] }
49+
paste = "1.0"
50+
criterion = "0.3"
51+
# Keep the same version as ndarray's dependency!
52+
approx = { version = "0.4", features = ["num-complex"] }
5253

5354
[[bench]]
5455
name = "truncated_eig"

ndarray-linalg/src/generate.rs

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,8 @@ where
110110
A: Scalar,
111111
S: Data<Elem = A>,
112112
{
113-
let views: Vec<_> = xs
114-
.iter()
115-
.map(|x| {
116-
let n = x.len();
117-
x.view().into_shape((n, 1)).unwrap()
118-
})
119-
.collect();
120-
stack(Axis(1), &views).map_err(|e| e.into())
113+
let views: Vec<_> = xs.iter().map(|x| x.view()).collect();
114+
stack(Axis(1), &views).map_err(Into::into)
121115
}
122116

123117
/// stack vectors into matrix vertically
@@ -126,12 +120,6 @@ where
126120
A: Scalar,
127121
S: Data<Elem = A>,
128122
{
129-
let views: Vec<_> = xs
130-
.iter()
131-
.map(|x| {
132-
let n = x.len();
133-
x.view().into_shape((1, n)).unwrap()
134-
})
135-
.collect();
136-
stack(Axis(0), &views).map_err(|e| e.into())
123+
let views: Vec<_> = xs.iter().map(|x| x.view()).collect();
124+
stack(Axis(0), &views).map_err(Into::into)
137125
}

ndarray-linalg/src/lobpcg/eig.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,11 @@ impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> Iterator
138138

139139
// add the new eigenvector to the internal constrain matrix
140140
let new_constraints = if let Some(ref constraints) = self.eig.constraints {
141-
let eigvecs_arr = constraints
141+
let eigvecs_arr: Vec<_> = constraints
142142
.gencolumns()
143143
.into_iter()
144144
.chain(vecs.gencolumns().into_iter())
145-
.map(|x| x.insert_axis(Axis(1)))
146-
.collect::<Vec<_>>();
145+
.collect();
147146

148147
stack(Axis(1), &eigvecs_arr).unwrap()
149148
} else {

ndarray-linalg/src/lobpcg/lobpcg.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -354,17 +354,17 @@ pub fn lobpcg<
354354
};
355355

356356
sorted_eig(
357-
stack![
357+
concatenate![
358358
Axis(0),
359-
stack![Axis(1), xax, xar, xap],
360-
stack![Axis(1), xar.t(), rar, rap],
361-
stack![Axis(1), xap.t(), rap.t(), pap]
359+
concatenate![Axis(1), xax, xar, xap],
360+
concatenate![Axis(1), xar.t(), rar, rap],
361+
concatenate![Axis(1), xap.t(), rap.t(), pap]
362362
],
363-
Some(stack![
363+
Some(concatenate![
364364
Axis(0),
365-
stack![Axis(1), xx, xr, xp],
366-
stack![Axis(1), xr.t(), rr, rp],
367-
stack![Axis(1), xp.t(), rp.t(), pp]
365+
concatenate![Axis(1), xx, xr, xp],
366+
concatenate![Axis(1), xr.t(), rr, rp],
367+
concatenate![Axis(1), xp.t(), rp.t(), pp]
368368
]),
369369
size_x,
370370
&order,
@@ -374,15 +374,15 @@ pub fn lobpcg<
374374
p_ap = None;
375375

376376
sorted_eig(
377-
stack![
377+
concatenate![
378378
Axis(0),
379-
stack![Axis(1), xax, xar],
380-
stack![Axis(1), xar.t(), rar]
379+
concatenate![Axis(1), xax, xar],
380+
concatenate![Axis(1), xar.t(), rar]
381381
],
382-
Some(stack![
382+
Some(concatenate![
383383
Axis(0),
384-
stack![Axis(1), xx, xr],
385-
stack![Axis(1), xr.t(), rr]
384+
concatenate![Axis(1), xx, xr],
385+
concatenate![Axis(1), xr.t(), rr]
386386
]),
387387
size_x,
388388
&order,

ndarray-linalg/src/opnorm.rs

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
use lax::Tridiagonal;
44
use ndarray::*;
55

6-
use crate::convert::*;
76
use crate::error::*;
87
use crate::layout::*;
98
use crate::types::*;
@@ -71,10 +70,9 @@ where
7170
NormType::One => {
7271
let zl: Array1<A> = Array::zeros(1);
7372
let zu: Array1<A> = Array::zeros(1);
74-
let dl = stack![Axis(0), self.dl.to_owned(), zl];
75-
let du = stack![Axis(0), zu, self.du.to_owned()];
76-
let arr = stack![Axis(0), into_row(du), into_row(arr1(&self.d)), into_row(dl)];
77-
arr
73+
let dl = concatenate![Axis(0), &self.dl, zl]; // n
74+
let du = concatenate![Axis(0), zu, &self.du]; // n
75+
stack![Axis(0), du, &self.d, dl] // 3 x n
7876
}
7977
// opnorm_inf() calculates muximum row sum.
8078
// Therefore, This part align the rows and make a (n x 3) matrix like,
@@ -86,26 +84,18 @@ where
8684
NormType::Infinity => {
8785
let zl: Array1<A> = Array::zeros(1);
8886
let zu: Array1<A> = Array::zeros(1);
89-
let dl = stack![Axis(0), zl, self.dl.to_owned()];
90-
let du = stack![Axis(0), self.du.to_owned(), zu];
91-
let arr = stack![Axis(1), into_col(dl), into_col(arr1(&self.d)), into_col(du)];
92-
arr
87+
let dl = concatenate![Axis(0), zl, &self.dl]; // n
88+
let du = concatenate![Axis(0), &self.du, zu]; // n
89+
stack![Axis(1), dl, &self.d, du] // n x 3
9390
}
9491
// opnorm_fro() calculates square root of sum of squares.
9592
// Because it is independent of the shape of matrix,
9693
// this part make a (1 x (3n-2)) matrix like,
9794
// [l1, ..., l{n-1}, d0, ..., d{n-1}, u1, ..., u{n-1}]
9895
NormType::Frobenius => {
99-
let arr = stack![
100-
Axis(1),
101-
into_row(arr1(&self.dl)),
102-
into_row(arr1(&self.d)),
103-
into_row(arr1(&self.du))
104-
];
105-
arr
96+
concatenate![Axis(0), &self.dl, &self.d, &self.du].insert_axis(Axis(0))
10697
}
10798
};
108-
10999
let l = arr.layout()?;
110100
let a = arr.as_allocated()?;
111101
Ok(A::opnorm(t, l, a))

0 commit comments

Comments
 (0)