diff --git a/src/render/bindings.wgsl b/src/render/bindings.wgsl index a6c242ef..33690a4d 100644 --- a/src/render/bindings.wgsl +++ b/src/render/bindings.wgsl @@ -31,7 +31,6 @@ struct DrawIndirect { base_instance: u32, } struct SortingGlobal { - status_counters: array, #{RADIX_BASE}>, #{MAX_TILE_COUNT_C}>, digit_histogram: array, #{RADIX_BASE}>, #{RADIX_DIGIT_PLACES}>, assignment_counter: atomic, } @@ -50,7 +49,8 @@ struct Entry { @group(3) @binding(0) var sorting_pass_index: u32; @group(3) @binding(1) var sorting: SortingGlobal; -@group(3) @binding(2) var draw_indirect: DrawIndirect; -@group(3) @binding(3) var input_entries: array; -@group(3) @binding(4) var output_entries: array; -@group(3) @binding(5) var sorted_entries: array; +@group(3) @binding(2) var status_counters: array, #{RADIX_BASE}>>; +@group(3) @binding(3) var draw_indirect: DrawIndirect; +@group(3) @binding(4) var input_entries: array; +@group(3) @binding(5) var output_entries: array; +@group(3) @binding(6) var sorted_entries: array; diff --git a/src/render/gaussian.wgsl b/src/render/gaussian.wgsl index 1c28873d..b71e1c97 100644 --- a/src/render/gaussian.wgsl +++ b/src/render/gaussian.wgsl @@ -200,9 +200,10 @@ fn vs_points( var output: GaussianOutput; let splat_index = sorted_entries[instance_index][1]; - let discard_quad = sorted_entries[instance_index][0] == 0xFFFFFFFFu; + let discard_quad = sorted_entries[instance_index][0] == 0xFFFFFFFFu || splat_index == 0u; if (discard_quad) { output.color = vec4(0.0, 0.0, 0.0, 0.0); + output.position = vec4(0.0, 0.0, 0.0, 0.0); return output; } diff --git a/src/render/mod.rs b/src/render/mod.rs index ecf675b0..2b93167a 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -183,10 +183,11 @@ pub struct GpuGaussianSplattingBundle { #[derive(Debug, Clone)] pub struct GpuGaussianCloud { pub gaussian_buffer: Buffer, - pub count: u32, + pub count: usize, pub draw_indirect_buffer: Buffer, pub sorting_global_buffer: Buffer, + pub sorting_status_counter_buffer: Buffer, pub sorting_pass_buffers: [Buffer; 4], pub entry_buffer_a: Buffer, pub entry_buffer_b: Buffer, @@ -210,9 +211,8 @@ impl RenderAsset for GaussianCloud { usage: BufferUsages::VERTEX | BufferUsages::COPY_DST | BufferUsages::STORAGE, }); - let count = gaussian_cloud.0.len() as u32; + let count = gaussian_cloud.0.len(); - // TODO: derive sorting_buffer_size from cloud count (with possible rounding to next power of 2) let sorting_global_buffer = render_device.create_buffer(&BufferDescriptor { label: Some("sorting global buffer"), size: ShaderDefines::default().sorting_buffer_size as u64, @@ -220,6 +220,13 @@ impl RenderAsset for GaussianCloud { mapped_at_creation: false, }); + let sorting_status_counter_buffer = render_device.create_buffer(&BufferDescriptor { + label: Some("status counters buffer"), + size: ShaderDefines::default().sorting_status_counters_buffer_size(count) as u64, + usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + let draw_indirect_buffer = render_device.create_buffer(&BufferDescriptor { label: Some("draw indirect buffer"), size: std::mem::size_of::() as u64, @@ -241,14 +248,14 @@ impl RenderAsset for GaussianCloud { let entry_buffer_a = render_device.create_buffer(&BufferDescriptor { label: Some("entry buffer a"), - size: (count as usize * std::mem::size_of::<(u32, u32)>()) as u64, + size: (count * std::mem::size_of::<(u32, u32)>()) as u64, usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC, mapped_at_creation: false, }); let entry_buffer_b = render_device.create_buffer(&BufferDescriptor { label: Some("entry buffer b"), - size: (count as usize * std::mem::size_of::<(u32, u32)>()) as u64, + size: (count * std::mem::size_of::<(u32, u32)>()) as u64, usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC, mapped_at_creation: false, }); @@ -258,6 +265,7 @@ impl RenderAsset for GaussianCloud { count, draw_indirect_buffer, sorting_global_buffer, + sorting_status_counter_buffer, sorting_pass_buffers, entry_buffer_a, entry_buffer_b, @@ -409,9 +417,20 @@ impl FromWorld for GaussianCloudPipeline { count: None, }; - let draw_indirect_buffer_entry = BindGroupLayoutEntry { + let sorting_status_counters_buffer_entry = BindGroupLayoutEntry { binding: 2, visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: BufferSize::new(ShaderDefines::default().sorting_status_counters_buffer_size(1) as u64), + }, + count: None, + }; + + let draw_indirect_buffer_entry = BindGroupLayoutEntry { + binding: 3, + visibility: ShaderStages::COMPUTE, ty: BindingType::Buffer { ty: BufferBindingType::Storage { read_only: false }, has_dynamic_offset: false, @@ -434,9 +453,10 @@ impl FromWorld for GaussianCloudPipeline { count: None, }, sorting_buffer_entry, + sorting_status_counters_buffer_entry, draw_indirect_buffer_entry, BindGroupLayoutEntry { - binding: 3, + binding: 4, visibility: ShaderStages::COMPUTE, ty: BindingType::Buffer { ty: BufferBindingType::Storage { read_only: false }, @@ -446,7 +466,7 @@ impl FromWorld for GaussianCloudPipeline { count: None, }, BindGroupLayoutEntry { - binding: 4, + binding: 5, visibility: ShaderStages::COMPUTE, ty: BindingType::Buffer { ty: BufferBindingType::Storage { read_only: false }, @@ -462,7 +482,7 @@ impl FromWorld for GaussianCloudPipeline { label: Some("sorted_layout"), entries: &vec![ BindGroupLayoutEntry { - binding: 5, + binding: 6, visibility: ShaderStages::VERTEX_FRAGMENT, ty: BindingType::Buffer { ty: BufferBindingType::Storage { read_only: true }, @@ -474,7 +494,7 @@ impl FromWorld for GaussianCloudPipeline { ], }); - let compute_layout = vec![ + let sorting_layout = vec![ view_layout.clone(), gaussian_uniform_layout.clone(), gaussian_cloud_layout.clone(), @@ -485,7 +505,7 @@ impl FromWorld for GaussianCloudPipeline { let pipeline_cache = render_world.resource::(); let radix_sort_a = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { label: Some("radix_sort_a".into()), - layout: compute_layout.clone(), + layout: sorting_layout.clone(), push_constant_ranges: vec![], shader: RADIX_SHADER_HANDLE, shader_defs: shader_defs.clone(), @@ -494,7 +514,7 @@ impl FromWorld for GaussianCloudPipeline { let radix_sort_b = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { label: Some("radix_sort_b".into()), - layout: compute_layout.clone(), + layout: sorting_layout.clone(), push_constant_ranges: vec![], shader: RADIX_SHADER_HANDLE, shader_defs: shader_defs.clone(), @@ -503,7 +523,7 @@ impl FromWorld for GaussianCloudPipeline { let radix_sort_c = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { label: Some("radix_sort_c".into()), - layout: compute_layout.clone(), + layout: sorting_layout.clone(), push_constant_ranges: vec![], shader: RADIX_SHADER_HANDLE, shader_defs: shader_defs.clone(), @@ -513,7 +533,7 @@ impl FromWorld for GaussianCloudPipeline { let temporal_sort_flip = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { label: Some("temporal_sort_flip".into()), - layout: compute_layout.clone(), + layout: sorting_layout.clone(), push_constant_ranges: vec![], shader: TEMPORAL_SORT_SHADER_HANDLE, shader_defs: shader_defs.clone(), @@ -522,7 +542,7 @@ impl FromWorld for GaussianCloudPipeline { let temporal_sort_flop = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { label: Some("temporal_sort_flop".into()), - layout: compute_layout.clone(), + layout: sorting_layout.clone(), push_constant_ranges: vec![], shader: TEMPORAL_SORT_SHADER_HANDLE, shader_defs: shader_defs.clone(), @@ -560,12 +580,21 @@ struct ShaderDefines { workgroup_invocations_c: u32, workgroup_entries_a: u32, workgroup_entries_c: u32, - max_tile_count_c: u32, - sorting_buffer_size: usize, + sorting_buffer_size: u32, temporal_sort_window_size: u32, } +impl ShaderDefines { + fn max_tile_count(&self, count: usize) -> u32 { + (count as u32 + self.workgroup_entries_c - 1) / self.workgroup_entries_c + } + + fn sorting_status_counters_buffer_size(&self, count: usize) -> usize { + self.radix_base as usize * self.max_tile_count(count) as usize * std::mem::size_of::() + } +} + impl Default for ShaderDefines { fn default() -> Self { let radix_bits_per_digit = 8; @@ -577,10 +606,8 @@ impl Default for ShaderDefines { let workgroup_invocations_c = radix_base; let workgroup_entries_a = workgroup_invocations_a * entries_per_invocation_a; let workgroup_entries_c = workgroup_invocations_c * entries_per_invocation_c; - let max_tile_count_c = (10000000 + workgroup_entries_c - 1) / workgroup_entries_c; - let sorting_buffer_size = radix_base as usize * - (radix_digit_places as usize + max_tile_count_c as usize) * - std::mem::size_of::() + 5 * std::mem::size_of::(); + let sorting_buffer_size = radix_base * radix_digit_places * + std::mem::size_of::() as u32 + 5 * std::mem::size_of::() as u32; Self { radix_bits_per_digit, @@ -592,7 +619,6 @@ impl Default for ShaderDefines { workgroup_invocations_c, workgroup_entries_a, workgroup_entries_c, - max_tile_count_c, sorting_buffer_size, temporal_sort_window_size: 16, @@ -615,7 +641,6 @@ fn shader_defs( ShaderDefVal::UInt("WORKGROUP_INVOCATIONS_A".into(), defines.workgroup_invocations_a), ShaderDefVal::UInt("WORKGROUP_INVOCATIONS_C".into(), defines.workgroup_invocations_c), ShaderDefVal::UInt("WORKGROUP_ENTRIES_C".into(), defines.workgroup_entries_c), - ShaderDefVal::UInt("MAX_TILE_COUNT_C".into(), defines.max_tile_count_c), ShaderDefVal::UInt("TEMPORAL_SORT_WINDOW_SIZE".into(), defines.temporal_sort_window_size), ]; @@ -833,8 +858,17 @@ pub fn queue_gaussian_bind_group( }), }; - let draw_indirect_entry = BindGroupEntry { + let sorting_status_counters_entry = BindGroupEntry { binding: 2, + resource: BindingResource::Buffer(BufferBinding { + buffer: &cloud.sorting_status_counter_buffer, + offset: 0, + size: BufferSize::new(cloud.sorting_status_counter_buffer.size()), + }), + }; + + let draw_indirect_entry = BindGroupEntry { + binding: 3, resource: BindingResource::Buffer(BufferBinding { buffer: &cloud.draw_indirect_buffer, offset: 0, @@ -857,9 +891,10 @@ pub fn queue_gaussian_bind_group( }), }, sorting_global_entry.clone(), + sorting_status_counters_entry.clone(), draw_indirect_entry.clone(), BindGroupEntry { - binding: 3, + binding: 4, resource: BindingResource::Buffer(BufferBinding { buffer: if idx % 2 == 0 { &cloud.entry_buffer_a @@ -871,7 +906,7 @@ pub fn queue_gaussian_bind_group( }), }, BindGroupEntry { - binding: 4, + binding: 5, resource: BindingResource::Buffer(BufferBinding { buffer: if idx % 2 == 0 { &cloud.entry_buffer_b @@ -910,7 +945,7 @@ pub fn queue_gaussian_bind_group( &gaussian_cloud_pipeline.sorted_layout, &[ BindGroupEntry { - binding: 5, + binding: 6, resource: BindingResource::Buffer(BufferBinding { buffer: &cloud.entry_buffer_a, offset: 0, @@ -1173,6 +1208,12 @@ impl render_graph::Node for RadixSortNode { None, ); + command_encoder.clear_buffer( + &cloud.sorting_status_counter_buffer, + 0, + None, + ); + command_encoder.clear_buffer( &cloud.draw_indirect_buffer, 0, @@ -1208,7 +1249,7 @@ impl render_graph::Node for RadixSortNode { pass.set_pipeline(radix_sort_a); let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a; - pass.dispatch_workgroups((cloud.count + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1); + pass.dispatch_workgroups((cloud.count as u32 + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1); let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap(); @@ -1219,12 +1260,10 @@ impl render_graph::Node for RadixSortNode { for pass_idx in 0..radix_digit_places { if pass_idx > 0 { - // clear SortingGlobal.status_counters - let size = (ShaderDefines::default().radix_base * ShaderDefines::default().max_tile_count_c) as u64 * std::mem::size_of::() as u64; command_encoder.clear_buffer( - &cloud.sorting_global_buffer, + &cloud.sorting_status_counter_buffer, 0, - std::num::NonZeroU64::new(size).unwrap().into() + None, ); } @@ -1255,7 +1294,7 @@ impl render_graph::Node for RadixSortNode { ); let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c; - pass.dispatch_workgroups(1, (cloud.count + workgroup_entries_c - 1) / workgroup_entries_c, 1); + pass.dispatch_workgroups(1, (cloud.count as u32 + workgroup_entries_c - 1) / workgroup_entries_c, 1); } } } diff --git a/src/render/sort/radix.wgsl b/src/render/sort/radix.wgsl index c867a583..9048dfbf 100644 --- a/src/render/sort/radix.wgsl +++ b/src/render/sort/radix.wgsl @@ -5,6 +5,7 @@ points, sorting_pass_index, sorting, + status_counters, draw_indirect, input_entries, output_entries, @@ -153,7 +154,7 @@ fn radix_sort_c( var keys: array; var ranks: array; for(var entry_index = 0u; entry_index < #{ENTRIES_PER_INVOCATION_C}u; entry_index += 1u) { - keys[entry_index] = input_entries[global_entry_offset + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x][0]; + keys[entry_index] = input_entries[global_entry_offset + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x].key; let digit = (keys[entry_index] >> (sorting_pass_index * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); // TODO: Implement warp-level multi-split (WLMS) once WebGPU supports subgroup operations ranks[entry_index] = atomicAdd(&sorting_shared_c.scan[digit + conflict_free_offset(digit)], 1u); @@ -166,7 +167,7 @@ fn radix_sort_c( sorting_shared_c.scan[gl_LocalInvocationID.x + conflict_free_offset(gl_LocalInvocationID.x)] = local_digit_offset; // Chained decoupling lookback - atomicStore(&sorting.status_counters[assignment][gl_LocalInvocationID.x], 0x40000000u | local_digit_count); + atomicStore(&status_counters[assignment][gl_LocalInvocationID.x], 0x40000000u | local_digit_count); var global_digit_count = 0u; var previous_tile = assignment; while true { @@ -177,14 +178,14 @@ fn radix_sort_c( previous_tile -= 1u; var status_counter = 0u; while((status_counter & 0xC0000000u) == 0u) { - status_counter = atomicLoad(&sorting.status_counters[previous_tile][gl_LocalInvocationID.x]); + status_counter = atomicLoad(&status_counters[previous_tile][gl_LocalInvocationID.x]); } global_digit_count += status_counter & 0x3FFFFFFFu; if((status_counter & 0x80000000u) != 0u) { break; } } - atomicStore(&sorting.status_counters[assignment][gl_LocalInvocationID.x], 0x80000000u | (global_digit_count + local_digit_count)); + atomicStore(&status_counters[assignment][gl_LocalInvocationID.x], 0x80000000u | (global_digit_count + local_digit_count)); if(sorting_pass_index == #{RADIX_DIGIT_PLACES}u - 1u && gl_LocalInvocationID.x == #{WORKGROUP_INVOCATIONS_C}u - 2u && global_entry_offset + #{WORKGROUP_ENTRIES_C}u >= arrayLength(&points)) { draw_indirect.vertex_count = 4u; draw_indirect.instance_count = global_digit_count + local_digit_count; @@ -208,13 +209,13 @@ fn radix_sort_c( let key = sorting_shared_c.entries[#{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x]; let digit = (key >> (sorting_pass_index * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); keys[entry_index] = digit; - output_entries[sorting_shared_c.scan[digit + conflict_free_offset(digit)] + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x][0] = key; + output_entries[sorting_shared_c.scan[digit + conflict_free_offset(digit)] + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x].key = key; } workgroupBarrier(); // Load values from global memory and scatter them inside shared memory for(var entry_index = 0u; entry_index < #{ENTRIES_PER_INVOCATION_C}u; entry_index += 1u) { - let value = input_entries[global_entry_offset + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x][1]; + let value = input_entries[global_entry_offset + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x].value; sorting_shared_c.entries[ranks[entry_index]] = value; } workgroupBarrier();