Skip to content

Commit 4df8569

Browse files
committed
struct CaseSet: Optimize by matching on len one per set_ctx.
In C, `case_set` works by `switch`ing once on `len`, once per `set_ctx`, but my `CaseSet` implementation in Rust accidentally switched to `match`ing on `buf.len()` each time. This goes back to the original, optimal behavior. To do this, `set_ctx` is made a rank-2 polymorphic "closure". This does not exist in Rust, so it's emulated through a generic trait with a generic method. This inner generic over `trait CaseSetter` is what allows `fn CaseSet::one` to select the correct `CaseSetterN` at compile time. However, this means that closures can't be used anymore, which is very annoying. To partially remedy this, I added the `set_ctx!` macro, which emulates the `set_ctx` closure as much as possible. All captures (up vars in `rustc`) and their types must be declared.
1 parent c25593e commit 4df8569

File tree

4 files changed

+288
-127
lines changed

4 files changed

+288
-127
lines changed

src/ctx.rs

Lines changed: 105 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -52,55 +52,121 @@ use std::iter::zip;
5252
/// This optimizes for the common cases where `buf.len()` is a small power of 2,
5353
/// where the array write is optimized as few and large stores as possible.
5454
#[inline]
55-
pub fn small_memset<T: Clone + Copy, const UP_TO: usize, const WITH_DEFAULT: bool>(
55+
pub fn small_memset<T: Clone + Copy, const N: usize, const WITH_DEFAULT: bool>(
5656
buf: &mut [T],
5757
val: T,
5858
) {
5959
fn as_array<T: Clone + Copy, const N: usize>(buf: &mut [T]) -> &mut [T; N] {
6060
buf.try_into().unwrap()
6161
}
62-
match buf.len() {
63-
01 if UP_TO >= 01 => *as_array(buf) = [val; 01],
64-
02 if UP_TO >= 02 => *as_array(buf) = [val; 02],
65-
04 if UP_TO >= 04 => *as_array(buf) = [val; 04],
66-
08 if UP_TO >= 08 => *as_array(buf) = [val; 08],
67-
16 if UP_TO >= 16 => *as_array(buf) = [val; 16],
68-
32 if UP_TO >= 32 => *as_array(buf) = [val; 32],
69-
64 if UP_TO >= 64 => *as_array(buf) = [val; 64],
70-
_ => {
71-
if WITH_DEFAULT {
72-
buf.fill(val)
73-
}
62+
if N == 0 {
63+
if WITH_DEFAULT {
64+
buf.fill(val)
7465
}
66+
} else {
67+
assert!(buf.len() == N); // Meant to be optimized out.
68+
*as_array(buf) = [val; N];
7569
}
7670
}
7771

78-
pub struct CaseSetter<const UP_TO: usize, const WITH_DEFAULT: bool> {
72+
pub trait CaseSetter {
73+
fn set<T: Clone + Copy>(&self, buf: &mut [T], val: T);
74+
75+
/// # Safety
76+
///
77+
/// Caller must ensure that no elements of the written range are concurrently
78+
/// borrowed (immutably or mutably) at all during the call to `set_disjoint`.
79+
fn set_disjoint<T, V>(&self, buf: &DisjointMut<T>, val: V)
80+
where
81+
T: AsMutPtr<Target = V>,
82+
V: Clone + Copy;
83+
}
84+
85+
pub struct CaseSetterN<const N: usize, const WITH_DEFAULT: bool> {
7986
offset: usize,
8087
len: usize,
8188
}
8289

83-
impl<const UP_TO: usize, const WITH_DEFAULT: bool> CaseSetter<UP_TO, WITH_DEFAULT> {
90+
impl<const N: usize, const WITH_DEFAULT: bool> CaseSetterN<N, WITH_DEFAULT> {
91+
const fn len(&self) -> usize {
92+
if N == 0 {
93+
self.len
94+
} else {
95+
N
96+
}
97+
}
98+
}
99+
100+
impl<const N: usize, const WITH_DEFAULT: bool> CaseSetter for CaseSetterN<N, WITH_DEFAULT> {
84101
#[inline]
85-
pub fn set<T: Clone + Copy>(&self, buf: &mut [T], val: T) {
86-
small_memset::<T, UP_TO, WITH_DEFAULT>(&mut buf[self.offset..][..self.len], val);
102+
fn set<T: Clone + Copy>(&self, buf: &mut [T], val: T) {
103+
small_memset::<_, N, WITH_DEFAULT>(&mut buf[self.offset..][..self.len()], val);
87104
}
88105

89106
/// # Safety
90107
///
91108
/// Caller must ensure that no elements of the written range are concurrently
92109
/// borrowed (immutably or mutably) at all during the call to `set_disjoint`.
93110
#[inline]
94-
pub fn set_disjoint<T, V>(&self, buf: &DisjointMut<T>, val: V)
111+
fn set_disjoint<T, V>(&self, buf: &DisjointMut<T>, val: V)
95112
where
96113
T: AsMutPtr<Target = V>,
97114
V: Clone + Copy,
98115
{
99-
let mut buf = buf.index_mut(self.offset..self.offset + self.len);
100-
small_memset::<V, UP_TO, WITH_DEFAULT>(&mut *buf, val);
116+
let mut buf = buf.index_mut((self.offset.., ..self.len()));
117+
small_memset::<_, N, WITH_DEFAULT>(&mut *buf, val);
101118
}
102119
}
103120

121+
/// Rank-2 polymorphic closures aren't a thing in Rust yet,
122+
/// so we need to emulate this through a generic trait with a generic method.
123+
/// Unforunately, this means we have to write the closure sugar manually.
124+
pub trait SetCtx<T> {
125+
fn call<S: CaseSetter>(self, case: &S, ctx: T) -> Self;
126+
}
127+
128+
/// Emulate a closure for a [`SetCtx`] `impl`.
129+
macro_rules! set_ctx {
130+
(
131+
// `||` is used instead of just `|` due to this bug: <https://github.com/rust-lang/rustfmt/issues/6228>.
132+
||
133+
$($lifetime:lifetime,)?
134+
$case:ident,
135+
$ctx:ident: $T:ty,
136+
// Note that the required trailing `,` is so `:expr` can precede `|`.
137+
$($up_var:ident: $up_var_ty:ty$( = $up_var_val:expr)?,)*
138+
|| $body:block
139+
) => {{
140+
use $crate::src::ctx::SetCtx;
141+
use $crate::src::ctx::CaseSetter;
142+
143+
struct F$(<$lifetime>)? {
144+
$($up_var: $up_var_ty,)*
145+
}
146+
147+
impl$(<$lifetime>)? SetCtx<$T> for F$(<$lifetime>)? {
148+
fn call<S: CaseSetter>(self, $case: &S, $ctx: $T) -> Self {
149+
let Self {
150+
$($up_var,)*
151+
} = self;
152+
$body
153+
// We destructure and re-structure `Self` so that we
154+
// can move out of refs without using `ref`/`ref mut`,
155+
// which I don't know how to match on in a macro.
156+
Self {
157+
$($up_var,)*
158+
}
159+
}
160+
}
161+
162+
F {
163+
$($up_var$(: $up_var_val)?,)*
164+
}
165+
}};
166+
}
167+
168+
pub(crate) use set_ctx;
169+
104170
/// The entrypoint to the [`CaseSet`] API.
105171
///
106172
/// `UP_TO` and `WITH_DEFAULT` are made const generic parameters rather than have multiple `case_set*` `fn`s,
@@ -117,11 +183,25 @@ impl<const UP_TO: usize, const WITH_DEFAULT: bool> CaseSet<UP_TO, WITH_DEFAULT>
117183
/// The `len` and `offset` are supplied here and
118184
/// applied to each `buf` passed to [`CaseSetter::set`] in `set_ctx`.
119185
#[inline]
120-
pub fn one<T, F>(ctx: T, len: usize, offset: usize, mut set_ctx: F)
186+
pub fn one<T, F>(ctx: T, len: usize, offset: usize, set_ctx: F) -> F
121187
where
122-
F: FnMut(&CaseSetter<UP_TO, WITH_DEFAULT>, T),
188+
F: SetCtx<T>,
123189
{
124-
set_ctx(&CaseSetter { offset, len }, ctx);
190+
macro_rules! set_ctx {
191+
($N:literal) => {
192+
set_ctx.call(&CaseSetterN::<$N, WITH_DEFAULT> { offset, len }, ctx)
193+
};
194+
}
195+
match len {
196+
01 if UP_TO >= 01 => set_ctx!(01),
197+
02 if UP_TO >= 02 => set_ctx!(02),
198+
04 if UP_TO >= 04 => set_ctx!(04),
199+
08 if UP_TO >= 08 => set_ctx!(08),
200+
16 if UP_TO >= 16 => set_ctx!(16),
201+
32 if UP_TO >= 32 => set_ctx!(32),
202+
64 if UP_TO >= 64 => set_ctx!(64),
203+
_ => set_ctx!(0),
204+
}
125205
}
126206

127207
/// Perform many case sets in one call.
@@ -138,10 +218,10 @@ impl<const UP_TO: usize, const WITH_DEFAULT: bool> CaseSet<UP_TO, WITH_DEFAULT>
138218
offsets: [usize; N],
139219
mut set_ctx: F,
140220
) where
141-
F: FnMut(&CaseSetter<UP_TO, WITH_DEFAULT>, T),
221+
F: SetCtx<T>,
142222
{
143223
for (dir, (len, offset)) in zip(dirs, zip(lens, offsets)) {
144-
Self::one(dir, len, offset, &mut set_ctx);
224+
set_ctx = Self::one(dir, len, offset, set_ctx);
145225
}
146226
}
147227
}

0 commit comments

Comments
 (0)