Skip to content

Commit 4c99aa4

Browse files
Collapse reductions in docs using traits
1 parent 6b93a89 commit 4c99aa4

File tree

1 file changed

+105
-62
lines changed

1 file changed

+105
-62
lines changed

crates/core_simd/src/reduction.rs

Lines changed: 105 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,58 +2,117 @@ use crate::simd::intrinsics::{
22
simd_reduce_add_ordered, simd_reduce_and, simd_reduce_max, simd_reduce_min,
33
simd_reduce_mul_ordered, simd_reduce_or, simd_reduce_xor,
44
};
5-
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
5+
use crate::simd::{Int, LaneCount, Simd, SimdElement, SupportedLaneCount};
6+
7+
impl<T, const LANES: usize> Simd<T, LANES>
8+
where
9+
T: Int,
10+
LaneCount<LANES>: SupportedLaneCount,
11+
{
12+
/// Horizontal bitwise "and". Returns the cumulative bitwise "and" across the lanes of
13+
/// the vector.
14+
#[inline]
15+
pub fn horizontal_and(self) -> T {
16+
unsafe { simd_reduce_and(self) }
17+
}
18+
19+
/// Horizontal bitwise "or". Returns the cumulative bitwise "or" across the lanes of
20+
/// the vector.
21+
#[inline]
22+
pub fn horizontal_or(self) -> T {
23+
unsafe { simd_reduce_or(self) }
24+
}
25+
26+
/// Horizontal bitwise "xor". Returns the cumulative bitwise "xor" across the lanes of
27+
/// the vector.
28+
#[inline]
29+
pub fn horizontal_xor(self) -> T {
30+
unsafe { simd_reduce_xor(self) }
31+
}
32+
}
33+
34+
impl<T, const LANES: usize> Simd<T, LANES>
35+
where
36+
T: SimdElement,
37+
LaneCount<LANES>: SupportedLaneCount,
38+
{
39+
/// Horizontal maximum. Returns the maximum lane in the vector.
40+
///
41+
/// Returns values based on equality, so a vector containing both `0.` and `-0.` may
42+
/// return either. This function will not return `NaN` unless all lanes are `NaN`.
43+
#[inline]
44+
pub fn horizontal_max(self) -> T {
45+
unsafe { simd_reduce_max(self) }
46+
}
47+
48+
/// Horizontal minimum. Returns the minimum lane in the vector.
49+
///
50+
/// Returns values based on equality, so a vector containing both `0.` and `-0.` may
51+
/// return either. This function will not return `NaN` unless all lanes are `NaN`.
52+
#[inline]
53+
pub fn horizontal_min(self) -> T {
54+
unsafe { simd_reduce_min(self) }
55+
}
56+
}
57+
58+
impl<T, const LANES: usize> Simd<T, LANES>
59+
where
60+
Self: HorizontalArith<Scalar = T>,
61+
T: SimdElement,
62+
LaneCount<LANES>: SupportedLaneCount,
63+
{
64+
/// Horizontal add. Returns the sum of the lanes of the vector.
65+
#[inline]
66+
pub fn horizontal_sum(self) -> T {
67+
<Self as HorizontalArith>::horizontal_sum(self)
68+
}
69+
70+
/// Horizontal multiply. Returns the product of the lanes of the vector.
71+
#[inline]
72+
pub fn horizontal_product(self) -> T {
73+
<Self as HorizontalArith>::horizontal_product(self)
74+
}
75+
}
76+
77+
mod sealed {
78+
pub trait Sealed {}
79+
}
80+
use sealed::Sealed;
81+
impl<T, const LANES: usize> Sealed for Simd<T, LANES>
82+
where
83+
T: SimdElement,
84+
LaneCount<LANES>: SupportedLaneCount,
85+
{
86+
}
87+
88+
pub trait HorizontalArith: Sealed {
89+
type Scalar: SimdElement;
90+
/// Horizontal add. Returns the sum of the lanes of the vector.
91+
fn horizontal_sum(self) -> Self::Scalar;
92+
93+
/// Horizontal multiply. Returns the product of the lanes of the vector.
94+
fn horizontal_product(self) -> Self::Scalar;
95+
}
696

797
macro_rules! impl_integer_reductions {
898
{ $scalar:ty } => {
9-
impl<const LANES: usize> Simd<$scalar, LANES>
99+
impl<const LANES: usize> HorizontalArith for Simd<$scalar, LANES>
10100
where
11-
LaneCount<LANES>: SupportedLaneCount,
12-
{
101+
LaneCount<LANES>: SupportedLaneCount,
102+
103+
{
104+
type Scalar = $scalar;
13105
/// Horizontal wrapping add. Returns the sum of the lanes of the vector, with wrapping addition.
14106
#[inline]
15-
pub fn horizontal_sum(self) -> $scalar {
107+
fn horizontal_sum(self) -> $scalar {
16108
unsafe { simd_reduce_add_ordered(self, 0) }
17109
}
18110

19111
/// Horizontal wrapping multiply. Returns the product of the lanes of the vector, with wrapping multiplication.
20112
#[inline]
21-
pub fn horizontal_product(self) -> $scalar {
113+
fn horizontal_product(self) -> $scalar {
22114
unsafe { simd_reduce_mul_ordered(self, 1) }
23115
}
24-
25-
/// Horizontal bitwise "and". Returns the cumulative bitwise "and" across the lanes of
26-
/// the vector.
27-
#[inline]
28-
pub fn horizontal_and(self) -> $scalar {
29-
unsafe { simd_reduce_and(self) }
30-
}
31-
32-
/// Horizontal bitwise "or". Returns the cumulative bitwise "or" across the lanes of
33-
/// the vector.
34-
#[inline]
35-
pub fn horizontal_or(self) -> $scalar {
36-
unsafe { simd_reduce_or(self) }
37-
}
38-
39-
/// Horizontal bitwise "xor". Returns the cumulative bitwise "xor" across the lanes of
40-
/// the vector.
41-
#[inline]
42-
pub fn horizontal_xor(self) -> $scalar {
43-
unsafe { simd_reduce_xor(self) }
44-
}
45-
46-
/// Horizontal maximum. Returns the maximum lane in the vector.
47-
#[inline]
48-
pub fn horizontal_max(self) -> $scalar {
49-
unsafe { simd_reduce_max(self) }
50-
}
51-
52-
/// Horizontal minimum. Returns the minimum lane in the vector.
53-
#[inline]
54-
pub fn horizontal_min(self) -> $scalar {
55-
unsafe { simd_reduce_min(self) }
56-
}
57116
}
58117
}
59118
}
@@ -71,14 +130,16 @@ impl_integer_reductions! { usize }
71130

72131
macro_rules! impl_float_reductions {
73132
{ $scalar:ty } => {
74-
impl<const LANES: usize> Simd<$scalar, LANES>
133+
impl<const LANES: usize> HorizontalArith for Simd<$scalar, LANES>
75134
where
76-
LaneCount<LANES>: SupportedLaneCount,
77-
{
135+
LaneCount<LANES>: SupportedLaneCount,
136+
137+
{
138+
type Scalar = $scalar;
78139

79140
/// Horizontal add. Returns the sum of the lanes of the vector.
80141
#[inline]
81-
pub fn horizontal_sum(self) -> $scalar {
142+
fn horizontal_sum(self) -> $scalar {
82143
// LLVM sum is inaccurate on i586
83144
if cfg!(all(target_arch = "x86", not(target_feature = "sse2"))) {
84145
self.as_array().iter().sum()
@@ -89,32 +150,14 @@ macro_rules! impl_float_reductions {
89150

90151
/// Horizontal multiply. Returns the product of the lanes of the vector.
91152
#[inline]
92-
pub fn horizontal_product(self) -> $scalar {
153+
fn horizontal_product(self) -> $scalar {
93154
// LLVM product is inaccurate on i586
94155
if cfg!(all(target_arch = "x86", not(target_feature = "sse2"))) {
95156
self.as_array().iter().product()
96157
} else {
97158
unsafe { simd_reduce_mul_ordered(self, 1.) }
98159
}
99160
}
100-
101-
/// Horizontal maximum. Returns the maximum lane in the vector.
102-
///
103-
/// Returns values based on equality, so a vector containing both `0.` and `-0.` may
104-
/// return either. This function will not return `NaN` unless all lanes are `NaN`.
105-
#[inline]
106-
pub fn horizontal_max(self) -> $scalar {
107-
unsafe { simd_reduce_max(self) }
108-
}
109-
110-
/// Horizontal minimum. Returns the minimum lane in the vector.
111-
///
112-
/// Returns values based on equality, so a vector containing both `0.` and `-0.` may
113-
/// return either. This function will not return `NaN` unless all lanes are `NaN`.
114-
#[inline]
115-
pub fn horizontal_min(self) -> $scalar {
116-
unsafe { simd_reduce_min(self) }
117-
}
118161
}
119162
}
120163
}

0 commit comments

Comments
 (0)