Skip to content

Commit 9fd3c9e

Browse files
authored
Add an IndexUnchecked trait that uses asm! (#805)
* Add an IndexUnchecked trait to spirv-std/arch * Slap some #[spirv_std_macros::gpu_only] on there * Spelling * Add safety sections to the docs * Improve documentation, implement for non-spirv targets
1 parent 6232d95 commit 9fd3c9e

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

crates/spirv-std/src/arch.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,92 @@ pub fn signed_min<T: SignedInteger>(a: T, b: T) -> T {
245245
pub fn signed_max<T: SignedInteger>(a: T, b: T) -> T {
246246
unsafe { call_glsl_op_with_ints::<_, 42>(a, b) }
247247
}
248+
249+
/// Index into an array without bounds checking.
250+
///
251+
/// The main purpose of this trait is to work around the fact that the regular `get_unchecked*`
252+
/// methods do not work in in SPIR-V.
253+
pub trait IndexUnchecked<T> {
254+
/// Returns a reference to the element at `index`. The equivalent of `get_unchecked`.
255+
///
256+
/// # Safety
257+
/// Behavior is undefined if the `index` value is greater than or equal to the length of the array.
258+
unsafe fn index_unchecked(&self, index: usize) -> &T;
259+
/// Returns a mutable reference to the element at `index`. The equivalent of `get_unchecked_mut`.
260+
///
261+
/// # Safety
262+
/// Behavior is undefined if the `index` value is greater than or equal to the length of the array.
263+
unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T;
264+
}
265+
266+
impl<T> IndexUnchecked<T> for [T] {
267+
#[cfg(target_arch = "spirv")]
268+
unsafe fn index_unchecked(&self, index: usize) -> &T {
269+
asm!(
270+
"%slice_ptr = OpLoad _ {slice_ptr_ptr}",
271+
"%data_ptr = OpCompositeExtract _ %slice_ptr 0",
272+
"%val_ptr = OpAccessChain _ %data_ptr {index}",
273+
"OpReturnValue %val_ptr",
274+
slice_ptr_ptr = in(reg) &self,
275+
index = in(reg) index,
276+
options(noreturn)
277+
)
278+
}
279+
280+
#[cfg(not(target_arch = "spirv"))]
281+
unsafe fn index_unchecked(&self, index: usize) -> &T {
282+
self.get_unchecked(index)
283+
}
284+
285+
#[cfg(target_arch = "spirv")]
286+
unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T {
287+
asm!(
288+
"%slice_ptr = OpLoad _ {slice_ptr_ptr}",
289+
"%data_ptr = OpCompositeExtract _ %slice_ptr 0",
290+
"%val_ptr = OpAccessChain _ %data_ptr {index}",
291+
"OpReturnValue %val_ptr",
292+
slice_ptr_ptr = in(reg) &self,
293+
index = in(reg) index,
294+
options(noreturn)
295+
)
296+
}
297+
298+
#[cfg(not(target_arch = "spirv"))]
299+
unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T {
300+
self.get_unchecked_mut(index)
301+
}
302+
}
303+
304+
impl<T, const N: usize> IndexUnchecked<T> for [T; N] {
305+
#[cfg(target_arch = "spirv")]
306+
unsafe fn index_unchecked(&self, index: usize) -> &T {
307+
asm!(
308+
"%val_ptr = OpAccessChain _ {array_ptr} {index}",
309+
"OpReturnValue %val_ptr",
310+
array_ptr = in(reg) self,
311+
index = in(reg) index,
312+
options(noreturn)
313+
)
314+
}
315+
316+
#[cfg(not(target_arch = "spirv"))]
317+
unsafe fn index_unchecked(&self, index: usize) -> &T {
318+
self.get_unchecked(index)
319+
}
320+
321+
#[cfg(target_arch = "spirv")]
322+
unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T {
323+
asm!(
324+
"%val_ptr = OpAccessChain _ {array_ptr} {index}",
325+
"OpReturnValue %val_ptr",
326+
array_ptr = in(reg) self,
327+
index = in(reg) index,
328+
options(noreturn)
329+
)
330+
}
331+
332+
#[cfg(not(target_arch = "spirv"))]
333+
unsafe fn index_unchecked_mut(&mut self, index: usize) -> &mut T {
334+
self.get_unchecked_mut(index)
335+
}
336+
}

tests/ui/arch/index_unchecked.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// build-pass
2+
3+
use spirv_std::arch::IndexUnchecked;
4+
5+
#[spirv(fragment)]
6+
pub fn main(
7+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] runtime_array: &mut [u32],
8+
#[spirv(descriptor_set = 1, binding = 1, storage_buffer)] array: &mut [u32; 5],
9+
) {
10+
unsafe {
11+
*runtime_array.index_unchecked_mut(0) = *array.index_unchecked(0);
12+
*array.index_unchecked_mut(1) = *runtime_array.index_unchecked(1);
13+
}
14+
}

0 commit comments

Comments
 (0)