Skip to content

Commit 7f9332f

Browse files
committed
WIP: User imports
1 parent b946a43 commit 7f9332f

File tree

24 files changed

+213
-69
lines changed

24 files changed

+213
-69
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
use crate::wgsl::shader_module::user_defined_portion::WgslShaderModuleUserPortion;
2+
3+
pub fn merge_libraries_into_wgsl_module(user_module: &mut WgslShaderModuleUserPortion, library_modules: &mut Vec<WgslShaderModuleUserPortion>) {
4+
for library in library_modules.iter_mut() {
5+
user_module.helper_functions.append(&mut library.helper_functions);
6+
user_module.static_consts.append(&mut library.static_consts);
7+
user_module.helper_types.append(&mut library.helper_types);
8+
}
9+
}

bevy_gpu_compute_core/src/rust/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
mod in_out_metadata;
22
mod iter_space_dimmensions;
3+
mod library_import;
34
mod max_output_lengths;
45
mod type_erased_array_input_data;
56
mod type_erased_config_input_data;
@@ -8,6 +9,7 @@ mod type_safe_api_helpers;
89

910
pub use in_out_metadata::*;
1011
pub use iter_space_dimmensions::*;
12+
pub use library_import::*;
1113
pub use max_output_lengths::*;
1214
pub use type_erased_array_input_data::*;
1315
pub use type_erased_config_input_data::*;

bevy_gpu_compute_core/src/wgsl/shader_module/derived_portion.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ mod tests {
9393

9494
#[test]
9595
fn test_wgsl_shader_module_library_portion_from_user_portion() {
96-
let user_portion = WgslShaderModuleUserPortion { static_consts: vec![WgslConstAssignment { code: WgslShaderModuleSectionCode { wgsl_code: "const example_module_const : u32 = 42;".to_string() } }], helper_types: vec![], uniforms: vec![WgslType { name: ShaderCustomTypeName::new("Uniforms"), code: WgslShaderModuleSectionCode { wgsl_code: "struct Uniforms { time : f32, resolution : vec2 < f32 > , }".to_string() } }], input_arrays: vec![WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Position"), code: WgslShaderModuleSectionCode { wgsl_code: "alias Position = array < f32, 2 > ;".to_string() } } }, WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Radius") , code: WgslShaderModuleSectionCode { wgsl_code: "alias Radius = f32;".to_string() } }}], output_arrays: vec![WgslOutputArray { item_type: WgslType { name: ShaderCustomTypeName::new("CollisionResult"), code: WgslShaderModuleSectionCode { wgsl_code: "struct CollisionResult { entity1 : u32, entity2 : u32, }".to_string() } }, atomic_counter_name: Some("collisionresult_counter".to_string()) }], helper_functions: vec![WgslFunction { name: "calculate_distance_squared".to_string(), code: WgslShaderModuleSectionCode { wgsl_code: "fn calculate_distance_squared(p1 : array < f32, 2 > , p2 : array < f32, 2 >)\n-> f32\n{\n let dx = p1 [0] - p2 [0]; let dy = p1 [1] - p2 [1]; return dx * dx + dy *\n dy;\n}".to_string() } }], main_function: Some(WgslFunction { name: "main".to_owned(), code: WgslShaderModuleSectionCode { wgsl_code: "fn main(@builtin(global_invocation_id) iter_pos: vec3<u32>)\n{\n let current_entity = iter_pos.x; let other_entity = iter_pos.y; if\n current_entity >= POSITION_INPUT_ARRAY_LENGTH || other_entity >=\n POSITION_INPUT_ARRAY_LENGTH || current_entity == other_entity ||\n current_entity >= other_entity { return; } let current_radius =\n radius_input_array [current_entity]; let other_radius = radius_input_array\n [other_entity]; if current_radius <= 0.0 || other_radius <= 0.0\n { return; } let current_pos = position_input_array [current_entity]; let\n other_pos = position_input_array [other_entity]; let dist_squared =\n calculate_distance_squared(current_pos, other_pos); let radius_sum =\n current_radius + other_radius; if dist_squared < radius_sum * radius_sum\n {\n {\n let collisionresult_output_array_index =\n atomicAdd(& collisionresult_counter, 1u); if\n collisionresult_output_array_index <\n COLLISIONRESULT_OUTPUT_ARRAY_LENGTH\n {\n collisionresult_output_array\n [collisionresult_output_array_index] = CollisionResult\n { entity1 : current_entity, entity2 : other_entity, };\n }\n };\n }\n}".to_owned() } }), binding_numbers_by_variable_name: Some(HashMap::from([(String::from("uniforms"), 0), (String::from("position_input_array"), 1), (String::from("radius_input_array"), 2), (String::from("collisionresult_output_array"), 3), (String::from("collisionresult_counter"), 4)]))
96+
let user_portion = WgslShaderModuleUserPortion { static_consts: vec![WgslConstAssignment { code: WgslShaderModuleSectionCode { wgsl_code: "const example_module_const : u32 = 42;".to_string() } }], helper_types: vec![], uniforms: vec![WgslType { name: ShaderCustomTypeName::new("Uniforms"), code: WgslShaderModuleSectionCode { wgsl_code: "struct Uniforms { time : f32, resolution : vec2 < f32 > , }".to_string() } }], input_arrays: vec![WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Position"), code: WgslShaderModuleSectionCode { wgsl_code: "alias Position = array < f32, 2 > ;".to_string() } } }, WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Radius") , code: WgslShaderModuleSectionCode { wgsl_code: "alias Radius = f32;".to_string() } }}], output_arrays: vec![WgslOutputArray { item_type: WgslType { name: ShaderCustomTypeName::new("CollisionResult"), code: WgslShaderModuleSectionCode { wgsl_code: "struct CollisionResult { entity1 : u32, entity2 : u32, }".to_string() } }, atomic_counter_name: Some("collisionresult_counter".to_string()) }], helper_functions: vec![WgslFunction { name: "calculate_distance_squared".to_string(), code: WgslShaderModuleSectionCode { wgsl_code: "fn calculate_distance_squared(p1 : array < f32, 2 > , p2 : array < f32, 2 >)\n-> f32\n{\n let dx = p1 [0] - p2 [0]; let dy = p1 [1] - p2 [1]; return dx * dx + dy *\n dy;\n}".to_string() } }], main_function: Some(WgslFunction { name: "main".to_owned(), code: WgslShaderModuleSectionCode { wgsl_code: "fn main(@builtin(global_invocation_id) iter_pos: vec3<u32>)\n{\n let current_entity = iter_pos.x; let other_entity = iter_pos.y; if\n current_entity >= POSITION_INPUT_ARRAY_LENGTH || other_entity >=\n POSITION_INPUT_ARRAY_LENGTH || current_entity == other_entity ||\n current_entity >= other_entity { return; } let current_radius =\n radius_input_array [current_entity]; let other_radius = radius_input_array\n [other_entity]; if current_radius <= 0.0 || other_radius <= 0.0\n { return; } let current_pos = position_input_array [current_entity]; let\n other_pos = position_input_array [other_entity]; let dist_squared =\n calculate_distance_squared(current_pos, other_pos); let radius_sum =\n current_radius + other_radius; if dist_squared < radius_sum * radius_sum\n {\n {\n let collisionresult_output_array_index =\n atomicAdd(& collisionresult_counter, 1u); if\n collisionresult_output_array_index <\n COLLISIONRESULT_OUTPUT_ARRAY_LENGTH\n {\n collisionresult_output_array\n [collisionresult_output_array_index] = CollisionResult\n { entity1 : current_entity, entity2 : other_entity, };\n }\n };\n }\n}".to_owned() } }), binding_numbers_by_variable_name: Some(HashMap::from([(String::from("uniforms"), 0), (String::from("position_input_array"), 1), (String::from("radius_input_array"), 2), (String::from("collisionresult_output_array"), 3), (String::from("collisionresult_counter"), 4)])), use_statements: vec![],
9797
};
9898

9999
let expected_wgsl_code = "const example_module_const : u32 = 42;

bevy_gpu_compute_core/src/wgsl/shader_module/user_defined_portion.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub struct WgslShaderModuleUserPortion {
2626
/// look for any attempt to ASSIGN to the value of "global_id.x", "global_id.y", or "global_id.z" or just "global_id" and throw an error
2727
pub main_function: Option<WgslFunction>,
2828
pub binding_numbers_by_variable_name: Option<HashMap<String, u32>>,
29+
pub use_statements: Vec<WgslImport>,
2930
}
3031
impl WgslShaderModuleUserPortion {
3132
pub fn empty() -> Self {
@@ -38,6 +39,7 @@ impl WgslShaderModuleUserPortion {
3839
helper_functions: vec![],
3940
main_function: None,
4041
binding_numbers_by_variable_name: None,
42+
use_statements: vec![],
4143
}
4244
}
4345
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#[derive(Clone, Debug, PartialEq)]
2+
pub struct WgslImport {
3+
pub path: String,
4+
}

bevy_gpu_compute_core/src/wgsl/shader_sections/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ mod code;
22
mod const_assignment;
33
mod custom_type;
44
mod function;
5+
mod import;
56
mod input_array;
67
mod output_array;
78
mod wgpu_binding;
@@ -11,6 +12,7 @@ pub use code::*;
1112
pub use const_assignment::*;
1213
pub use custom_type::*;
1314
pub use function::*;
15+
pub use import::*;
1416
pub use input_array::*;
1517
pub use output_array::*;
1618
pub use wgpu_binding::*;

bevy_gpu_compute_macro/src/pipeline/compilation_metadata.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
use crate::pipeline::phases::custom_type_collector::custom_type::CustomType;
1+
use crate::pipeline::phases::{
2+
custom_type_collector::custom_type::CustomType, user_import_collector::user_import::UserImport,
3+
};
24
use bevy_gpu_compute_core::wgsl::shader_module::user_defined_portion::WgslShaderModuleUserPortion;
35
use proc_macro2::TokenStream;
46

57
pub struct CompilationMetadata {
8+
pub user_imports: Option<Vec<UserImport>>,
69
pub main_func_required: bool,
710
pub custom_types: Option<Vec<CustomType>>,
811
pub wgsl_module_user_portion: Option<WgslShaderModuleUserPortion>,

bevy_gpu_compute_macro/src/pipeline/compilation_unit.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ use proc_macro2::TokenStream;
33

44
use super::{
55
compilation_metadata::CompilationMetadata,
6-
phases::custom_type_collector::custom_type::CustomType,
6+
phases::{
7+
custom_type_collector::custom_type::CustomType,
8+
user_import_collector::user_import::UserImport,
9+
},
710
};
811

912
pub struct CompilationUnit {
@@ -22,6 +25,7 @@ impl CompilationUnit {
2225
rust_module_for_gpu: None,
2326
compiled_tokens: None,
2427
metadata: CompilationMetadata {
28+
user_imports: None,
2529
custom_types: None,
2630
wgsl_module_user_portion: None,
2731
typesafe_buffer_builders: None,
@@ -44,6 +48,15 @@ impl CompilationUnit {
4448
}
4549
self.rust_module_for_cpu.as_ref().unwrap()
4650
}
51+
pub fn set_user_imports(&mut self, user_imports: Vec<UserImport>) {
52+
self.metadata.user_imports = Some(user_imports);
53+
}
54+
pub fn user_imports(&self) -> &Vec<UserImport> {
55+
if self.metadata.user_imports.is_none() {
56+
panic!("user_imports is not set");
57+
}
58+
self.metadata.user_imports.as_ref().unwrap()
59+
}
4760
pub fn set_custom_types(&mut self, custom_types: Vec<CustomType>) {
4861
self.metadata.custom_types = Some(custom_types);
4962
}

bevy_gpu_compute_macro/src/pipeline/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use super::phases::{
77
module_for_rust_usage_cleaner::compiler_phase::ModuleForRustUsageCleaner,
88
non_mutating_tree_validation::compiler_phase::NonMutatingTreeValidation,
99
typesafe_buffer_builders_generator::compiler_phase::TypesafeBufferBuildersGenerator,
10+
user_import_collector::compiler_phase::UserImportCollector,
1011
wgsl_helper_transformer::compiler_phase::WgslHelperTransformer,
1112
};
1213
use crate::pipeline::compilation_unit::CompilationUnit;
@@ -20,6 +21,7 @@ impl Default for CompilerPipeline {
2021
Self {
2122
phases: vec![
2223
Box::new(NonMutatingTreeValidation {}),
24+
Box::new(UserImportCollector {}),
2325
Box::new(CustomTypeCollector {}),
2426
Box::new(TypesafeBufferBuildersGenerator {}),
2527
Box::new(WgslHelperTransformer {}),

bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/per_component_expansion.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use std::collections::HashMap;
22

3-
use bevy_gpu_compute_core::{
4-
wgsl::shader_custom_type_name::ShaderCustomTypeName,
5-
wgsl::shader_sections::{
6-
WgslConstAssignment, WgslFunction, WgslInputArray, WgslOutputArray,
3+
use bevy_gpu_compute_core::wgsl::{
4+
shader_custom_type_name::ShaderCustomTypeName,
5+
shader_sections::{
6+
WgslConstAssignment, WgslFunction, WgslImport, WgslInputArray, WgslOutputArray,
77
WgslShaderModuleSectionCode, WgslType,
88
},
99
};
@@ -88,6 +88,12 @@ impl ToStructInitializer {
8888
}
8989
)
9090
}
91+
92+
pub fn wgsl_import(c: &WgslImport) -> TokenStream {
93+
let i: TokenStream = c.path.parse().unwrap();
94+
quote!(#i)
95+
}
96+
9197
pub fn hash_map(c: &HashMap<String, u32>) -> TokenStream {
9298
let entries: TokenStream = c
9399
.iter()

0 commit comments

Comments
 (0)