Skip to content

Commit 85fa980

Browse files
Introduce trait-driven arithmetic
This allows collapsing 5~10 instances of a function on the Simd type into 1-3 copies, at least from the perspective of the docs. The result is far more legible to a user.
1 parent 36e198b commit 85fa980

File tree

6 files changed

+369
-273
lines changed

6 files changed

+369
-273
lines changed

crates/core_simd/src/math.rs

Lines changed: 150 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -1,159 +1,156 @@
1-
use crate::simd::intrinsics::{simd_saturating_add, simd_saturating_sub};
2-
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
3-
4-
macro_rules! impl_uint_arith {
5-
($($ty:ty),+) => {
6-
$( impl<const LANES: usize> Simd<$ty, LANES> where LaneCount<LANES>: SupportedLaneCount {
7-
8-
/// Lanewise saturating add.
9-
///
10-
/// # Examples
11-
/// ```
12-
/// # #![feature(portable_simd)]
13-
/// # #[cfg(feature = "std")] use core_simd::Simd;
14-
/// # #[cfg(not(feature = "std"))] use core::simd::Simd;
15-
#[doc = concat!("# use core::", stringify!($ty), "::MAX;")]
16-
/// let x = Simd::from_array([2, 1, 0, MAX]);
17-
/// let max = Simd::splat(MAX);
18-
/// let unsat = x + max;
19-
/// let sat = x.saturating_add(max);
20-
/// assert_eq!(x - 1, unsat);
21-
/// assert_eq!(sat, max);
22-
/// ```
23-
#[inline]
24-
pub fn saturating_add(self, second: Self) -> Self {
25-
unsafe { simd_saturating_add(self, second) }
26-
}
27-
28-
/// Lanewise saturating subtract.
29-
///
30-
/// # Examples
31-
/// ```
32-
/// # #![feature(portable_simd)]
33-
/// # #[cfg(feature = "std")] use core_simd::Simd;
34-
/// # #[cfg(not(feature = "std"))] use core::simd::Simd;
35-
#[doc = concat!("# use core::", stringify!($ty), "::MAX;")]
36-
/// let x = Simd::from_array([2, 1, 0, MAX]);
37-
/// let max = Simd::splat(MAX);
38-
/// let unsat = x - max;
39-
/// let sat = x.saturating_sub(max);
40-
/// assert_eq!(unsat, x + 1);
41-
/// assert_eq!(sat, Simd::splat(0));
42-
#[inline]
43-
pub fn saturating_sub(self, second: Self) -> Self {
44-
unsafe { simd_saturating_sub(self, second) }
45-
}
46-
})+
1+
use crate::simd::intrinsics;
2+
use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount};
3+
4+
mod sealed {
5+
pub trait Sealed {}
6+
}
7+
use sealed::Sealed;
8+
impl<T, const LANES: usize> Sealed for Simd<T, LANES>
9+
where
10+
T: SimdElement,
11+
LaneCount<LANES>: SupportedLaneCount,
12+
{
13+
}
14+
15+
impl<T, const LANES: usize> Simd<T, LANES>
16+
where
17+
T: Int,
18+
LaneCount<LANES>: SupportedLaneCount,
19+
{
20+
/// Lanewise saturating add.
21+
///
22+
/// # Examples
23+
/// ```
24+
/// # #![feature(portable_simd)]
25+
/// # #[cfg(feature = "std")] use core_simd::Simd;
26+
/// # #[cfg(not(feature = "std"))] use core::simd::Simd;
27+
/// let x = Simd::from_array([i32::MIN, 0, 1, i32::MAX]);
28+
/// let max = Simd::splat(i32::MAX);
29+
/// let unsat = x + max;
30+
/// let sat = x.saturating_add(max);
31+
/// assert_eq!(unsat, Simd::from_array([-1, i32::MAX, i32::MIN, -2]));
32+
/// assert_eq!(sat, Simd::from_array([-1, i32::MAX, i32::MAX, i32::MAX]));
33+
/// ```
34+
#[inline]
35+
pub fn saturating_add(self, other: Self) -> Self {
36+
unsafe { intrinsics::simd_saturating_add(self, other) }
4737
}
38+
39+
/// Lanewise saturating subtract.
40+
///
41+
/// # Examples
42+
/// ```
43+
/// # #![feature(portable_simd)]
44+
/// # #[cfg(feature = "std")] use core_simd::Simd;
45+
/// # #[cfg(not(feature = "std"))] use core::simd::Simd;
46+
/// let x = Simd::from_array([i32::MIN, -2, -1, i32::MAX]);
47+
/// let max = Simd::splat(i32::MAX);
48+
/// let unsat = x - max;
49+
/// let sat = x.saturating_sub(max);
50+
/// assert_eq!(unsat, Simd::from_array([1, i32::MAX, i32::MIN, 0]));
51+
/// assert_eq!(sat, Simd::from_array([i32::MIN, i32::MIN, i32::MIN, 0]));
52+
#[inline]
53+
pub fn saturating_sub(self, other: Self) -> Self {
54+
unsafe { intrinsics::simd_saturating_sub(self, other) }
55+
}
56+
}
57+
58+
pub trait Int: SimdElement + PartialOrd {
59+
const BITS: u32;
60+
}
61+
62+
impl Int for u8 {
63+
const BITS: u32 = 8;
64+
}
65+
66+
impl Int for i8 {
67+
const BITS: u32 = 8;
68+
}
69+
70+
impl Int for u16 {
71+
const BITS: u32 = 16;
72+
}
73+
74+
impl Int for i16 {
75+
const BITS: u32 = 16;
76+
}
77+
78+
impl Int for u32 {
79+
const BITS: u32 = 32;
80+
}
81+
82+
impl Int for i32 {
83+
const BITS: u32 = 32;
84+
}
85+
86+
impl Int for u64 {
87+
const BITS: u32 = 64;
88+
}
89+
90+
impl Int for i64 {
91+
const BITS: u32 = 64;
92+
}
93+
94+
impl Int for usize {
95+
const BITS: u32 = usize::BITS;
4896
}
4997

50-
macro_rules! impl_int_arith {
51-
($($ty:ty),+) => {
52-
$( impl<const LANES: usize> Simd<$ty, LANES> where LaneCount<LANES>: SupportedLaneCount {
53-
54-
/// Lanewise saturating add.
55-
///
56-
/// # Examples
57-
/// ```
58-
/// # #![feature(portable_simd)]
59-
/// # #[cfg(feature = "std")] use core_simd::Simd;
60-
/// # #[cfg(not(feature = "std"))] use core::simd::Simd;
61-
#[doc = concat!("# use core::", stringify!($ty), "::{MIN, MAX};")]
62-
/// let x = Simd::from_array([MIN, 0, 1, MAX]);
63-
/// let max = Simd::splat(MAX);
64-
/// let unsat = x + max;
65-
/// let sat = x.saturating_add(max);
66-
/// assert_eq!(unsat, Simd::from_array([-1, MAX, MIN, -2]));
67-
/// assert_eq!(sat, Simd::from_array([-1, MAX, MAX, MAX]));
68-
/// ```
69-
#[inline]
70-
pub fn saturating_add(self, second: Self) -> Self {
71-
unsafe { simd_saturating_add(self, second) }
72-
}
73-
74-
/// Lanewise saturating subtract.
75-
///
76-
/// # Examples
77-
/// ```
78-
/// # #![feature(portable_simd)]
79-
/// # #[cfg(feature = "std")] use core_simd::Simd;
80-
/// # #[cfg(not(feature = "std"))] use core::simd::Simd;
81-
#[doc = concat!("# use core::", stringify!($ty), "::{MIN, MAX};")]
82-
/// let x = Simd::from_array([MIN, -2, -1, MAX]);
83-
/// let max = Simd::splat(MAX);
84-
/// let unsat = x - max;
85-
/// let sat = x.saturating_sub(max);
86-
/// assert_eq!(unsat, Simd::from_array([1, MAX, MIN, 0]));
87-
/// assert_eq!(sat, Simd::from_array([MIN, MIN, MIN, 0]));
88-
#[inline]
89-
pub fn saturating_sub(self, second: Self) -> Self {
90-
unsafe { simd_saturating_sub(self, second) }
91-
}
92-
93-
/// Lanewise absolute value, implemented in Rust.
94-
/// Every lane becomes its absolute value.
95-
///
96-
/// # Examples
97-
/// ```
98-
/// # #![feature(portable_simd)]
99-
/// # #[cfg(feature = "std")] use core_simd::Simd;
100-
/// # #[cfg(not(feature = "std"))] use core::simd::Simd;
101-
#[doc = concat!("# use core::", stringify!($ty), "::{MIN, MAX};")]
102-
/// let xs = Simd::from_array([MIN, MIN +1, -5, 0]);
103-
/// assert_eq!(xs.abs(), Simd::from_array([MIN, MAX, 5, 0]));
104-
/// ```
105-
#[inline]
106-
pub fn abs(self) -> Self {
107-
const SHR: $ty = <$ty>::BITS as $ty - 1;
108-
let m = self >> SHR;
109-
(self^m) - m
110-
}
111-
112-
/// Lanewise saturating absolute value, implemented in Rust.
113-
/// As abs(), except the MIN value becomes MAX instead of itself.
114-
///
115-
/// # Examples
116-
/// ```
117-
/// # #![feature(portable_simd)]
118-
/// # #[cfg(feature = "std")] use core_simd::Simd;
119-
/// # #[cfg(not(feature = "std"))] use core::simd::Simd;
120-
#[doc = concat!("# use core::", stringify!($ty), "::{MIN, MAX};")]
121-
/// let xs = Simd::from_array([MIN, -2, 0, 3]);
122-
/// let unsat = xs.abs();
123-
/// let sat = xs.saturating_abs();
124-
/// assert_eq!(unsat, Simd::from_array([MIN, 2, 0, 3]));
125-
/// assert_eq!(sat, Simd::from_array([MAX, 2, 0, 3]));
126-
/// ```
127-
#[inline]
128-
pub fn saturating_abs(self) -> Self {
129-
// arith shift for -1 or 0 mask based on sign bit, giving 2s complement
130-
const SHR: $ty = <$ty>::BITS as $ty - 1;
131-
let m = self >> SHR;
132-
(self^m).saturating_sub(m)
133-
}
134-
135-
/// Lanewise saturating negation, implemented in Rust.
136-
/// As neg(), except the MIN value becomes MAX instead of itself.
137-
///
138-
/// # Examples
139-
/// ```
140-
/// # #![feature(portable_simd)]
141-
/// # #[cfg(feature = "std")] use core_simd::Simd;
142-
/// # #[cfg(not(feature = "std"))] use core::simd::Simd;
143-
#[doc = concat!("# use core::", stringify!($ty), "::{MIN, MAX};")]
144-
/// let x = Simd::from_array([MIN, -2, 3, MAX]);
145-
/// let unsat = -x;
146-
/// let sat = x.saturating_neg();
147-
/// assert_eq!(unsat, Simd::from_array([MIN, 2, -3, MIN + 1]));
148-
/// assert_eq!(sat, Simd::from_array([MAX, 2, -3, MIN + 1]));
149-
/// ```
150-
#[inline]
151-
pub fn saturating_neg(self) -> Self {
152-
Self::splat(0).saturating_sub(self)
153-
}
154-
})+
98+
impl Int for isize {
99+
const BITS: u32 = isize::BITS;
100+
}
101+
102+
pub trait SimdSignum: Sealed {
103+
fn signum(self) -> Self;
104+
}
105+
106+
impl<T, const LANES: usize> Simd<T, LANES>
107+
where
108+
Self: SimdSignum,
109+
T: SimdElement,
110+
LaneCount<LANES>: SupportedLaneCount,
111+
{
112+
/// Replaces each lane with a number that represents its sign.
113+
///
114+
/// For floats:
115+
/// * `1.0` if the number is positive, `+0.0`, or `INFINITY`
116+
/// * `-1.0` if the number is negative, `-0.0`, or `NEG_INFINITY`
117+
/// * `NAN` if the number is `NAN`
118+
///
119+
/// For signed integers:
120+
/// * `0` if the number is zero
121+
/// * `1` if the number is positive
122+
/// * `-1` if the number is negative
123+
#[inline]
124+
pub fn signum(self) -> Self {
125+
<Self as SimdSignum>::signum(self)
155126
}
156127
}
157128

158-
impl_uint_arith! { u8, u16, u32, u64, usize }
159-
impl_int_arith! { i8, i16, i32, i64, isize }
129+
pub trait SimdAbs: Sealed {
130+
/// Returns a vector where every lane has the absolute value of the
131+
/// equivalent index in `self`.
132+
fn abs(self) -> Self;
133+
}
134+
135+
impl<T, const LANES: usize> Simd<T, LANES>
136+
where
137+
Self: SimdAbs,
138+
T: SimdElement,
139+
LaneCount<LANES>: SupportedLaneCount,
140+
{
141+
/// Returns a vector where every lane has the absolute value of the
142+
/// equivalent index in `self`.
143+
///
144+
/// # Examples
145+
/// ```rust
146+
/// # #![feature(portable_simd)]
147+
/// # #[cfg(feature = "std")] use core_simd::Simd;
148+
/// # #[cfg(not(feature = "std"))] use core::simd::Simd;
149+
/// let xs = Simd::from_array([i32::MIN, i32::MIN +1, -5, 0]);
150+
/// assert_eq!(xs.abs(), Simd::from_array([i32::MIN, i32::MAX, 5, 0]));
151+
/// ```
152+
#[inline]
153+
pub fn abs(self) -> Self {
154+
<Self as SimdAbs>::abs(self)
155+
}
156+
}

crates/core_simd/src/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ mod reduction;
55
mod swizzle;
66

77
pub(crate) mod intrinsics;
8+
pub(crate) mod math;
89

910
#[cfg(feature = "generic_const_exprs")]
1011
mod to_bytes;
@@ -14,7 +15,6 @@ mod fmt;
1415
mod iter;
1516
mod lane_count;
1617
mod masks;
17-
mod math;
1818
mod ops;
1919
mod round;
2020
mod select;
@@ -24,6 +24,7 @@ mod vendor;
2424
#[doc = include_str!("core_simd_docs.md")]
2525
pub mod simd {
2626
pub(crate) use crate::core_simd::intrinsics;
27+
pub(crate) use crate::core_simd::math::*;
2728

2829
pub use crate::core_simd::lane_count::{LaneCount, SupportedLaneCount};
2930
pub use crate::core_simd::masks::*;

0 commit comments

Comments
 (0)