@@ -518,12 +518,17 @@ where
518518} 
519519
520520macro_rules!  all_reduce_func_def { 
521-     ( $doc_str:  expr,  $fn_name:  ident,  $ffi_name:  ident,  $out_type : ty )  => { 
521+     ( $doc_str:  expr,  $fn_name:  ident,  $ffi_name:  ident,  $assoc_type : ident )  => { 
522522        #[ doc=$doc_str] 
523-         pub  fn  $fn_name<T >( input:  & Array <T >)  -> ( $out_type,  $out_type) 
523+         pub  fn  $fn_name<T >( input:  & Array <T >) 
524+             -> ( 
525+                 <<T  as  HasAfEnum >:: $assoc_type as  HasAfEnum >:: BaseType , 
526+                 <<T  as  HasAfEnum >:: $assoc_type as  HasAfEnum >:: BaseType 
527+                ) 
524528        where 
525529            T :  HasAfEnum , 
526-             $out_type:  HasAfEnum  + Fromf64 
530+             <T  as  HasAfEnum >:: $assoc_type:  HasAfEnum , 
531+             <<T  as  HasAfEnum >:: $assoc_type as  HasAfEnum >:: BaseType :  HasAfEnum  + Fromf64 , 
527532        { 
528533            let  mut  real:  f64  = 0.0 ; 
529534            let  mut  imag:  f64  = 0.0 ; 
@@ -533,7 +538,10 @@ macro_rules! all_reduce_func_def {
533538                ) ; 
534539                HANDLE_ERROR ( AfError :: from( err_val) ) ; 
535540            } 
536-             ( <$out_type>:: fromf64( real) ,  <$out_type>:: fromf64( imag) ) 
541+             ( 
542+                 <<T  as  HasAfEnum >:: $assoc_type as  HasAfEnum >:: BaseType :: fromf64( real) , 
543+                 <<T  as  HasAfEnum >:: $assoc_type as  HasAfEnum >:: BaseType :: fromf64( imag) , 
544+             ) 
537545        } 
538546    } ; 
539547} 
@@ -564,7 +572,7 @@ all_reduce_func_def!(
564572    " , 
565573    sum_all, 
566574    af_sum_all, 
567-     T :: AggregateOutType 
575+     AggregateOutType 
568576) ; 
569577
570578all_reduce_func_def ! ( 
@@ -594,7 +602,7 @@ all_reduce_func_def!(
594602    " , 
595603    product_all, 
596604    af_product_all, 
597-     T :: ProductOutType 
605+     ProductOutType 
598606) ; 
599607
600608all_reduce_func_def ! ( 
@@ -623,7 +631,7 @@ all_reduce_func_def!(
623631    " , 
624632    min_all, 
625633    af_min_all, 
626-     T :: InType 
634+     InType 
627635) ; 
628636
629637all_reduce_func_def ! ( 
@@ -652,10 +660,31 @@ all_reduce_func_def!(
652660    " , 
653661    max_all, 
654662    af_max_all, 
655-     T :: InType 
663+     InType 
656664) ; 
657665
658- all_reduce_func_def ! ( 
666+ macro_rules!  all_reduce_func_def2 { 
667+     ( $doc_str:  expr,  $fn_name:  ident,  $ffi_name:  ident,  $out_type: ty)  => { 
668+         #[ doc=$doc_str] 
669+         pub  fn  $fn_name<T >( input:  & Array <T >)  -> ( $out_type,  $out_type) 
670+         where 
671+             T :  HasAfEnum , 
672+             $out_type:  HasAfEnum  + Fromf64 
673+         { 
674+             let  mut  real:  f64  = 0.0 ; 
675+             let  mut  imag:  f64  = 0.0 ; 
676+             unsafe  { 
677+                 let  err_val = $ffi_name( 
678+                     & mut  real as  * mut  c_double,  & mut  imag as  * mut  c_double,  input. get( ) , 
679+                 ) ; 
680+                 HANDLE_ERROR ( AfError :: from( err_val) ) ; 
681+             } 
682+             ( <$out_type>:: fromf64( real) ,  <$out_type>:: fromf64( imag) ) 
683+         } 
684+     } ; 
685+ } 
686+ 
687+ all_reduce_func_def2 ! ( 
659688    " 
660689    Find if all values of Array are non-zero 
661690
@@ -682,7 +711,7 @@ all_reduce_func_def!(
682711    bool 
683712) ; 
684713
685- all_reduce_func_def ! ( 
714+ all_reduce_func_def2 ! ( 
686715    " 
687716    Find if any value of Array is non-zero 
688717
@@ -709,7 +738,7 @@ all_reduce_func_def!(
709738    bool 
710739) ; 
711740
712- all_reduce_func_def ! ( 
741+ all_reduce_func_def2 ! ( 
713742    " 
714743    Count number of non-zero values in the Array 
715744
@@ -751,10 +780,17 @@ all_reduce_func_def!(
751780/// A tuple of summation result. 
752781/// 
753782/// Note: For non-complex data type Arrays, second value of tuple is zero. 
754- pub  fn  sum_nan_all < T > ( input :  & Array < T > ,  val :  f64 )  -> ( T :: AggregateOutType ,  T :: AggregateOutType ) 
783+ pub  fn  sum_nan_all < T > ( 
784+     input :  & Array < T > , 
785+     val :  f64 , 
786+ )  -> ( 
787+     <<T  as  HasAfEnum >:: AggregateOutType  as  HasAfEnum >:: BaseType , 
788+     <<T  as  HasAfEnum >:: AggregateOutType  as  HasAfEnum >:: BaseType , 
789+ ) 
755790where 
756791    T :  HasAfEnum , 
757-     T :: AggregateOutType :  HasAfEnum  + Fromf64 , 
792+     <T  as  HasAfEnum >:: AggregateOutType :  HasAfEnum , 
793+     <<T  as  HasAfEnum >:: AggregateOutType  as  HasAfEnum >:: BaseType :  HasAfEnum  + Fromf64 , 
758794{ 
759795    let  mut  real:  f64  = 0.0 ; 
760796    let  mut  imag:  f64  = 0.0 ; 
@@ -768,8 +804,8 @@ where
768804        HANDLE_ERROR ( AfError :: from ( err_val) ) ; 
769805    } 
770806    ( 
771-         <T :: AggregateOutType > :: fromf64 ( real) , 
772-         <T :: AggregateOutType > :: fromf64 ( imag) , 
807+         << T   as   HasAfEnum > :: AggregateOutType   as   HasAfEnum > :: BaseType :: fromf64 ( real) , 
808+         << T   as   HasAfEnum > :: AggregateOutType   as   HasAfEnum > :: BaseType :: fromf64 ( imag) , 
773809    ) 
774810} 
775811
@@ -788,10 +824,17 @@ where
788824/// A tuple of product result. 
789825/// 
790826/// Note: For non-complex data type Arrays, second value of tuple is zero. 
791- pub  fn  product_nan_all < T > ( input :  & Array < T > ,  val :  f64 )  -> ( T :: ProductOutType ,  T :: ProductOutType ) 
827+ pub  fn  product_nan_all < T > ( 
828+     input :  & Array < T > , 
829+     val :  f64 , 
830+ )  -> ( 
831+     <<T  as  HasAfEnum >:: ProductOutType  as  HasAfEnum >:: BaseType , 
832+     <<T  as  HasAfEnum >:: ProductOutType  as  HasAfEnum >:: BaseType , 
833+ ) 
792834where 
793835    T :  HasAfEnum , 
794-     T :: ProductOutType :  HasAfEnum  + Fromf64 , 
836+     <T  as  HasAfEnum >:: ProductOutType :  HasAfEnum , 
837+     <<T  as  HasAfEnum >:: ProductOutType  as  HasAfEnum >:: BaseType :  HasAfEnum  + Fromf64 , 
795838{ 
796839    let  mut  real:  f64  = 0.0 ; 
797840    let  mut  imag:  f64  = 0.0 ; 
@@ -805,8 +848,8 @@ where
805848        HANDLE_ERROR ( AfError :: from ( err_val) ) ; 
806849    } 
807850    ( 
808-         <T :: ProductOutType > :: fromf64 ( real) , 
809-         <T :: ProductOutType > :: fromf64 ( imag) , 
851+         << T   as   HasAfEnum > :: ProductOutType   as   HasAfEnum > :: BaseType :: fromf64 ( real) , 
852+         << T   as   HasAfEnum > :: ProductOutType   as   HasAfEnum > :: BaseType :: fromf64 ( imag) , 
810853    ) 
811854} 
812855
@@ -858,12 +901,18 @@ dim_ireduce_func_def!("
858901    " ,  imax,  af_imax,  InType ) ; 
859902
860903macro_rules!  all_ireduce_func_def { 
861-     ( $doc_str:  expr,  $fn_name:  ident,  $ffi_name:  ident,  $out_type : ty )  => { 
904+     ( $doc_str:  expr,  $fn_name:  ident,  $ffi_name:  ident,  $assoc_type : ident )  => { 
862905        #[ doc=$doc_str] 
863-         pub  fn  $fn_name<T >( input:  & Array <T >)  -> ( $out_type,  $out_type,  u32 ) 
906+         pub  fn  $fn_name<T >( input:  & Array <T >) 
907+             -> ( 
908+                 <<T  as  HasAfEnum >:: $assoc_type as  HasAfEnum >:: BaseType , 
909+                 <<T  as  HasAfEnum >:: $assoc_type as  HasAfEnum >:: BaseType , 
910+                 u32 
911+                ) 
864912        where 
865913            T :  HasAfEnum , 
866-             $out_type:  HasAfEnum  + Fromf64 
914+             <T  as  HasAfEnum >:: $assoc_type:  HasAfEnum , 
915+             <<T  as  HasAfEnum >:: $assoc_type as  HasAfEnum >:: BaseType :  HasAfEnum  + Fromf64 , 
867916        { 
868917            let  mut  real:  f64  = 0.0 ; 
869918            let  mut  imag:  f64  = 0.0 ; 
@@ -875,7 +924,11 @@ macro_rules! all_ireduce_func_def {
875924                ) ; 
876925                HANDLE_ERROR ( AfError :: from( err_val) ) ; 
877926            } 
878-             ( <$out_type>:: fromf64( real) ,  <$out_type>:: fromf64( imag) ,  temp) 
927+             ( 
928+                 <<T  as  HasAfEnum >:: $assoc_type as  HasAfEnum >:: BaseType :: fromf64( real) , 
929+                 <<T  as  HasAfEnum >:: $assoc_type as  HasAfEnum >:: BaseType :: fromf64( imag) , 
930+                 temp, 
931+             ) 
879932        } 
880933    } ; 
881934} 
@@ -898,7 +951,7 @@ all_ireduce_func_def!(
898951    " , 
899952    imin_all, 
900953    af_imin_all, 
901-     T :: InType 
954+     InType 
902955) ; 
903956all_ireduce_func_def ! ( 
904957    " 
@@ -918,7 +971,7 @@ all_ireduce_func_def!(
918971    " , 
919972    imax_all, 
920973    af_imax_all, 
921-     T :: InType 
974+     InType 
922975) ; 
923976
924977/// Locate the indices of non-zero elements. 
@@ -1386,3 +1439,40 @@ dim_reduce_by_key_nan_func_def!(
13861439    af_product_by_key_nan, 
13871440    ValueType :: ProductOutType 
13881441) ; 
1442+ 
1443+ #[ cfg( test) ]  
1444+ mod  tests { 
1445+     use  super :: super :: core:: c32; 
1446+     use  super :: { product_nan_all,  sum_all,  sum_nan_all,  imin_all,  imax_all} ; 
1447+     use  crate :: randu; 
1448+ 
1449+     #[ test]  
1450+     fn  all_reduce_api ( )  { 
1451+         let  a = randu ! ( c32;  10 ,  10 ) ; 
1452+         println ! ( "Reduction of complex f32 matrix: {:?}" ,  sum_all( & a) ) ; 
1453+ 
1454+         let  b = randu ! ( bool ;  10 ,  10 ) ; 
1455+         println ! ( "reduction of bool matrix: {:?}" ,  sum_all( & b) ) ; 
1456+ 
1457+         println ! ( 
1458+             "reduction of complex f32 matrix after replacing nan with {}: {:?}" , 
1459+             1.0 , 
1460+             product_nan_all( & a,  1.0 ) 
1461+         ) ; 
1462+ 
1463+         println ! ( 
1464+             "reduction of bool matrix after replacing nan with {}: {:?}" , 
1465+             0.0 , 
1466+             sum_nan_all( & b,  0.0 ) 
1467+         ) ; 
1468+     } 
1469+ 
1470+     #[ test]  
1471+     fn  all_ireduce_api ( )  { 
1472+         let  a = randu ! ( c32;  10 ) ; 
1473+         println ! ( "Reduction of complex f32 matrix: {:?}" ,  imin_all( & a) ) ; 
1474+ 
1475+         let  b = randu ! ( u32 ;  10 ) ; 
1476+         println ! ( "reduction of bool matrix: {:?}" ,  imax_all( & b) ) ; 
1477+     } 
1478+ } 
0 commit comments