@@ -817,17 +817,16 @@ mod blas_tests {
817
817
}
818
818
819
819
#[ allow( dead_code) ]
820
- fn general_outer_to_dyn < Sa , Sb , I , F , T > (
820
+ fn general_outer_to_dyn < Sa , Sb , F , T > (
821
821
a : & ArrayBase < Sa , IxDyn > ,
822
- b : & ArrayBase < Sb , I > ,
822
+ b : & ArrayBase < Sb , IxDyn > ,
823
823
f : F ,
824
824
) -> ArrayD < T >
825
825
where
826
826
T : Copy ,
827
827
Sa : Data < Elem = T > ,
828
828
Sb : Data < Elem = T > ,
829
- I : Dimension ,
830
- F : Fn ( ArrayViewMut < T , IxDyn > , T , & ArrayBase < Sb , I > ) -> ( ) ,
829
+ F : Fn ( T , T ) -> T ,
831
830
{
832
831
//Iterators on the shapes, compelted by 1s
833
832
let a_shape_iter = a. shape ( ) . iter ( ) . chain ( [ 1 ] . iter ( ) . cycle ( ) ) ;
@@ -843,25 +842,24 @@ where
843
842
unsafe {
844
843
let mut res: ArrayD < T > = ArrayBase :: uninitialized ( res_dim) ;
845
844
let res_chunks = res. exact_chunks_mut ( b. shape ( ) ) ;
846
- Zip :: from ( res_chunks) . and ( a) . apply ( |res_chunk, & a_elem| f ( res_chunk, a_elem, b) ) ;
845
+ Zip :: from ( res_chunks) . and ( a) . apply ( |res_chunk, & a_elem| {
846
+ Zip :: from ( res_chunk)
847
+ . and ( b)
848
+ . apply ( |res_elem, & b_elem| * res_elem = f ( a_elem, b_elem) )
849
+ } ) ;
847
850
res
848
851
}
849
852
}
850
853
851
854
#[ allow( dead_code, clippy:: type_repetition_in_bounds) ]
852
- fn kron_to_dyn < Sa , I , Sb , T > ( a : & ArrayBase < Sa , IxDyn > , b : & ArrayBase < Sb , I > ) -> Array < T , IxDyn >
855
+ fn kron_to_dyn < Sa , Sb , T > ( a : & ArrayBase < Sa , IxDyn > , b : & ArrayBase < Sb , IxDyn > ) -> Array < T , IxDyn >
853
856
where
854
857
T : Copy ,
855
858
Sa : Data < Elem = T > ,
856
859
Sb : Data < Elem = T > ,
857
- I : Dimension ,
858
- T : crate :: ScalarOperand + std:: ops:: MulAssign ,
859
- for < ' a > & ' a ArrayBase < Sb , I > : std:: ops:: Mul < T , Output = Array < T , I > > ,
860
+ T : crate :: ScalarOperand + std:: ops:: Mul < Output = T > ,
860
861
{
861
- general_outer_to_dyn ( a, b, |mut res, x, a| {
862
- res. assign ( a) ;
863
- res *= x
864
- } )
862
+ general_outer_to_dyn ( a, b, std:: ops:: Mul :: mul)
865
863
}
866
864
867
865
#[ allow( dead_code) ]
@@ -875,7 +873,7 @@ where
875
873
Sa : Data < Elem = T > ,
876
874
Sb : Data < Elem = T > ,
877
875
I : Dimension ,
878
- F : Fn ( ArrayViewMut < T , I > , T , & ArrayBase < Sb , I > ) -> ( ) ,
876
+ F : Fn ( T , T ) -> T ,
879
877
{
880
878
let mut res_dim = a. raw_dim ( ) ;
881
879
let mut res_dim_view = res_dim. as_array_view_mut ( ) ;
@@ -884,7 +882,11 @@ where
884
882
unsafe {
885
883
let mut res: Array < T , I > = ArrayBase :: uninitialized ( res_dim) ;
886
884
let res_chunks = res. exact_chunks_mut ( b. raw_dim ( ) ) ;
887
- Zip :: from ( res_chunks) . and ( a) . apply ( |r_c, & x| f ( r_c, x, b) ) ;
885
+ Zip :: from ( res_chunks) . and ( a) . apply ( |res_chunk, & a_elem| {
886
+ Zip :: from ( res_chunk)
887
+ . and ( b)
888
+ . apply ( |r_elem, & b_elem| * r_elem = f ( a_elem, b_elem) )
889
+ } ) ;
888
890
res
889
891
}
890
892
}
@@ -896,13 +898,9 @@ where
896
898
Sa : Data < Elem = T > ,
897
899
Sb : Data < Elem = T > ,
898
900
I : Dimension ,
899
- T : crate :: ScalarOperand + std:: ops:: MulAssign ,
900
- for < ' a > & ' a ArrayBase < Sb , I > : std:: ops:: Mul < T , Output = Array < T , I > > ,
901
+ T : crate :: ScalarOperand + std:: ops:: Mul < Output = T > ,
901
902
{
902
- general_outer_same_size ( a, b, |mut res, x, a| {
903
- res. assign ( & a) ;
904
- res *= x
905
- } )
903
+ general_outer_same_size ( a, b, std:: ops:: Mul :: mul)
906
904
}
907
905
908
906
#[ cfg( test) ]
@@ -922,7 +920,7 @@ mod kron_test {
922
920
[ [ 110 , 0 , 7 ] , [ 523 , 21 , -12 ] ]
923
921
] ;
924
922
let res1 = kron_same_size ( & a, & b) ;
925
- let res2 = kron_to_dyn ( & a. clone ( ) . into_dyn ( ) , & b) ;
923
+ let res2 = kron_to_dyn ( & a. clone ( ) . into_dyn ( ) , & b. clone ( ) . into_dyn ( ) ) ;
926
924
assert_eq ! ( res1. clone( ) . into_dyn( ) , res2) ;
927
925
for a0 in 0 ..a. len_of ( Axis ( 0 ) ) {
928
926
for a1 in 0 ..a. len_of ( Axis ( 1 ) ) {
0 commit comments