Skip to content

Commit 703c7d2

Browse files
Naga mesh shader SPIR-V writer (#8456)
Co-authored-by: Inner Daemons <magnus.larsson.mn@gmail.com> Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
1 parent 6dd69b0 commit 703c7d2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+2554
-5521
lines changed

CHANGELOG.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,71 @@ Bottom level categories:
4343

4444
### Major Changes
4545

46+
#### Added support for mesh shaders
47+
48+
This has been a long time coming. See [the tracking issue](https://github.com/gfx-rs/wgpu/issues/7197) for more information.
49+
They are now fully supported on Vulkan, and supported on Metal and DX12 with passthrough shaders. WGSL parsing and rewriting
50+
is supported, meaning they can be used through WESL or naga_oil.
51+
52+
Mesh shader pipelines replace standard vertex shader pipelines and allow new ways to render meshes. They form the core of
53+
some rendering engines, including Unreal Engine's nanite. This is because they are ideal for meshlet rendering, a form
54+
of rendering where small groups of triangles are handled together, for culling and for rendering.
55+
56+
The core idea is that compute-like shaders will generate primitives directly that will then be passed to the rasterizer, rather
57+
than having a list of vertices generated individually and then using a static index buffer. This means that certain computations
58+
on nearby groups of triangles can be done together, the relationship between vertices and primitives is more programmable, and
59+
you can even pass non-interpolated per-primitive data to the fragment shader, independent of vertices.
60+
61+
Mesh shaders are very versatile, and are powerful enough to replace vertex shaders, tesselation shaders, and geometry shaders
62+
on their own or with task shaders.
63+
64+
A full example of mesh shaders in use can be seen in the `mesh_shader` example. Below is a small snippet of shader code
65+
demonstrating their usage:
66+
```wgsl
67+
@task
68+
@payload(taskPayload)
69+
@workgroup_size(1)
70+
fn ts_main() -> @builtin(mesh_task_size) vec3<u32> {
71+
// Task shaders can use workgroup variables like compute shaders
72+
workgroupData = 1.0;
73+
// Pass some data to all mesh shaders dispatched by this workgroup
74+
taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0);
75+
taskPayload.visible = 1;
76+
// Dispatch a mesh shader grid with one workgroup
77+
return vec3(1, 1, 1);
78+
}
79+
80+
@mesh(mesh_output)
81+
@payload(taskPayload)
82+
@workgroup_size(1)
83+
fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3<u32>) {
84+
// Set how many outputs this workgroup will generate
85+
mesh_output.vertex_count = 3;
86+
mesh_output.primitive_count = 1;
87+
// Can also use workgroup variables
88+
workgroupData = 2.0;
89+
90+
// Set vertex outputs
91+
mesh_output.vertices[0].position = positions[0];
92+
mesh_output.vertices[0].color = colors[0] * taskPayload.colorMask;
93+
94+
mesh_output.vertices[1].position = positions[1];
95+
mesh_output.vertices[1].color = colors[1] * taskPayload.colorMask;
96+
97+
mesh_output.vertices[2].position = positions[2];
98+
mesh_output.vertices[2].color = colors[2] * taskPayload.colorMask;
99+
100+
// Set the vertex indices for the only primitive
101+
mesh_output.primitives[0].indices = vec3<u32>(0, 1, 2);
102+
// Cull it if the data passed by the task shader says to
103+
mesh_output.primitives[0].cull = taskPayload.visible == 1;
104+
// Give a noninterpolated per-primitive vec4 to the fragment shader
105+
mesh_output.primitives[0].colorMask = vec4<f32>(1.0, 0.0, 1.0, 1.0);
106+
}
107+
```
108+
109+
See other changes in this changelog for more information.
110+
46111
#### Switch from `gpu-alloc` to `gpu-allocator` in the `vulkan` backend
47112

48113
`gpu-allocator` is the allocator used in the `dx12` backend, allowing to configure
@@ -284,6 +349,7 @@ By @cwfitzgerald in [#8609](https://github.com/gfx-rs/wgpu/pull/8609).
284349
#### Vulkan
285350

286351
- Fixed a validation error regarding atomic memory semantics. By @atlv24 in [#8391](https://github.com/gfx-rs/wgpu/pull/8391).
352+
- Add mesh shader writer support, allowing WGSL shaders to be used on the vulkan backend. Only works on NVIDIA and Intel GPUs. By @inner-daemons in [#8456](https://github.com/gfx-rs/wgpu/pull/8456).
287353

288354
#### Metal
289355
- Fixed a variety of feature detection related bugs. By @inner-daemons in [#8439](https://github.com/gfx-rs/wgpu/pull/8439).

docs/api-specs/mesh_shading.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,13 @@ struct MeshOutput {
241241
@builtin(vertex_count) vertex_count: u32,
242242
@builtin(primitive_count) primitive_count: u32,
243243
}
244+
244245
var<workgroup> mesh_output: MeshOutput;
245246
246247
@mesh(mesh_output)
247248
@payload(taskPayload)
248249
@workgroup_size(1)
249-
fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3<u32>) {
250+
fn ms_main() {
250251
mesh_output.vertex_count = 3;
251252
mesh_output.primitive_count = 1;
252253
workgroupData = 2.0;

examples/features/src/mesh_shader/mod.rs

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,9 @@
1-
use std::process::Stdio;
2-
31
// Same as in mesh shader tests
4-
fn compile_glsl(device: &wgpu::Device, shader_stage: &'static str) -> wgpu::ShaderModule {
5-
let cmd = std::process::Command::new("glslc")
6-
.args([
7-
&format!(
8-
"{}/src/mesh_shader/shader.{shader_stage}",
9-
env!("CARGO_MANIFEST_DIR")
10-
),
11-
"-o",
12-
"-",
13-
"--target-env=vulkan1.2",
14-
"--target-spv=spv1.4",
15-
])
16-
.stdin(Stdio::piped())
17-
.stdout(Stdio::piped())
18-
.spawn()
19-
.expect("Failed to call glslc");
20-
let output = cmd.wait_with_output().expect("Error waiting for glslc");
21-
assert!(output.status.success());
22-
unsafe {
23-
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
24-
entry_point: "main".into(),
25-
label: None,
26-
spirv: Some(wgpu::util::make_spirv_raw(&output.stdout)),
27-
..Default::default()
28-
})
29-
}
2+
fn compile_wgsl(device: &wgpu::Device) -> wgpu::ShaderModule {
3+
device.create_shader_module(wgpu::ShaderModuleDescriptor {
4+
label: None,
5+
source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()),
6+
})
307
}
318
fn compile_hlsl(device: &wgpu::Device, entry: &str, stage_str: &str) -> wgpu::ShaderModule {
329
let out_path = format!(
@@ -83,21 +60,30 @@ impl crate::framework::Example for Example {
8360
device: &wgpu::Device,
8461
_queue: &wgpu::Queue,
8562
) -> Self {
86-
let (ts, ms, fs) = match adapter.get_info().backend {
63+
let (ts, ms, fs, ts_name, ms_name, fs_name) = match adapter.get_info().backend {
8764
wgpu::Backend::Vulkan => (
88-
compile_glsl(device, "task"),
89-
compile_glsl(device, "mesh"),
90-
compile_glsl(device, "frag"),
65+
compile_wgsl(device),
66+
compile_wgsl(device),
67+
compile_wgsl(device),
68+
"ts_main",
69+
"ms_main",
70+
"fs_main",
9171
),
9272
wgpu::Backend::Dx12 => (
9373
compile_hlsl(device, "Task", "as"),
9474
compile_hlsl(device, "Mesh", "ms"),
9575
compile_hlsl(device, "Frag", "ps"),
76+
"main",
77+
"main",
78+
"main",
9679
),
9780
wgpu::Backend::Metal => (
9881
compile_msl(device, "taskShader"),
9982
compile_msl(device, "meshShader"),
10083
compile_msl(device, "fragShader"),
84+
"main",
85+
"main",
86+
"main",
10187
),
10288
_ => panic!("Example can currently only run on vulkan, dx12 or metal"),
10389
};
@@ -111,17 +97,17 @@ impl crate::framework::Example for Example {
11197
layout: Some(&pipeline_layout),
11298
task: Some(wgpu::TaskState {
11399
module: &ts,
114-
entry_point: Some("main"),
100+
entry_point: Some(ts_name),
115101
compilation_options: Default::default(),
116102
}),
117103
mesh: wgpu::MeshState {
118104
module: &ms,
119-
entry_point: Some("main"),
105+
entry_point: Some(ms_name),
120106
compilation_options: Default::default(),
121107
},
122108
fragment: Some(wgpu::FragmentState {
123109
module: &fs,
124-
entry_point: Some("main"),
110+
entry_point: Some(fs_name),
125111
compilation_options: Default::default(),
126112
targets: &[Some(config.view_formats[0].into())],
127113
}),
@@ -208,7 +194,17 @@ pub static TEST: crate::framework::ExampleTestParams = crate::framework::Example
208194
wgpu::Features::EXPERIMENTAL_MESH_SHADER
209195
| wgpu::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS,
210196
)
211-
.limits(wgpu::Limits::defaults().using_recommended_minimum_mesh_shader_values()),
212-
comparisons: &[wgpu_test::ComparisonType::Mean(0.01)],
197+
.instance_flags(wgpu::InstanceFlags::advanced_debugging())
198+
.limits(wgpu::Limits::defaults().using_recommended_minimum_mesh_shader_values())
199+
.skip(wgpu_test::FailureCase {
200+
backends: None,
201+
// Skip Mesa because LLVMPIPE has what is believed to be a driver bug
202+
vendor: Some(0x10005),
203+
adapter: None,
204+
driver: None,
205+
reasons: vec![],
206+
behavior: wgpu_test::FailureBehavior::Ignore,
207+
}),
208+
comparisons: &[wgpu_test::ComparisonType::Mean(0.005)],
213209
_phantom: std::marker::PhantomData::<Example>,
214210
};

examples/features/src/mesh_shader/shader.frag

Lines changed: 0 additions & 11 deletions
This file was deleted.

examples/features/src/mesh_shader/shader.mesh

Lines changed: 0 additions & 38 deletions
This file was deleted.

examples/features/src/mesh_shader/shader.task

Lines changed: 0 additions & 16 deletions
This file was deleted.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
enable wgpu_mesh_shader;
2+
3+
const positions = array(
4+
vec4(0., 1., 0., 1.),
5+
vec4(-1., -1., 0., 1.),
6+
vec4(1., -1., 0., 1.)
7+
);
8+
const colors = array(
9+
vec4(0., 1., 0., 1.),
10+
vec4(0., 0., 1., 1.),
11+
vec4(1., 0., 0., 1.)
12+
);
13+
14+
struct TaskPayload {
15+
colorMask: vec4<f32>,
16+
visible: bool,
17+
}
18+
struct VertexOutput {
19+
@builtin(position) position: vec4<f32>,
20+
@location(0) color: vec4<f32>,
21+
}
22+
struct PrimitiveOutput {
23+
@builtin(triangle_indices) indices: vec3<u32>,
24+
@builtin(cull_primitive) cull: bool,
25+
@per_primitive @location(1) colorMask: vec4<f32>,
26+
}
27+
struct PrimitiveInput {
28+
@per_primitive @location(1) colorMask: vec4<f32>,
29+
}
30+
31+
var<task_payload> taskPayload: TaskPayload;
32+
var<workgroup> workgroupData: f32;
33+
34+
@task
35+
@payload(taskPayload)
36+
@workgroup_size(1)
37+
fn ts_main() -> @builtin(mesh_task_size) vec3<u32> {
38+
workgroupData = 1.0;
39+
taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0);
40+
taskPayload.visible = true;
41+
return vec3(1, 1, 1);
42+
}
43+
44+
struct MeshOutput {
45+
@builtin(vertices) vertices: array<VertexOutput, 3>,
46+
@builtin(primitives) primitives: array<PrimitiveOutput, 1>,
47+
@builtin(vertex_count) vertex_count: u32,
48+
@builtin(primitive_count) primitive_count: u32,
49+
}
50+
51+
var<workgroup> mesh_output: MeshOutput;
52+
53+
@mesh(mesh_output)
54+
@payload(taskPayload)
55+
@workgroup_size(1)
56+
fn ms_main() {
57+
mesh_output.vertex_count = 3;
58+
mesh_output.primitive_count = 1;
59+
workgroupData = 2.0;
60+
61+
mesh_output.vertices[0].position = positions[0];
62+
mesh_output.vertices[0].color = colors[0] * taskPayload.colorMask;
63+
64+
mesh_output.vertices[1].position = positions[1];
65+
mesh_output.vertices[1].color = colors[1] * taskPayload.colorMask;
66+
67+
mesh_output.vertices[2].position = positions[2];
68+
mesh_output.vertices[2].color = colors[2] * taskPayload.colorMask;
69+
70+
mesh_output.primitives[0].indices = vec3<u32>(0, 1, 2);
71+
mesh_output.primitives[0].cull = !taskPayload.visible;
72+
mesh_output.primitives[0].colorMask = vec4<f32>(1.0, 0.0, 1.0, 1.0);
73+
}
74+
75+
@fragment
76+
fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4<f32> {
77+
return vertex.color * primitive.colorMask;
78+
}

0 commit comments

Comments
 (0)