Skip to content

Commit a240fb7

Browse files
committed
scalar or vector: add #[derive(ScalarOrVectorComposite)]
1 parent df16a21 commit a240fb7

File tree

6 files changed

+179
-0
lines changed

6 files changed

+179
-0
lines changed

crates/spirv-std/macros/src/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
mod debug_printf;
7575
mod image;
7676
mod sample_param_permutations;
77+
mod scalar_or_vector_composite;
7778

7879
use crate::debug_printf::{DebugPrintfInput, debug_printf_inner};
7980
use proc_macro::TokenStream;
@@ -311,3 +312,10 @@ pub fn debug_printfln(input: TokenStream) -> TokenStream {
311312
pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
312313
sample_param_permutations::gen_sample_param_permutations(item)
313314
}
315+
316+
#[proc_macro_derive(ScalarOrVectorComposite)]
317+
pub fn derive_scalar_or_vector_composite(item: TokenStream) -> TokenStream {
318+
scalar_or_vector_composite::derive(item.into())
319+
.unwrap_or_else(syn::Error::into_compile_error)
320+
.into()
321+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
use proc_macro2::TokenStream;
2+
use quote::{ToTokens, quote};
3+
use syn::punctuated::Punctuated;
4+
use syn::{Fields, FieldsNamed, FieldsUnnamed, GenericParam, Token};
5+
6+
pub fn derive(item: TokenStream) -> syn::Result<TokenStream> {
7+
// Whenever we'll properly resolve the crate symbol, replace this.
8+
let spirv_std = quote!(spirv_std);
9+
10+
// Defer all validation to our codegen backend. Rather than erroring here, emit garbage.
11+
let item = syn::parse2::<syn::ItemStruct>(item)?;
12+
let struct_ident = &item.ident;
13+
let gens = &item.generics.params;
14+
let gen_refs = &item
15+
.generics
16+
.params
17+
.iter()
18+
.map(|p| match p {
19+
GenericParam::Lifetime(p) => p.lifetime.to_token_stream(),
20+
GenericParam::Type(p) => p.ident.to_token_stream(),
21+
GenericParam::Const(p) => p.ident.to_token_stream(),
22+
})
23+
.collect::<Punctuated<_, Token![,]>>();
24+
let where_clause = &item.generics.where_clause;
25+
26+
let content =
27+
match item.fields {
28+
Fields::Named(FieldsNamed { named, .. }) => {
29+
let content = named.iter().map(|f| {
30+
let ident = &f.ident;
31+
quote!(#ident: #spirv_std::ScalarOrVectorComposite::transform(self.#ident, f))
32+
}).collect::<Punctuated<_, Token![,]>>();
33+
quote!(Self { #content })
34+
}
35+
Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
36+
let content = (0..unnamed.len())
37+
.map(|i| {
38+
let i = syn::Index::from(i);
39+
quote!(#spirv_std::ScalarOrVectorComposite::transform(self.#i, f))
40+
})
41+
.collect::<Punctuated<_, Token![,]>>();
42+
quote!(Self(#content))
43+
}
44+
Fields::Unit => quote!(Self),
45+
};
46+
47+
Ok(quote! {
48+
impl<#gens> #spirv_std::ScalarOrVectorComposite for #struct_ident<#gen_refs> #where_clause {
49+
#[inline]
50+
fn transform<F: #spirv_std::ScalarOrVectorTransform>(self, f: &mut F) -> Self {
51+
#content
52+
}
53+
}
54+
})
55+
}

crates/spirv-std/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
/// Public re-export of the `spirv-std-macros` crate.
8888
#[macro_use]
8989
pub extern crate spirv_std_macros as macros;
90+
pub use macros::ScalarOrVectorComposite;
9091
pub use macros::spirv;
9192
pub use macros::{debug_printf, debug_printfln};
9293

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// build-pass
2+
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformBallot,+GroupNonUniformShuffle,+GroupNonUniformShuffleRelative,+ext:SPV_KHR_vulkan_memory_model
3+
// normalize-stderr-test "OpLine .*\n" -> ""
4+
// ignore-vulkan1.0
5+
// ignore-vulkan1.1
6+
// ignore-spv1.0
7+
// ignore-spv1.1
8+
// ignore-spv1.2
9+
// ignore-spv1.3
10+
// ignore-spv1.4
11+
12+
use glam::*;
13+
use spirv_std::ScalarOrVectorComposite;
14+
use spirv_std::arch::*;
15+
use spirv_std::spirv;
16+
17+
#[derive(Copy, Clone, ScalarOrVectorComposite)]
18+
pub struct MyStruct {
19+
a: f32,
20+
b: UVec3,
21+
c: Nested,
22+
d: Zst,
23+
}
24+
25+
#[derive(Copy, Clone, ScalarOrVectorComposite)]
26+
pub struct Nested(i32);
27+
28+
#[derive(Copy, Clone, ScalarOrVectorComposite)]
29+
pub struct Zst;
30+
31+
#[spirv(compute(threads(32)))]
32+
pub fn main(
33+
#[spirv(local_invocation_index)] inv_id: UVec3,
34+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut UVec3,
35+
) {
36+
unsafe {
37+
let my_struct = MyStruct {
38+
a: 1.,
39+
b: inv_id,
40+
c: Nested(-42),
41+
d: Zst,
42+
};
43+
44+
let mut out = UVec3::ZERO;
45+
// before spv1.5 / vulkan1.2, this id = 19 must be a constant
46+
out += subgroup_broadcast(my_struct, 19).b;
47+
out += subgroup_broadcast_first(my_struct).b;
48+
out += subgroup_shuffle(my_struct, 2).b;
49+
out += subgroup_shuffle_xor(my_struct, 4).b;
50+
out += subgroup_shuffle_up(my_struct, 5).b;
51+
out += subgroup_shuffle_down(my_struct, 7).b;
52+
*output = out;
53+
}
54+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// build-pass
2+
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformShuffle,+ext:SPV_KHR_vulkan_memory_model
3+
// compile-flags: -C llvm-args=--disassemble-fn=subgroup_composite_shuffle::disassembly
4+
// normalize-stderr-test "OpLine .*\n" -> ""
5+
6+
use glam::*;
7+
use spirv_std::ScalarOrVectorComposite;
8+
use spirv_std::arch::*;
9+
use spirv_std::spirv;
10+
11+
#[derive(Copy, Clone, ScalarOrVectorComposite)]
12+
pub struct MyStruct {
13+
a: f32,
14+
b: UVec3,
15+
c: Nested,
16+
d: Zst,
17+
}
18+
19+
#[derive(Copy, Clone, ScalarOrVectorComposite)]
20+
pub struct Nested(i32);
21+
22+
#[derive(Copy, Clone, ScalarOrVectorComposite)]
23+
pub struct Zst;
24+
25+
/// this should be 3 `subgroup_shuffle` instructions, with all calls inlined
26+
fn disassembly(my_struct: MyStruct, id: u32) -> MyStruct {
27+
subgroup_shuffle(my_struct, id)
28+
}
29+
30+
#[spirv(compute(threads(32)))]
31+
pub fn main(
32+
#[spirv(local_invocation_index)] inv_id: UVec3,
33+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut MyStruct,
34+
) {
35+
unsafe {
36+
let my_struct = MyStruct {
37+
a: inv_id.x as f32,
38+
b: inv_id,
39+
c: Nested(5i32 - inv_id.x as i32),
40+
d: Zst,
41+
};
42+
43+
*output = disassembly(my_struct, 5);
44+
}
45+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
%1 = OpFunction %2 None %3
2+
%4 = OpFunctionParameter %2
3+
%5 = OpFunctionParameter %6
4+
%7 = OpLabel
5+
%9 = OpCompositeExtract %10 %4 0
6+
%12 = OpGroupNonUniformShuffle %10 %13 %9 %5
7+
%14 = OpCompositeExtract %15 %4 1
8+
%16 = OpGroupNonUniformShuffle %15 %13 %14 %5
9+
%17 = OpCompositeExtract %18 %4 2
10+
%19 = OpGroupNonUniformShuffle %18 %13 %17 %5
11+
%20 = OpCompositeInsert %2 %12 %21 0
12+
%22 = OpCompositeInsert %2 %16 %20 1
13+
%23 = OpCompositeInsert %2 %19 %22 2
14+
OpNoLine
15+
OpReturnValue %23
16+
OpFunctionEnd

0 commit comments

Comments
 (0)