Skip to content

Commit

Permalink
fix: stream compaction overdraw (mosure#35)
Browse files Browse the repository at this point in the history
* feat: initial work to mitigate overdraw

* fix: stream compaction propagation

* remove duplicate code
  • Loading branch information
mosure authored Nov 21, 2023
1 parent 6c44e85 commit 1a10966
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 45 deletions.
10 changes: 5 additions & 5 deletions src/render/bindings.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ struct DrawIndirect {
base_instance: u32,
}
struct SortingGlobal {
status_counters: array<array<atomic<u32>, #{RADIX_BASE}>, #{MAX_TILE_COUNT_C}>,
digit_histogram: array<array<atomic<u32>, #{RADIX_BASE}>, #{RADIX_DIGIT_PLACES}>,
assignment_counter: atomic<u32>,
}
Expand All @@ -50,7 +49,8 @@ struct Entry {

@group(3) @binding(0) var<uniform> sorting_pass_index: u32;
@group(3) @binding(1) var<storage, read_write> sorting: SortingGlobal;
@group(3) @binding(2) var<storage, read_write> draw_indirect: DrawIndirect;
@group(3) @binding(3) var<storage, read_write> input_entries: array<Entry>;
@group(3) @binding(4) var<storage, read_write> output_entries: array<Entry>;
@group(3) @binding(5) var<storage, read> sorted_entries: array<Entry>;
@group(3) @binding(2) var<storage, read_write> status_counters: array<array<atomic<u32>, #{RADIX_BASE}>>;
@group(3) @binding(3) var<storage, read_write> draw_indirect: DrawIndirect;
@group(3) @binding(4) var<storage, read_write> input_entries: array<Entry>;
@group(3) @binding(5) var<storage, read_write> output_entries: array<Entry>;
@group(3) @binding(6) var<storage, read> sorted_entries: array<Entry>;
3 changes: 2 additions & 1 deletion src/render/gaussian.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,10 @@ fn vs_points(
var output: GaussianOutput;
let splat_index = sorted_entries[instance_index][1];

let discard_quad = sorted_entries[instance_index][0] == 0xFFFFFFFFu;
let discard_quad = sorted_entries[instance_index][0] == 0xFFFFFFFFu || splat_index == 0u;
if (discard_quad) {
output.color = vec4<f32>(0.0, 0.0, 0.0, 0.0);
output.position = vec4<f32>(0.0, 0.0, 0.0, 0.0);
return output;
}

Expand Down
105 changes: 72 additions & 33 deletions src/render/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,11 @@ pub struct GpuGaussianSplattingBundle {
#[derive(Debug, Clone)]
pub struct GpuGaussianCloud {
pub gaussian_buffer: Buffer,
pub count: u32,
pub count: usize,

pub draw_indirect_buffer: Buffer,
pub sorting_global_buffer: Buffer,
pub sorting_status_counter_buffer: Buffer,
pub sorting_pass_buffers: [Buffer; 4],
pub entry_buffer_a: Buffer,
pub entry_buffer_b: Buffer,
Expand All @@ -210,16 +211,22 @@ impl RenderAsset for GaussianCloud {
usage: BufferUsages::VERTEX | BufferUsages::COPY_DST | BufferUsages::STORAGE,
});

let count = gaussian_cloud.0.len() as u32;
let count = gaussian_cloud.0.len();

// TODO: derive sorting_buffer_size from cloud count (with possible rounding to next power of 2)
let sorting_global_buffer = render_device.create_buffer(&BufferDescriptor {
label: Some("sorting global buffer"),
size: ShaderDefines::default().sorting_buffer_size as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let sorting_status_counter_buffer = render_device.create_buffer(&BufferDescriptor {
label: Some("status counters buffer"),
size: ShaderDefines::default().sorting_status_counters_buffer_size(count) as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let draw_indirect_buffer = render_device.create_buffer(&BufferDescriptor {
label: Some("draw indirect buffer"),
size: std::mem::size_of::<wgpu::util::DrawIndirect>() as u64,
Expand All @@ -241,14 +248,14 @@ impl RenderAsset for GaussianCloud {

let entry_buffer_a = render_device.create_buffer(&BufferDescriptor {
label: Some("entry buffer a"),
size: (count as usize * std::mem::size_of::<(u32, u32)>()) as u64,
size: (count * std::mem::size_of::<(u32, u32)>()) as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let entry_buffer_b = render_device.create_buffer(&BufferDescriptor {
label: Some("entry buffer b"),
size: (count as usize * std::mem::size_of::<(u32, u32)>()) as u64,
size: (count * std::mem::size_of::<(u32, u32)>()) as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
Expand All @@ -258,6 +265,7 @@ impl RenderAsset for GaussianCloud {
count,
draw_indirect_buffer,
sorting_global_buffer,
sorting_status_counter_buffer,
sorting_pass_buffers,
entry_buffer_a,
entry_buffer_b,
Expand Down Expand Up @@ -409,9 +417,20 @@ impl FromWorld for GaussianCloudPipeline {
count: None,
};

let draw_indirect_buffer_entry = BindGroupLayoutEntry {
let sorting_status_counters_buffer_entry = BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(ShaderDefines::default().sorting_status_counters_buffer_size(1) as u64),
},
count: None,
};

let draw_indirect_buffer_entry = BindGroupLayoutEntry {
binding: 3,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
Expand All @@ -434,9 +453,10 @@ impl FromWorld for GaussianCloudPipeline {
count: None,
},
sorting_buffer_entry,
sorting_status_counters_buffer_entry,
draw_indirect_buffer_entry,
BindGroupLayoutEntry {
binding: 3,
binding: 4,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
Expand All @@ -446,7 +466,7 @@ impl FromWorld for GaussianCloudPipeline {
count: None,
},
BindGroupLayoutEntry {
binding: 4,
binding: 5,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
Expand All @@ -462,7 +482,7 @@ impl FromWorld for GaussianCloudPipeline {
label: Some("sorted_layout"),
entries: &vec![
BindGroupLayoutEntry {
binding: 5,
binding: 6,
visibility: ShaderStages::VERTEX_FRAGMENT,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
Expand All @@ -474,7 +494,7 @@ impl FromWorld for GaussianCloudPipeline {
],
});

let compute_layout = vec![
let sorting_layout = vec![
view_layout.clone(),
gaussian_uniform_layout.clone(),
gaussian_cloud_layout.clone(),
Expand All @@ -485,7 +505,7 @@ impl FromWorld for GaussianCloudPipeline {
let pipeline_cache = render_world.resource::<PipelineCache>();
let radix_sort_a = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("radix_sort_a".into()),
layout: compute_layout.clone(),
layout: sorting_layout.clone(),
push_constant_ranges: vec![],
shader: RADIX_SHADER_HANDLE,
shader_defs: shader_defs.clone(),
Expand All @@ -494,7 +514,7 @@ impl FromWorld for GaussianCloudPipeline {

let radix_sort_b = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("radix_sort_b".into()),
layout: compute_layout.clone(),
layout: sorting_layout.clone(),
push_constant_ranges: vec![],
shader: RADIX_SHADER_HANDLE,
shader_defs: shader_defs.clone(),
Expand All @@ -503,7 +523,7 @@ impl FromWorld for GaussianCloudPipeline {

let radix_sort_c = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("radix_sort_c".into()),
layout: compute_layout.clone(),
layout: sorting_layout.clone(),
push_constant_ranges: vec![],
shader: RADIX_SHADER_HANDLE,
shader_defs: shader_defs.clone(),
Expand All @@ -513,7 +533,7 @@ impl FromWorld for GaussianCloudPipeline {

let temporal_sort_flip = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("temporal_sort_flip".into()),
layout: compute_layout.clone(),
layout: sorting_layout.clone(),
push_constant_ranges: vec![],
shader: TEMPORAL_SORT_SHADER_HANDLE,
shader_defs: shader_defs.clone(),
Expand All @@ -522,7 +542,7 @@ impl FromWorld for GaussianCloudPipeline {

let temporal_sort_flop = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("temporal_sort_flop".into()),
layout: compute_layout.clone(),
layout: sorting_layout.clone(),
push_constant_ranges: vec![],
shader: TEMPORAL_SORT_SHADER_HANDLE,
shader_defs: shader_defs.clone(),
Expand Down Expand Up @@ -560,12 +580,21 @@ struct ShaderDefines {
workgroup_invocations_c: u32,
workgroup_entries_a: u32,
workgroup_entries_c: u32,
max_tile_count_c: u32,
sorting_buffer_size: usize,
sorting_buffer_size: u32,

temporal_sort_window_size: u32,
}

impl ShaderDefines {
fn max_tile_count(&self, count: usize) -> u32 {
(count as u32 + self.workgroup_entries_c - 1) / self.workgroup_entries_c
}

fn sorting_status_counters_buffer_size(&self, count: usize) -> usize {
self.radix_base as usize * self.max_tile_count(count) as usize * std::mem::size_of::<u32>()
}
}

impl Default for ShaderDefines {
fn default() -> Self {
let radix_bits_per_digit = 8;
Expand All @@ -577,10 +606,8 @@ impl Default for ShaderDefines {
let workgroup_invocations_c = radix_base;
let workgroup_entries_a = workgroup_invocations_a * entries_per_invocation_a;
let workgroup_entries_c = workgroup_invocations_c * entries_per_invocation_c;
let max_tile_count_c = (10000000 + workgroup_entries_c - 1) / workgroup_entries_c;
let sorting_buffer_size = radix_base as usize *
(radix_digit_places as usize + max_tile_count_c as usize) *
std::mem::size_of::<u32>() + 5 * std::mem::size_of::<u32>();
let sorting_buffer_size = radix_base * radix_digit_places *
std::mem::size_of::<u32>() as u32 + 5 * std::mem::size_of::<u32>() as u32;

Self {
radix_bits_per_digit,
Expand All @@ -592,7 +619,6 @@ impl Default for ShaderDefines {
workgroup_invocations_c,
workgroup_entries_a,
workgroup_entries_c,
max_tile_count_c,
sorting_buffer_size,

temporal_sort_window_size: 16,
Expand All @@ -615,7 +641,6 @@ fn shader_defs(
ShaderDefVal::UInt("WORKGROUP_INVOCATIONS_A".into(), defines.workgroup_invocations_a),
ShaderDefVal::UInt("WORKGROUP_INVOCATIONS_C".into(), defines.workgroup_invocations_c),
ShaderDefVal::UInt("WORKGROUP_ENTRIES_C".into(), defines.workgroup_entries_c),
ShaderDefVal::UInt("MAX_TILE_COUNT_C".into(), defines.max_tile_count_c),

ShaderDefVal::UInt("TEMPORAL_SORT_WINDOW_SIZE".into(), defines.temporal_sort_window_size),
];
Expand Down Expand Up @@ -833,8 +858,17 @@ pub fn queue_gaussian_bind_group(
}),
};

let draw_indirect_entry = BindGroupEntry {
let sorting_status_counters_entry = BindGroupEntry {
binding: 2,
resource: BindingResource::Buffer(BufferBinding {
buffer: &cloud.sorting_status_counter_buffer,
offset: 0,
size: BufferSize::new(cloud.sorting_status_counter_buffer.size()),
}),
};

let draw_indirect_entry = BindGroupEntry {
binding: 3,
resource: BindingResource::Buffer(BufferBinding {
buffer: &cloud.draw_indirect_buffer,
offset: 0,
Expand All @@ -857,9 +891,10 @@ pub fn queue_gaussian_bind_group(
}),
},
sorting_global_entry.clone(),
sorting_status_counters_entry.clone(),
draw_indirect_entry.clone(),
BindGroupEntry {
binding: 3,
binding: 4,
resource: BindingResource::Buffer(BufferBinding {
buffer: if idx % 2 == 0 {
&cloud.entry_buffer_a
Expand All @@ -871,7 +906,7 @@ pub fn queue_gaussian_bind_group(
}),
},
BindGroupEntry {
binding: 4,
binding: 5,
resource: BindingResource::Buffer(BufferBinding {
buffer: if idx % 2 == 0 {
&cloud.entry_buffer_b
Expand Down Expand Up @@ -910,7 +945,7 @@ pub fn queue_gaussian_bind_group(
&gaussian_cloud_pipeline.sorted_layout,
&[
BindGroupEntry {
binding: 5,
binding: 6,
resource: BindingResource::Buffer(BufferBinding {
buffer: &cloud.entry_buffer_a,
offset: 0,
Expand Down Expand Up @@ -1173,6 +1208,12 @@ impl render_graph::Node for RadixSortNode {
None,
);

command_encoder.clear_buffer(
&cloud.sorting_status_counter_buffer,
0,
None,
);

command_encoder.clear_buffer(
&cloud.draw_indirect_buffer,
0,
Expand Down Expand Up @@ -1208,7 +1249,7 @@ impl render_graph::Node for RadixSortNode {
pass.set_pipeline(radix_sort_a);

let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a;
pass.dispatch_workgroups((cloud.count + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);
pass.dispatch_workgroups((cloud.count as u32 + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);


let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
Expand All @@ -1219,12 +1260,10 @@ impl render_graph::Node for RadixSortNode {

for pass_idx in 0..radix_digit_places {
if pass_idx > 0 {
// clear SortingGlobal.status_counters
let size = (ShaderDefines::default().radix_base * ShaderDefines::default().max_tile_count_c) as u64 * std::mem::size_of::<u32>() as u64;
command_encoder.clear_buffer(
&cloud.sorting_global_buffer,
&cloud.sorting_status_counter_buffer,
0,
std::num::NonZeroU64::new(size).unwrap().into()
None,
);
}

Expand Down Expand Up @@ -1255,7 +1294,7 @@ impl render_graph::Node for RadixSortNode {
);

let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c;
pass.dispatch_workgroups(1, (cloud.count + workgroup_entries_c - 1) / workgroup_entries_c, 1);
pass.dispatch_workgroups(1, (cloud.count as u32 + workgroup_entries_c - 1) / workgroup_entries_c, 1);
}
}
}
Expand Down
Loading

0 comments on commit 1a10966

Please sign in to comment.