Skip to content

Commit 0ed6fd2

Browse files
committed
New API, more clear and more efficient.
1 parent 0c4f7e7 commit 0ed6fd2

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

src/linalg/impl_linalg.rs

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -817,17 +817,16 @@ mod blas_tests {
817817
}
818818

819819
#[allow(dead_code)]
820-
fn general_outer_to_dyn<Sa, Sb, I, F, T>(
820+
fn general_outer_to_dyn<Sa, Sb, F, T>(
821821
a: &ArrayBase<Sa, IxDyn>,
822-
b: &ArrayBase<Sb, I>,
822+
b: &ArrayBase<Sb, IxDyn>,
823823
f: F,
824824
) -> ArrayD<T>
825825
where
826826
T: Copy,
827827
Sa: Data<Elem = T>,
828828
Sb: Data<Elem = T>,
829-
I: Dimension,
830-
F: Fn(ArrayViewMut<T, IxDyn>, T, &ArrayBase<Sb, I>) -> (),
829+
F: Fn(T, T) -> T,
831830
{
832831
//Iterators on the shapes, compelted by 1s
833832
let a_shape_iter = a.shape().iter().chain([1].iter().cycle());
@@ -843,25 +842,24 @@ where
843842
unsafe {
844843
let mut res: ArrayD<T> = ArrayBase::uninitialized(res_dim);
845844
let res_chunks = res.exact_chunks_mut(b.shape());
846-
Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| f(res_chunk, a_elem, b));
845+
Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| {
846+
Zip::from(res_chunk)
847+
.and(b)
848+
.apply(|res_elem, &b_elem| *res_elem = f(a_elem, b_elem))
849+
});
847850
res
848851
}
849852
}
850853

851854
#[allow(dead_code, clippy::type_repetition_in_bounds)]
852-
fn kron_to_dyn<Sa, I, Sb, T>(a: &ArrayBase<Sa, IxDyn>, b: &ArrayBase<Sb, I>) -> Array<T, IxDyn>
855+
fn kron_to_dyn<Sa, Sb, T>(a: &ArrayBase<Sa, IxDyn>, b: &ArrayBase<Sb, IxDyn>) -> Array<T, IxDyn>
853856
where
854857
T: Copy,
855858
Sa: Data<Elem = T>,
856859
Sb: Data<Elem = T>,
857-
I: Dimension,
858-
T: crate::ScalarOperand + std::ops::MulAssign,
859-
for<'a> &'a ArrayBase<Sb, I>: std::ops::Mul<T, Output = Array<T, I>>,
860+
T: crate::ScalarOperand + std::ops::Mul<Output = T>,
860861
{
861-
general_outer_to_dyn(a, b, |mut res, x, a| {
862-
res.assign(a);
863-
res *= x
864-
})
862+
general_outer_to_dyn(a, b, std::ops::Mul::mul)
865863
}
866864

867865
#[allow(dead_code)]
@@ -875,7 +873,7 @@ where
875873
Sa: Data<Elem = T>,
876874
Sb: Data<Elem = T>,
877875
I: Dimension,
878-
F: Fn(ArrayViewMut<T, I>, T, &ArrayBase<Sb, I>) -> (),
876+
F: Fn(T, T) -> T,
879877
{
880878
let mut res_dim = a.raw_dim();
881879
let mut res_dim_view = res_dim.as_array_view_mut();
@@ -884,7 +882,11 @@ where
884882
unsafe {
885883
let mut res: Array<T, I> = ArrayBase::uninitialized(res_dim);
886884
let res_chunks = res.exact_chunks_mut(b.raw_dim());
887-
Zip::from(res_chunks).and(a).apply(|r_c, &x| f(r_c, x, b));
885+
Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| {
886+
Zip::from(res_chunk)
887+
.and(b)
888+
.apply(|r_elem, &b_elem| *r_elem = f(a_elem, b_elem))
889+
});
888890
res
889891
}
890892
}
@@ -896,13 +898,9 @@ where
896898
Sa: Data<Elem = T>,
897899
Sb: Data<Elem = T>,
898900
I: Dimension,
899-
T: crate::ScalarOperand + std::ops::MulAssign,
900-
for<'a> &'a ArrayBase<Sb, I>: std::ops::Mul<T, Output = Array<T, I>>,
901+
T: crate::ScalarOperand + std::ops::Mul<Output = T>,
901902
{
902-
general_outer_same_size(a, b, |mut res, x, a| {
903-
res.assign(&a);
904-
res *= x
905-
})
903+
general_outer_same_size(a, b, std::ops::Mul::mul)
906904
}
907905

908906
#[cfg(test)]
@@ -922,7 +920,7 @@ mod kron_test {
922920
[[110, 0, 7], [523, 21, -12]]
923921
];
924922
let res1 = kron_same_size(&a, &b);
925-
let res2 = kron_to_dyn(&a.clone().into_dyn(), &b);
923+
let res2 = kron_to_dyn(&a.clone().into_dyn(), &b.clone().into_dyn());
926924
assert_eq!(res1.clone().into_dyn(), res2);
927925
for a0 in 0..a.len_of(Axis(0)) {
928926
for a1 in 0..a.len_of(Axis(1)) {

0 commit comments

Comments
 (0)