Skip to content

Commit 39ab7a5

Browse files
committed
Use appropriate output type for reduce all ops
1 parent 8076658 commit 39ab7a5

File tree

5 files changed

+84
-24
lines changed

5 files changed

+84
-24
lines changed

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ lazy_static = "1.0"
5454
half = "1.5.0"
5555

5656
[dev-dependencies]
57-
float-cmp = "0.6.0"
5857
half = "1.5.0"
5958

6059
[build-dependencies]

examples/pi.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ fn main() {
2323
let root = &sqrt(xplusy);
2424
let cnst = &constant(1, dims);
2525
let (real, imag) = sum_all(&le(root, cnst, false));
26-
let pi_val = real * 4.0 / (samples as f64);
26+
let pi_val = (real as f64) * 4.0 / (samples as f64);
2727
}
2828

2929
println!("Estimated Pi Value in {:?}", start.elapsed());

src/algorithm/mod.rs

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::core::{
2-
af_array, AfError, Array, BinaryOp, HasAfEnum, RealNumber, ReduceByKeyInput, Scanable,
2+
af_array, AfError, Array, BinaryOp, Fromf64, HasAfEnum, RealNumber, ReduceByKeyInput, Scanable,
33
HANDLE_ERROR,
44
};
55

@@ -518,9 +518,13 @@ where
518518
}
519519

520520
macro_rules! all_reduce_func_def {
521-
($doc_str: expr, $fn_name: ident, $ffi_name: ident) => {
521+
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => {
522522
#[doc=$doc_str]
523-
pub fn $fn_name<T: HasAfEnum>(input: &Array<T>) -> (f64, f64) {
523+
pub fn $fn_name<T>(input: &Array<T>) -> ($out_type, $out_type)
524+
where
525+
T: HasAfEnum,
526+
$out_type: HasAfEnum + Fromf64
527+
{
524528
let mut real: f64 = 0.0;
525529
let mut imag: f64 = 0.0;
526530
unsafe {
@@ -529,7 +533,7 @@ macro_rules! all_reduce_func_def {
529533
);
530534
HANDLE_ERROR(AfError::from(err_val));
531535
}
532-
(real, imag)
536+
(<$out_type>::fromf64(real), <$out_type>::fromf64(imag))
533537
}
534538
};
535539
}
@@ -559,7 +563,8 @@ all_reduce_func_def!(
559563
```
560564
",
561565
sum_all,
562-
af_sum_all
566+
af_sum_all,
567+
T::AggregateOutType
563568
);
564569

565570
all_reduce_func_def!(
@@ -588,7 +593,8 @@ all_reduce_func_def!(
588593
```
589594
",
590595
product_all,
591-
af_product_all
596+
af_product_all,
597+
T::ProductOutType
592598
);
593599

594600
all_reduce_func_def!(
@@ -616,7 +622,8 @@ all_reduce_func_def!(
616622
```
617623
",
618624
min_all,
619-
af_min_all
625+
af_min_all,
626+
T::InType
620627
);
621628

622629
all_reduce_func_def!(
@@ -644,7 +651,8 @@ all_reduce_func_def!(
644651
```
645652
",
646653
max_all,
647-
af_max_all
654+
af_max_all,
655+
T::InType
648656
);
649657

650658
all_reduce_func_def!(
@@ -670,7 +678,8 @@ all_reduce_func_def!(
670678
```
671679
",
672680
all_true_all,
673-
af_all_true_all
681+
af_all_true_all,
682+
bool
674683
);
675684

676685
all_reduce_func_def!(
@@ -696,7 +705,8 @@ all_reduce_func_def!(
696705
```
697706
",
698707
any_true_all,
699-
af_any_true_all
708+
af_any_true_all,
709+
bool
700710
);
701711

702712
all_reduce_func_def!(
@@ -722,7 +732,8 @@ all_reduce_func_def!(
722732
```
723733
",
724734
count_all,
725-
af_count_all
735+
af_count_all,
736+
u64
726737
);
727738

728739
/// Sum all values using user provided value for `NAN`
@@ -740,7 +751,11 @@ all_reduce_func_def!(
740751
/// A tuple of summation result.
741752
///
742753
/// Note: For non-complex data type Arrays, second value of tuple is zero.
743-
pub fn sum_nan_all<T: HasAfEnum>(input: &Array<T>, val: f64) -> (f64, f64) {
754+
pub fn sum_nan_all<T>(input: &Array<T>, val: f64) -> (T::AggregateOutType, T::AggregateOutType)
755+
where
756+
T: HasAfEnum,
757+
T::AggregateOutType: HasAfEnum + Fromf64,
758+
{
744759
let mut real: f64 = 0.0;
745760
let mut imag: f64 = 0.0;
746761
unsafe {
@@ -752,7 +767,10 @@ pub fn sum_nan_all<T: HasAfEnum>(input: &Array<T>, val: f64) -> (f64, f64) {
752767
);
753768
HANDLE_ERROR(AfError::from(err_val));
754769
}
755-
(real, imag)
770+
(
771+
<T::AggregateOutType>::fromf64(real),
772+
<T::AggregateOutType>::fromf64(imag),
773+
)
756774
}
757775

758776
/// Product of all values using user provided value for `NAN`
@@ -770,7 +788,11 @@ pub fn sum_nan_all<T: HasAfEnum>(input: &Array<T>, val: f64) -> (f64, f64) {
770788
/// A tuple of product result.
771789
///
772790
/// Note: For non-complex data type Arrays, second value of tuple is zero.
773-
pub fn product_nan_all<T: HasAfEnum>(input: &Array<T>, val: f64) -> (f64, f64) {
791+
pub fn product_nan_all<T>(input: &Array<T>, val: f64) -> (T::ProductOutType, T::ProductOutType)
792+
where
793+
T: HasAfEnum,
794+
T::ProductOutType: HasAfEnum + Fromf64,
795+
{
774796
let mut real: f64 = 0.0;
775797
let mut imag: f64 = 0.0;
776798
unsafe {
@@ -782,7 +804,10 @@ pub fn product_nan_all<T: HasAfEnum>(input: &Array<T>, val: f64) -> (f64, f64) {
782804
);
783805
HANDLE_ERROR(AfError::from(err_val));
784806
}
785-
(real, imag)
807+
(
808+
<T::ProductOutType>::fromf64(real),
809+
<T::ProductOutType>::fromf64(imag),
810+
)
786811
}
787812

788813
macro_rules! dim_ireduce_func_def {
@@ -833,9 +858,13 @@ dim_ireduce_func_def!("
833858
", imax, af_imax, InType);
834859

835860
macro_rules! all_ireduce_func_def {
836-
($doc_str: expr, $fn_name: ident, $ffi_name: ident) => {
861+
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => {
837862
#[doc=$doc_str]
838-
pub fn $fn_name<T: HasAfEnum>(input: &Array<T>) -> (f64, f64, u32) {
863+
pub fn $fn_name<T>(input: &Array<T>) -> ($out_type, $out_type, u32)
864+
where
865+
T: HasAfEnum,
866+
$out_type: HasAfEnum + Fromf64
867+
{
839868
let mut real: f64 = 0.0;
840869
let mut imag: f64 = 0.0;
841870
let mut temp: u32 = 0;
@@ -846,7 +875,7 @@ macro_rules! all_ireduce_func_def {
846875
);
847876
HANDLE_ERROR(AfError::from(err_val));
848877
}
849-
(real, imag, temp)
878+
(<$out_type>::fromf64(real), <$out_type>::fromf64(imag), temp)
850879
}
851880
};
852881
}
@@ -868,7 +897,8 @@ all_ireduce_func_def!(
868897
* index of minimum element in the third component.
869898
",
870899
imin_all,
871-
af_imin_all
900+
af_imin_all,
901+
T::InType
872902
);
873903
all_ireduce_func_def!(
874904
"
@@ -887,7 +917,8 @@ all_ireduce_func_def!(
887917
- index of maximum element in the third component.
888918
",
889919
imax_all,
890-
af_imax_all
920+
af_imax_all,
921+
T::InType
891922
);
892923

893924
/// Locate the indices of non-zero elements.

src/core/util.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,3 +796,34 @@ impl BitOr for MatProp {
796796
Self::from(self as u32 | rhs as u32)
797797
}
798798
}
799+
800+
/// Trait to convert reduction's scalar output to appropriate output type
801+
///
802+
/// This is an internal trait and ideally of no use to user usecases.
803+
pub trait Fromf64 {
804+
/// Convert to target type from a double precision value
805+
fn fromf64(value: f64) -> Self;
806+
}
807+
808+
#[rustfmt::skip]
809+
impl Fromf64 for usize{ fn fromf64(value: f64) -> Self { value as Self }}
810+
#[rustfmt::skip]
811+
impl Fromf64 for f64 { fn fromf64(value: f64) -> Self { value as Self }}
812+
#[rustfmt::skip]
813+
impl Fromf64 for u64 { fn fromf64(value: f64) -> Self { value as Self }}
814+
#[rustfmt::skip]
815+
impl Fromf64 for i64 { fn fromf64(value: f64) -> Self { value as Self }}
816+
#[rustfmt::skip]
817+
impl Fromf64 for f32 { fn fromf64(value: f64) -> Self { value as Self }}
818+
#[rustfmt::skip]
819+
impl Fromf64 for u32 { fn fromf64(value: f64) -> Self { value as Self }}
820+
#[rustfmt::skip]
821+
impl Fromf64 for i32 { fn fromf64(value: f64) -> Self { value as Self }}
822+
#[rustfmt::skip]
823+
impl Fromf64 for u16 { fn fromf64(value: f64) -> Self { value as Self }}
824+
#[rustfmt::skip]
825+
impl Fromf64 for i16 { fn fromf64(value: f64) -> Self { value as Self }}
826+
#[rustfmt::skip]
827+
impl Fromf64 for u8 { fn fromf64(value: f64) -> Self { value as Self }}
828+
#[rustfmt::skip]
829+
impl Fromf64 for bool { fn fromf64(value: f64) -> Self { value > 0.0 }}

tests/scalar_arith.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use ::arrayfire::*;
2-
use float_cmp::approx_eq;
32

43
#[test]
54
fn check_scalar_arith() {
@@ -15,5 +14,5 @@ fn check_scalar_arith() {
1514
let scalar_res = all_true_all(&scalar_res_comp);
1615
let res = all_true_all(&res_comp);
1716

18-
assert!(approx_eq!(f64, scalar_res.0, res.0, ulps = 2));
17+
assert!(scalar_res.0 == res.0);
1918
}

0 commit comments

Comments
 (0)