11#![ cfg_attr( target_arch = "spirv" , no_std) ]
22
33use spirv_std:: {
4- glam:: { vec3, vec4, Vec3 , Vec4 } ,
4+ glam:: { vec3, vec4, Vec4 , Vec4Swizzles } ,
55 spirv,
6+ arch:: workgroup_memory_barrier_with_group_sync,
67 num_traits:: Float ,
78} ;
89
@@ -23,40 +24,62 @@ pub struct UBO {
2324 pub soften : f32 ,
2425}
2526
27+ const SHARED_DATA_SIZE : usize = 512 ;
28+
2629#[ spirv( compute( threads( 256 ) ) ) ]
2730pub fn main_cs (
2831 #[ spirv( global_invocation_id) ] global_id : spirv_std:: glam:: UVec3 ,
32+ #[ spirv( local_invocation_id) ] local_id : spirv_std:: glam:: UVec3 ,
33+ #[ spirv( workgroup) ] shared_data : & mut [ Vec4 ; SHARED_DATA_SIZE ] ,
2934 #[ spirv( storage_buffer, descriptor_set = 0 , binding = 0 ) ] particles : & mut [ Particle ] ,
3035 #[ spirv( uniform, descriptor_set = 0 , binding = 1 ) ] ubo : & UBO ,
3136) {
3237 let index = global_id. x as usize ;
38+ let local_index = local_id. x as usize ;
3339
3440 if index >= ubo. particle_count as usize {
3541 return ;
3642 }
3743
3844 let position = vec4 ( particles[ index] . pos [ 0 ] , particles[ index] . pos [ 1 ] , particles[ index] . pos [ 2 ] , particles[ index] . pos [ 3 ] ) ;
3945 let mut velocity = vec4 ( particles[ index] . vel [ 0 ] , particles[ index] . vel [ 1 ] , particles[ index] . vel [ 2 ] , particles[ index] . vel [ 3 ] ) ;
40- let mut acceleration = vec4 ( 0.0 , 0.0 , 0.0 , 0.0 ) ;
46+ let mut acceleration = vec3 ( 0.0 , 0.0 , 0.0 ) ;
4147
42- // Calculate forces from all other particles (simplified O(N²) approach)
43- for i in 0 ..ubo. particle_count as usize {
44- if i == index {
45- continue ; // Skip self-interaction
48+ // Process particles in chunks of SHARED_DATA_SIZE
49+ let mut i = 0u32 ;
50+ while i < ubo. particle_count {
51+ // Load particle data into shared memory
52+ if i + ( local_index as u32 ) < ubo. particle_count {
53+ let particle_idx = i as usize + local_index;
54+ shared_data[ local_index] = vec4 (
55+ particles[ particle_idx] . pos [ 0 ] ,
56+ particles[ particle_idx] . pos [ 1 ] ,
57+ particles[ particle_idx] . pos [ 2 ] ,
58+ particles[ particle_idx] . pos [ 3 ]
59+ ) ;
60+ } else {
61+ shared_data[ local_index] = vec4 ( 0.0 , 0.0 , 0.0 , 0.0 ) ;
62+ }
63+
64+ // Ensure all threads have loaded their data
65+ unsafe {
66+ workgroup_memory_barrier_with_group_sync ( ) ;
4667 }
4768
48- let other = vec4 ( particles[ i] . pos [ 0 ] , particles[ i] . pos [ 1 ] , particles[ i] . pos [ 2 ] , particles[ i] . pos [ 3 ] ) ;
49- let len = vec3 ( other. x - position. x , other. y - position. y , other. z - position. z ) ;
50- let distance_sq = len. dot ( len) + ubo. soften ;
51- let distance = distance_sq. sqrt ( ) ;
52- let force_magnitude = ubo. gravity * other. w / distance_sq. powf ( ubo. power / 2.0 ) ;
69+ // Calculate forces from particles in shared memory
70+ for j in 0 ..256 { // gl_WorkGroupSize.x = 256
71+ let other = shared_data[ j] ;
72+ let len = other. xyz ( ) - position. xyz ( ) ;
73+ let distance_sq = len. dot ( len) + ubo. soften ;
74+ acceleration += ubo. gravity * len * other. w / distance_sq. powf ( ubo. power * 0.5 ) ;
75+ }
76+
77+ // Synchronize before next iteration
78+ unsafe {
79+ workgroup_memory_barrier_with_group_sync ( ) ;
80+ }
5381
54- acceleration = acceleration + vec4 (
55- len. x * force_magnitude,
56- len. y * force_magnitude,
57- len. z * force_magnitude,
58- 0.0
59- ) ;
82+ i += SHARED_DATA_SIZE as u32 ;
6083 }
6184
6285 // Update velocity with acceleration
0 commit comments