Skip to content

Commit b44b606

Browse files
committed
bevy_pbr: Avoid copying structs and using registers in shaders (#7069)
# Objective - The #7064 PR had poor performance on an M1 Max in MacOS due to significant overuse of registers resulting in 'register spilling' where data that would normally be stored in registers on the GPU is instead stored in VRAM. The latency to read from/write to VRAM instead of registers incurs a significant performance penalty. - Use of registers is a limiting factor in shader performance. Assignment of a struct from memory to a local variable can incur copies. Passing a variable that has struct type as an argument to a function can also incur copies. As such, these two cases can incur increased register usage and decreased performance. ## Solution - Remove/avoid a number of assignments of light struct type data to local variables. - Remove/avoid a number of passing light struct type variables/data as value arguments to shader functions.
1 parent b833bda commit b44b606

File tree

3 files changed

+51
-49
lines changed

3 files changed

+51
-49
lines changed

crates/bevy_pbr/src/render/pbr_functions.wgsl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -196,38 +196,35 @@ fn pbr(
196196
// point lights
197197
for (var i: u32 = offset_and_counts[0]; i < offset_and_counts[0] + offset_and_counts[1]; i = i + 1u) {
198198
let light_id = get_light_id(i);
199-
let light = point_lights.data[light_id];
200199
var shadow: f32 = 1.0;
201200
if ((mesh.flags & MESH_FLAGS_SHADOW_RECEIVER_BIT) != 0u
202-
&& (light.flags & POINT_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
201+
&& (point_lights.data[light_id].flags & POINT_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
203202
shadow = fetch_point_shadow(light_id, in.world_position, in.world_normal);
204203
}
205-
let light_contrib = point_light(in.world_position.xyz, light, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
204+
let light_contrib = point_light(in.world_position.xyz, light_id, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
206205
light_accum = light_accum + light_contrib * shadow;
207206
}
208207

209208
// spot lights
210209
for (var i: u32 = offset_and_counts[0] + offset_and_counts[1]; i < offset_and_counts[0] + offset_and_counts[1] + offset_and_counts[2]; i = i + 1u) {
211210
let light_id = get_light_id(i);
212-
let light = point_lights.data[light_id];
213211
var shadow: f32 = 1.0;
214212
if ((mesh.flags & MESH_FLAGS_SHADOW_RECEIVER_BIT) != 0u
215-
&& (light.flags & POINT_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
213+
&& (point_lights.data[light_id].flags & POINT_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
216214
shadow = fetch_spot_shadow(light_id, in.world_position, in.world_normal);
217215
}
218-
let light_contrib = spot_light(in.world_position.xyz, light, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
216+
let light_contrib = spot_light(in.world_position.xyz, light_id, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
219217
light_accum = light_accum + light_contrib * shadow;
220218
}
221219

222220
let n_directional_lights = lights.n_directional_lights;
223221
for (var i: u32 = 0u; i < n_directional_lights; i = i + 1u) {
224-
let light = lights.directional_lights[i];
225222
var shadow: f32 = 1.0;
226223
if ((mesh.flags & MESH_FLAGS_SHADOW_RECEIVER_BIT) != 0u
227-
&& (light.flags & DIRECTIONAL_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
224+
&& (lights.directional_lights[i].flags & DIRECTIONAL_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
228225
shadow = fetch_directional_shadow(i, in.world_position, in.world_normal);
229226
}
230-
let light_contrib = directional_light(light, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
227+
let light_contrib = directional_light(i, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
231228
light_accum = light_accum + light_contrib * shadow;
232229
}
233230

crates/bevy_pbr/src/render/pbr_lighting.wgsl

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -150,22 +150,23 @@ fn perceptualRoughnessToRoughness(perceptualRoughness: f32) -> f32 {
150150
}
151151

152152
fn point_light(
153-
world_position: vec3<f32>, light: PointLight, roughness: f32, NdotV: f32, N: vec3<f32>, V: vec3<f32>,
153+
world_position: vec3<f32>, light_id: u32, roughness: f32, NdotV: f32, N: vec3<f32>, V: vec3<f32>,
154154
R: vec3<f32>, F0: vec3<f32>, diffuseColor: vec3<f32>
155155
) -> vec3<f32> {
156-
let light_to_frag = light.position_radius.xyz - world_position.xyz;
156+
let light = &point_lights.data[light_id];
157+
let light_to_frag = (*light).position_radius.xyz - world_position.xyz;
157158
let distance_square = dot(light_to_frag, light_to_frag);
158159
let rangeAttenuation =
159-
getDistanceAttenuation(distance_square, light.color_inverse_square_range.w);
160+
getDistanceAttenuation(distance_square, (*light).color_inverse_square_range.w);
160161

161162
// Specular.
162163
// Representative Point Area Lights.
163164
// see http://blog.selfshadow.com/publications/s2013-shading-course/karis/s2013_pbs_epic_notes_v2.pdf p14-16
164165
let a = roughness;
165166
let centerToRay = dot(light_to_frag, R) * R - light_to_frag;
166-
let closestPoint = light_to_frag + centerToRay * saturate(light.position_radius.w * inverseSqrt(dot(centerToRay, centerToRay)));
167+
let closestPoint = light_to_frag + centerToRay * saturate((*light).position_radius.w * inverseSqrt(dot(centerToRay, centerToRay)));
167168
let LspecLengthInverse = inverseSqrt(dot(closestPoint, closestPoint));
168-
let normalizationFactor = a / saturate(a + (light.position_radius.w * 0.5 * LspecLengthInverse));
169+
let normalizationFactor = a / saturate(a + ((*light).position_radius.w * 0.5 * LspecLengthInverse));
169170
let specularIntensity = normalizationFactor * normalizationFactor;
170171

171172
var L: vec3<f32> = closestPoint * LspecLengthInverse; // normalize() equivalent?
@@ -197,40 +198,44 @@ fn point_light(
197198
// I = Φ / 4 π
198199
// The derivation of this can be seen here: https://google.github.io/filament/Filament.html#mjx-eqn-pointLightLuminousPower
199200

200-
// NOTE: light.color.rgb is premultiplied with light.intensity / 4 π (which would be the luminous intensity) on the CPU
201+
// NOTE: (*light).color.rgb is premultiplied with (*light).intensity / 4 π (which would be the luminous intensity) on the CPU
201202

202203
// TODO compensate for energy loss https://google.github.io/filament/Filament.html#materialsystem/improvingthebrdfs/energylossinspecularreflectance
203204

204-
return ((diffuse + specular_light) * light.color_inverse_square_range.rgb) * (rangeAttenuation * NoL);
205+
return ((diffuse + specular_light) * (*light).color_inverse_square_range.rgb) * (rangeAttenuation * NoL);
205206
}
206207

207208
fn spot_light(
208-
world_position: vec3<f32>, light: PointLight, roughness: f32, NdotV: f32, N: vec3<f32>, V: vec3<f32>,
209+
world_position: vec3<f32>, light_id: u32, roughness: f32, NdotV: f32, N: vec3<f32>, V: vec3<f32>,
209210
R: vec3<f32>, F0: vec3<f32>, diffuseColor: vec3<f32>
210211
) -> vec3<f32> {
211212
// reuse the point light calculations
212-
let point_light = point_light(world_position, light, roughness, NdotV, N, V, R, F0, diffuseColor);
213+
let point_light = point_light(world_position, light_id, roughness, NdotV, N, V, R, F0, diffuseColor);
214+
215+
let light = &point_lights.data[light_id];
213216

214217
// reconstruct spot dir from x/z and y-direction flag
215-
var spot_dir = vec3<f32>(light.light_custom_data.x, 0.0, light.light_custom_data.y);
218+
var spot_dir = vec3<f32>((*light).light_custom_data.x, 0.0, (*light).light_custom_data.y);
216219
spot_dir.y = sqrt(max(0.0, 1.0 - spot_dir.x * spot_dir.x - spot_dir.z * spot_dir.z));
217-
if ((light.flags & POINT_LIGHT_FLAGS_SPOT_LIGHT_Y_NEGATIVE) != 0u) {
220+
if (((*light).flags & POINT_LIGHT_FLAGS_SPOT_LIGHT_Y_NEGATIVE) != 0u) {
218221
spot_dir.y = -spot_dir.y;
219222
}
220-
let light_to_frag = light.position_radius.xyz - world_position.xyz;
223+
let light_to_frag = (*light).position_radius.xyz - world_position.xyz;
221224

222225
// calculate attenuation based on filament formula https://google.github.io/filament/Filament.html#listing_glslpunctuallight
223226
// spot_scale and spot_offset have been precomputed
224227
// note we normalize here to get "l" from the filament listing. spot_dir is already normalized
225228
let cd = dot(-spot_dir, normalize(light_to_frag));
226-
let attenuation = saturate(cd * light.light_custom_data.z + light.light_custom_data.w);
229+
let attenuation = saturate(cd * (*light).light_custom_data.z + (*light).light_custom_data.w);
227230
let spot_attenuation = attenuation * attenuation;
228231

229232
return point_light * spot_attenuation;
230233
}
231234

232-
fn directional_light(light: DirectionalLight, roughness: f32, NdotV: f32, normal: vec3<f32>, view: vec3<f32>, R: vec3<f32>, F0: vec3<f32>, diffuseColor: vec3<f32>) -> vec3<f32> {
233-
let incident_light = light.direction_to_light.xyz;
235+
fn directional_light(light_id: u32, roughness: f32, NdotV: f32, normal: vec3<f32>, view: vec3<f32>, R: vec3<f32>, F0: vec3<f32>, diffuseColor: vec3<f32>) -> vec3<f32> {
236+
let light = &lights.directional_lights[light_id];
237+
238+
let incident_light = (*light).direction_to_light.xyz;
234239

235240
let half_vector = normalize(incident_light + view);
236241
let NoL = saturate(dot(normal, incident_light));
@@ -241,5 +246,5 @@ fn directional_light(light: DirectionalLight, roughness: f32, NdotV: f32, normal
241246
let specularIntensity = 1.0;
242247
let specular_light = specular(F0, roughness, half_vector, NdotV, NoL, NoH, LoH, specularIntensity);
243248

244-
return (specular_light + diffuse) * light.color.rgb * NoL;
249+
return (specular_light + diffuse) * (*light).color.rgb * NoL;
245250
}

crates/bevy_pbr/src/render/shadows.wgsl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,31 @@
11
#define_import_path bevy_pbr::shadows
22

33
fn fetch_point_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: vec3<f32>) -> f32 {
4-
let light = point_lights.data[light_id];
4+
let light = &point_lights.data[light_id];
55

66
// because the shadow maps align with the axes and the frustum planes are at 45 degrees
77
// we can get the worldspace depth by taking the largest absolute axis
8-
let surface_to_light = light.position_radius.xyz - frag_position.xyz;
8+
let surface_to_light = (*light).position_radius.xyz - frag_position.xyz;
99
let surface_to_light_abs = abs(surface_to_light);
1010
let distance_to_light = max(surface_to_light_abs.x, max(surface_to_light_abs.y, surface_to_light_abs.z));
1111

1212
// The normal bias here is already scaled by the texel size at 1 world unit from the light.
1313
// The texel size increases proportionally with distance from the light so multiplying by
1414
// distance to light scales the normal bias to the texel size at the fragment distance.
15-
let normal_offset = light.shadow_normal_bias * distance_to_light * surface_normal.xyz;
16-
let depth_offset = light.shadow_depth_bias * normalize(surface_to_light.xyz);
15+
let normal_offset = (*light).shadow_normal_bias * distance_to_light * surface_normal.xyz;
16+
let depth_offset = (*light).shadow_depth_bias * normalize(surface_to_light.xyz);
1717
let offset_position = frag_position.xyz + normal_offset + depth_offset;
1818

1919
// similar largest-absolute-axis trick as above, but now with the offset fragment position
20-
let frag_ls = light.position_radius.xyz - offset_position.xyz;
20+
let frag_ls = (*light).position_radius.xyz - offset_position.xyz;
2121
let abs_position_ls = abs(frag_ls);
2222
let major_axis_magnitude = max(abs_position_ls.x, max(abs_position_ls.y, abs_position_ls.z));
2323

2424
// NOTE: These simplifications come from multiplying:
2525
// projection * vec4(0, 0, -major_axis_magnitude, 1.0)
2626
// and keeping only the terms that have any impact on the depth.
2727
// Projection-agnostic approach:
28-
let zw = -major_axis_magnitude * light.light_custom_data.xy + light.light_custom_data.zw;
28+
let zw = -major_axis_magnitude * (*light).light_custom_data.xy + (*light).light_custom_data.zw;
2929
let depth = zw.x / zw.y;
3030

3131
// do the lookup, using HW PCF and comparison
@@ -42,27 +42,27 @@ fn fetch_point_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: v
4242
}
4343

4444
fn fetch_spot_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: vec3<f32>) -> f32 {
45-
let light = point_lights.data[light_id];
45+
let light = &point_lights.data[light_id];
4646

47-
let surface_to_light = light.position_radius.xyz - frag_position.xyz;
47+
let surface_to_light = (*light).position_radius.xyz - frag_position.xyz;
4848

4949
// construct the light view matrix
50-
var spot_dir = vec3<f32>(light.light_custom_data.x, 0.0, light.light_custom_data.y);
50+
var spot_dir = vec3<f32>((*light).light_custom_data.x, 0.0, (*light).light_custom_data.y);
5151
// reconstruct spot dir from x/z and y-direction flag
5252
spot_dir.y = sqrt(1.0 - spot_dir.x * spot_dir.x - spot_dir.z * spot_dir.z);
53-
if ((light.flags & POINT_LIGHT_FLAGS_SPOT_LIGHT_Y_NEGATIVE) != 0u) {
53+
if (((*light).flags & POINT_LIGHT_FLAGS_SPOT_LIGHT_Y_NEGATIVE) != 0u) {
5454
spot_dir.y = -spot_dir.y;
5555
}
5656

5757
// view matrix z_axis is the reverse of transform.forward()
5858
let fwd = -spot_dir;
5959
let distance_to_light = dot(fwd, surface_to_light);
60-
let offset_position =
61-
-surface_to_light
62-
+ (light.shadow_depth_bias * normalize(surface_to_light))
63-
+ (surface_normal.xyz * light.shadow_normal_bias) * distance_to_light;
60+
let offset_position =
61+
-surface_to_light
62+
+ ((*light).shadow_depth_bias * normalize(surface_to_light))
63+
+ (surface_normal.xyz * (*light).shadow_normal_bias) * distance_to_light;
6464

65-
// the construction of the up and right vectors needs to precisely mirror the code
65+
// the construction of the up and right vectors needs to precisely mirror the code
6666
// in render/light.rs:spot_light_view_matrix
6767
var sign = -1.0;
6868
if (fwd.z >= 0.0) {
@@ -74,14 +74,14 @@ fn fetch_spot_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: ve
7474
let right_dir = vec3<f32>(-b, -sign - fwd.y * fwd.y * a, fwd.y);
7575
let light_inv_rot = mat3x3<f32>(right_dir, up_dir, fwd);
7676

77-
// because the matrix is a pure rotation matrix, the inverse is just the transpose, and to calculate
78-
// the product of the transpose with a vector we can just post-multiply instead of pre-multplying.
77+
// because the matrix is a pure rotation matrix, the inverse is just the transpose, and to calculate
78+
// the product of the transpose with a vector we can just post-multiply instead of pre-multplying.
7979
// this allows us to keep the matrix construction code identical between CPU and GPU.
8080
let projected_position = offset_position * light_inv_rot;
8181

8282
// divide xy by perspective matrix "f" and by -projected.z (projected.z is -projection matrix's w)
8383
// to get ndc coordinates
84-
let f_div_minus_z = 1.0 / (light.spot_light_tan_angle * -projected_position.z);
84+
let f_div_minus_z = 1.0 / ((*light).spot_light_tan_angle * -projected_position.z);
8585
let shadow_xy_ndc = projected_position.xy * f_div_minus_z;
8686
// convert to uv coordinates
8787
let shadow_uv = shadow_xy_ndc * vec2<f32>(0.5, -0.5) + vec2<f32>(0.5, 0.5);
@@ -90,23 +90,23 @@ fn fetch_spot_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: ve
9090
let depth = 0.1 / -projected_position.z;
9191

9292
#ifdef NO_ARRAY_TEXTURES_SUPPORT
93-
return textureSampleCompare(directional_shadow_textures, directional_shadow_textures_sampler,
93+
return textureSampleCompare(directional_shadow_textures, directional_shadow_textures_sampler,
9494
shadow_uv, depth);
9595
#else
96-
return textureSampleCompareLevel(directional_shadow_textures, directional_shadow_textures_sampler,
96+
return textureSampleCompareLevel(directional_shadow_textures, directional_shadow_textures_sampler,
9797
shadow_uv, i32(light_id) + lights.spot_light_shadowmap_offset, depth);
9898
#endif
9999
}
100100

101101
fn fetch_directional_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: vec3<f32>) -> f32 {
102-
let light = lights.directional_lights[light_id];
102+
let light = &lights.directional_lights[light_id];
103103

104104
// The normal bias is scaled to the texel size.
105-
let normal_offset = light.shadow_normal_bias * surface_normal.xyz;
106-
let depth_offset = light.shadow_depth_bias * light.direction_to_light.xyz;
105+
let normal_offset = (*light).shadow_normal_bias * surface_normal.xyz;
106+
let depth_offset = (*light).shadow_depth_bias * (*light).direction_to_light.xyz;
107107
let offset_position = vec4<f32>(frag_position.xyz + normal_offset + depth_offset, frag_position.w);
108108

109-
let offset_position_clip = light.view_projection * offset_position;
109+
let offset_position_clip = (*light).view_projection * offset_position;
110110
if (offset_position_clip.w <= 0.0) {
111111
return 1.0;
112112
}

0 commit comments

Comments
 (0)