Skip to content

Commit 87a5aaa

Browse files
committed
Fix return type trait bound on reduce all functions
1 parent 6670d5c commit 87a5aaa

File tree

1 file changed

+115
-25
lines changed

1 file changed

+115
-25
lines changed

src/algorithm/mod.rs

Lines changed: 115 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -518,12 +518,17 @@ where
518518
}
519519

520520
macro_rules! all_reduce_func_def {
521-
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => {
521+
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => {
522522
#[doc=$doc_str]
523-
pub fn $fn_name<T>(input: &Array<T>) -> ($out_type, $out_type)
523+
pub fn $fn_name<T>(input: &Array<T>)
524+
-> (
525+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
526+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType
527+
)
524528
where
525529
T: HasAfEnum,
526-
$out_type: HasAfEnum + Fromf64
530+
<T as HasAfEnum>::$assoc_type: HasAfEnum,
531+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
527532
{
528533
let mut real: f64 = 0.0;
529534
let mut imag: f64 = 0.0;
@@ -533,7 +538,10 @@ macro_rules! all_reduce_func_def {
533538
);
534539
HANDLE_ERROR(AfError::from(err_val));
535540
}
536-
(<$out_type>::fromf64(real), <$out_type>::fromf64(imag))
541+
(
542+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(real),
543+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(imag),
544+
)
537545
}
538546
};
539547
}
@@ -564,7 +572,7 @@ all_reduce_func_def!(
564572
",
565573
sum_all,
566574
af_sum_all,
567-
T::AggregateOutType
575+
AggregateOutType
568576
);
569577

570578
all_reduce_func_def!(
@@ -594,7 +602,7 @@ all_reduce_func_def!(
594602
",
595603
product_all,
596604
af_product_all,
597-
T::ProductOutType
605+
ProductOutType
598606
);
599607

600608
all_reduce_func_def!(
@@ -623,7 +631,7 @@ all_reduce_func_def!(
623631
",
624632
min_all,
625633
af_min_all,
626-
T::InType
634+
InType
627635
);
628636

629637
all_reduce_func_def!(
@@ -652,10 +660,31 @@ all_reduce_func_def!(
652660
",
653661
max_all,
654662
af_max_all,
655-
T::InType
663+
InType
656664
);
657665

658-
all_reduce_func_def!(
666+
macro_rules! all_reduce_func_def2 {
667+
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => {
668+
#[doc=$doc_str]
669+
pub fn $fn_name<T>(input: &Array<T>) -> ($out_type, $out_type)
670+
where
671+
T: HasAfEnum,
672+
$out_type: HasAfEnum + Fromf64
673+
{
674+
let mut real: f64 = 0.0;
675+
let mut imag: f64 = 0.0;
676+
unsafe {
677+
let err_val = $ffi_name(
678+
&mut real as *mut c_double, &mut imag as *mut c_double, input.get(),
679+
);
680+
HANDLE_ERROR(AfError::from(err_val));
681+
}
682+
(<$out_type>::fromf64(real), <$out_type>::fromf64(imag))
683+
}
684+
};
685+
}
686+
687+
all_reduce_func_def2!(
659688
"
660689
Find if all values of Array are non-zero
661690
@@ -682,7 +711,7 @@ all_reduce_func_def!(
682711
bool
683712
);
684713

685-
all_reduce_func_def!(
714+
all_reduce_func_def2!(
686715
"
687716
Find if any value of Array is non-zero
688717
@@ -709,7 +738,7 @@ all_reduce_func_def!(
709738
bool
710739
);
711740

712-
all_reduce_func_def!(
741+
all_reduce_func_def2!(
713742
"
714743
Count number of non-zero values in the Array
715744
@@ -751,10 +780,17 @@ all_reduce_func_def!(
751780
/// A tuple of summation result.
752781
///
753782
/// Note: For non-complex data type Arrays, second value of tuple is zero.
754-
pub fn sum_nan_all<T>(input: &Array<T>, val: f64) -> (T::AggregateOutType, T::AggregateOutType)
783+
pub fn sum_nan_all<T>(
784+
input: &Array<T>,
785+
val: f64,
786+
) -> (
787+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType,
788+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType,
789+
)
755790
where
756791
T: HasAfEnum,
757-
T::AggregateOutType: HasAfEnum + Fromf64,
792+
<T as HasAfEnum>::AggregateOutType: HasAfEnum,
793+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
758794
{
759795
let mut real: f64 = 0.0;
760796
let mut imag: f64 = 0.0;
@@ -768,8 +804,8 @@ where
768804
HANDLE_ERROR(AfError::from(err_val));
769805
}
770806
(
771-
<T::AggregateOutType>::fromf64(real),
772-
<T::AggregateOutType>::fromf64(imag),
807+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType::fromf64(real),
808+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType::fromf64(imag),
773809
)
774810
}
775811

@@ -788,10 +824,17 @@ where
788824
/// A tuple of product result.
789825
///
790826
/// Note: For non-complex data type Arrays, second value of tuple is zero.
791-
pub fn product_nan_all<T>(input: &Array<T>, val: f64) -> (T::ProductOutType, T::ProductOutType)
827+
pub fn product_nan_all<T>(
828+
input: &Array<T>,
829+
val: f64,
830+
) -> (
831+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType,
832+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType,
833+
)
792834
where
793835
T: HasAfEnum,
794-
T::ProductOutType: HasAfEnum + Fromf64,
836+
<T as HasAfEnum>::ProductOutType: HasAfEnum,
837+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
795838
{
796839
let mut real: f64 = 0.0;
797840
let mut imag: f64 = 0.0;
@@ -805,8 +848,8 @@ where
805848
HANDLE_ERROR(AfError::from(err_val));
806849
}
807850
(
808-
<T::ProductOutType>::fromf64(real),
809-
<T::ProductOutType>::fromf64(imag),
851+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType::fromf64(real),
852+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType::fromf64(imag),
810853
)
811854
}
812855

@@ -858,12 +901,18 @@ dim_ireduce_func_def!("
858901
", imax, af_imax, InType);
859902

860903
macro_rules! all_ireduce_func_def {
861-
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => {
904+
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => {
862905
#[doc=$doc_str]
863-
pub fn $fn_name<T>(input: &Array<T>) -> ($out_type, $out_type, u32)
906+
pub fn $fn_name<T>(input: &Array<T>)
907+
-> (
908+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
909+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
910+
u32
911+
)
864912
where
865913
T: HasAfEnum,
866-
$out_type: HasAfEnum + Fromf64
914+
<T as HasAfEnum>::$assoc_type: HasAfEnum,
915+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
867916
{
868917
let mut real: f64 = 0.0;
869918
let mut imag: f64 = 0.0;
@@ -875,7 +924,11 @@ macro_rules! all_ireduce_func_def {
875924
);
876925
HANDLE_ERROR(AfError::from(err_val));
877926
}
878-
(<$out_type>::fromf64(real), <$out_type>::fromf64(imag), temp)
927+
(
928+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(real),
929+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(imag),
930+
temp,
931+
)
879932
}
880933
};
881934
}
@@ -898,7 +951,7 @@ all_ireduce_func_def!(
898951
",
899952
imin_all,
900953
af_imin_all,
901-
T::InType
954+
InType
902955
);
903956
all_ireduce_func_def!(
904957
"
@@ -918,7 +971,7 @@ all_ireduce_func_def!(
918971
",
919972
imax_all,
920973
af_imax_all,
921-
T::InType
974+
InType
922975
);
923976

924977
/// Locate the indices of non-zero elements.
@@ -1386,3 +1439,40 @@ dim_reduce_by_key_nan_func_def!(
13861439
af_product_by_key_nan,
13871440
ValueType::ProductOutType
13881441
);
1442+
1443+
#[cfg(test)]
1444+
mod tests {
1445+
use super::super::core::c32;
1446+
use super::{product_nan_all, sum_all, sum_nan_all, imin_all, imax_all};
1447+
use crate::randu;
1448+
1449+
#[test]
1450+
fn all_reduce_api() {
1451+
let a = randu!(c32; 10, 10);
1452+
println!("Reduction of complex f32 matrix: {:?}", sum_all(&a));
1453+
1454+
let b = randu!(bool; 10, 10);
1455+
println!("reduction of bool matrix: {:?}", sum_all(&b));
1456+
1457+
println!(
1458+
"reduction of complex f32 matrix after replacing nan with {}: {:?}",
1459+
1.0,
1460+
product_nan_all(&a, 1.0)
1461+
);
1462+
1463+
println!(
1464+
"reduction of bool matrix after replacing nan with {}: {:?}",
1465+
0.0,
1466+
sum_nan_all(&b, 0.0)
1467+
);
1468+
}
1469+
1470+
#[test]
1471+
fn all_ireduce_api() {
1472+
let a = randu!(c32; 10);
1473+
println!("Reduction of complex f32 matrix: {:?}", imin_all(&a));
1474+
1475+
let b = randu!(u32; 10);
1476+
println!("reduction of bool matrix: {:?}", imax_all(&b));
1477+
}
1478+
}

0 commit comments

Comments
 (0)