1
1
use super :: core:: {
2
- af_array, AfError , Array , BinaryOp , HasAfEnum , RealNumber , ReduceByKeyInput , Scanable ,
2
+ af_array, AfError , Array , BinaryOp , Fromf64 , HasAfEnum , RealNumber , ReduceByKeyInput , Scanable ,
3
3
HANDLE_ERROR ,
4
4
} ;
5
5
@@ -518,9 +518,13 @@ where
518
518
}
519
519
520
520
macro_rules! all_reduce_func_def {
521
- ( $doc_str: expr, $fn_name: ident, $ffi_name: ident) => {
521
+ ( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type : ty ) => {
522
522
#[ doc=$doc_str]
523
- pub fn $fn_name<T : HasAfEnum >( input: & Array <T >) -> ( f64 , f64 ) {
523
+ pub fn $fn_name<T >( input: & Array <T >) -> ( $out_type, $out_type)
524
+ where
525
+ T : HasAfEnum ,
526
+ $out_type: HasAfEnum + Fromf64
527
+ {
524
528
let mut real: f64 = 0.0 ;
525
529
let mut imag: f64 = 0.0 ;
526
530
unsafe {
@@ -529,7 +533,7 @@ macro_rules! all_reduce_func_def {
529
533
) ;
530
534
HANDLE_ERROR ( AfError :: from( err_val) ) ;
531
535
}
532
- ( real, imag)
536
+ ( <$out_type> :: fromf64 ( real) , <$out_type> :: fromf64 ( imag) )
533
537
}
534
538
} ;
535
539
}
@@ -559,7 +563,8 @@ all_reduce_func_def!(
559
563
```
560
564
" ,
561
565
sum_all,
562
- af_sum_all
566
+ af_sum_all,
567
+ T :: AggregateOutType
563
568
) ;
564
569
565
570
all_reduce_func_def ! (
@@ -588,7 +593,8 @@ all_reduce_func_def!(
588
593
```
589
594
" ,
590
595
product_all,
591
- af_product_all
596
+ af_product_all,
597
+ T :: ProductOutType
592
598
) ;
593
599
594
600
all_reduce_func_def ! (
@@ -616,7 +622,8 @@ all_reduce_func_def!(
616
622
```
617
623
" ,
618
624
min_all,
619
- af_min_all
625
+ af_min_all,
626
+ T :: InType
620
627
) ;
621
628
622
629
all_reduce_func_def ! (
@@ -644,7 +651,8 @@ all_reduce_func_def!(
644
651
```
645
652
" ,
646
653
max_all,
647
- af_max_all
654
+ af_max_all,
655
+ T :: InType
648
656
) ;
649
657
650
658
all_reduce_func_def ! (
@@ -670,7 +678,8 @@ all_reduce_func_def!(
670
678
```
671
679
" ,
672
680
all_true_all,
673
- af_all_true_all
681
+ af_all_true_all,
682
+ bool
674
683
) ;
675
684
676
685
all_reduce_func_def ! (
@@ -696,7 +705,8 @@ all_reduce_func_def!(
696
705
```
697
706
" ,
698
707
any_true_all,
699
- af_any_true_all
708
+ af_any_true_all,
709
+ bool
700
710
) ;
701
711
702
712
all_reduce_func_def ! (
@@ -722,7 +732,8 @@ all_reduce_func_def!(
722
732
```
723
733
" ,
724
734
count_all,
725
- af_count_all
735
+ af_count_all,
736
+ u64
726
737
) ;
727
738
728
739
/// Sum all values using user provided value for `NAN`
@@ -740,7 +751,11 @@ all_reduce_func_def!(
740
751
/// A tuple of summation result.
741
752
///
742
753
/// Note: For non-complex data type Arrays, second value of tuple is zero.
743
- pub fn sum_nan_all < T : HasAfEnum > ( input : & Array < T > , val : f64 ) -> ( f64 , f64 ) {
754
+ pub fn sum_nan_all < T > ( input : & Array < T > , val : f64 ) -> ( T :: AggregateOutType , T :: AggregateOutType )
755
+ where
756
+ T : HasAfEnum ,
757
+ T :: AggregateOutType : HasAfEnum + Fromf64 ,
758
+ {
744
759
let mut real: f64 = 0.0 ;
745
760
let mut imag: f64 = 0.0 ;
746
761
unsafe {
@@ -752,7 +767,10 @@ pub fn sum_nan_all<T: HasAfEnum>(input: &Array<T>, val: f64) -> (f64, f64) {
752
767
) ;
753
768
HANDLE_ERROR ( AfError :: from ( err_val) ) ;
754
769
}
755
- ( real, imag)
770
+ (
771
+ <T :: AggregateOutType >:: fromf64 ( real) ,
772
+ <T :: AggregateOutType >:: fromf64 ( imag) ,
773
+ )
756
774
}
757
775
758
776
/// Product of all values using user provided value for `NAN`
@@ -770,7 +788,11 @@ pub fn sum_nan_all<T: HasAfEnum>(input: &Array<T>, val: f64) -> (f64, f64) {
770
788
/// A tuple of product result.
771
789
///
772
790
/// Note: For non-complex data type Arrays, second value of tuple is zero.
773
- pub fn product_nan_all < T : HasAfEnum > ( input : & Array < T > , val : f64 ) -> ( f64 , f64 ) {
791
+ pub fn product_nan_all < T > ( input : & Array < T > , val : f64 ) -> ( T :: ProductOutType , T :: ProductOutType )
792
+ where
793
+ T : HasAfEnum ,
794
+ T :: ProductOutType : HasAfEnum + Fromf64 ,
795
+ {
774
796
let mut real: f64 = 0.0 ;
775
797
let mut imag: f64 = 0.0 ;
776
798
unsafe {
@@ -782,7 +804,10 @@ pub fn product_nan_all<T: HasAfEnum>(input: &Array<T>, val: f64) -> (f64, f64) {
782
804
) ;
783
805
HANDLE_ERROR ( AfError :: from ( err_val) ) ;
784
806
}
785
- ( real, imag)
807
+ (
808
+ <T :: ProductOutType >:: fromf64 ( real) ,
809
+ <T :: ProductOutType >:: fromf64 ( imag) ,
810
+ )
786
811
}
787
812
788
813
macro_rules! dim_ireduce_func_def {
@@ -833,9 +858,13 @@ dim_ireduce_func_def!("
833
858
" , imax, af_imax, InType ) ;
834
859
835
860
macro_rules! all_ireduce_func_def {
836
- ( $doc_str: expr, $fn_name: ident, $ffi_name: ident) => {
861
+ ( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type : ty ) => {
837
862
#[ doc=$doc_str]
838
- pub fn $fn_name<T : HasAfEnum >( input: & Array <T >) -> ( f64 , f64 , u32 ) {
863
+ pub fn $fn_name<T >( input: & Array <T >) -> ( $out_type, $out_type, u32 )
864
+ where
865
+ T : HasAfEnum ,
866
+ $out_type: HasAfEnum + Fromf64
867
+ {
839
868
let mut real: f64 = 0.0 ;
840
869
let mut imag: f64 = 0.0 ;
841
870
let mut temp: u32 = 0 ;
@@ -846,7 +875,7 @@ macro_rules! all_ireduce_func_def {
846
875
) ;
847
876
HANDLE_ERROR ( AfError :: from( err_val) ) ;
848
877
}
849
- ( real, imag, temp)
878
+ ( <$out_type> :: fromf64 ( real) , <$out_type> :: fromf64 ( imag) , temp)
850
879
}
851
880
} ;
852
881
}
@@ -868,7 +897,8 @@ all_ireduce_func_def!(
868
897
* index of minimum element in the third component.
869
898
" ,
870
899
imin_all,
871
- af_imin_all
900
+ af_imin_all,
901
+ T :: InType
872
902
) ;
873
903
all_ireduce_func_def ! (
874
904
"
@@ -887,7 +917,8 @@ all_ireduce_func_def!(
887
917
- index of maximum element in the third component.
888
918
" ,
889
919
imax_all,
890
- af_imax_all
920
+ af_imax_all,
921
+ T :: InType
891
922
) ;
892
923
893
924
/// Locate the indices of non-zero elements.
0 commit comments