Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified shaders/rust/computenbody/particle_calculate.comp.spv
Binary file not shown.
57 changes: 40 additions & 17 deletions shaders/rust/computenbody/particle_calculate/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#![cfg_attr(target_arch = "spirv", no_std)]

use spirv_std::{
glam::{vec3, vec4, Vec3, Vec4},
glam::{vec3, vec4, Vec4, Vec4Swizzles},
spirv,
arch::workgroup_memory_barrier_with_group_sync,
num_traits::Float,
};

Expand All @@ -23,40 +24,62 @@ pub struct UBO {
pub soften: f32,
}

const SHARED_DATA_SIZE: usize = 512;

#[spirv(compute(threads(256)))]
pub fn main_cs(
#[spirv(global_invocation_id)] global_id: spirv_std::glam::UVec3,
#[spirv(local_invocation_id)] local_id: spirv_std::glam::UVec3,
#[spirv(workgroup)] shared_data: &mut [Vec4; SHARED_DATA_SIZE],
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] particles: &mut [Particle],
#[spirv(uniform, descriptor_set = 0, binding = 1)] ubo: &UBO,
) {
let index = global_id.x as usize;
let local_index = local_id.x as usize;

if index >= ubo.particle_count as usize {
return;
}

let position = vec4(particles[index].pos[0], particles[index].pos[1], particles[index].pos[2], particles[index].pos[3]);
let mut velocity = vec4(particles[index].vel[0], particles[index].vel[1], particles[index].vel[2], particles[index].vel[3]);
let mut acceleration = vec4(0.0, 0.0, 0.0, 0.0);
let mut acceleration = vec3(0.0, 0.0, 0.0);

// Calculate forces from all other particles (simplified O(N²) approach)
for i in 0..ubo.particle_count as usize {
if i == index {
continue; // Skip self-interaction
// Process particles in chunks of SHARED_DATA_SIZE
let mut i = 0u32;
while i < ubo.particle_count {
// Load particle data into shared memory
if i + (local_index as u32) < ubo.particle_count {
let particle_idx = i as usize + local_index;
shared_data[local_index] = vec4(
particles[particle_idx].pos[0],
particles[particle_idx].pos[1],
particles[particle_idx].pos[2],
particles[particle_idx].pos[3]
);
} else {
shared_data[local_index] = vec4(0.0, 0.0, 0.0, 0.0);
}

// Ensure all threads have loaded their data
unsafe {
workgroup_memory_barrier_with_group_sync();
}

let other = vec4(particles[i].pos[0], particles[i].pos[1], particles[i].pos[2], particles[i].pos[3]);
let len = vec3(other.x - position.x, other.y - position.y, other.z - position.z);
let distance_sq = len.dot(len) + ubo.soften;
let distance = distance_sq.sqrt();
let force_magnitude = ubo.gravity * other.w / distance_sq.powf(ubo.power / 2.0);
// Calculate forces from particles in shared memory
for j in 0..256 { // gl_WorkGroupSize.x = 256
let other = shared_data[j];
let len = other.xyz() - position.xyz();
let distance_sq = len.dot(len) + ubo.soften;
acceleration += ubo.gravity * len * other.w / distance_sq.powf(ubo.power * 0.5);
}

// Synchronize before next iteration
unsafe {
workgroup_memory_barrier_with_group_sync();
}

acceleration = acceleration + vec4(
len.x * force_magnitude,
len.y * force_magnitude,
len.z * force_magnitude,
0.0
);
i += SHARED_DATA_SIZE as u32;
}

// Update velocity with acceleration
Expand Down