Skip to content

Commit 4645e57

Browse files
authored
[ET-VK] De vectorise all vectors in conv2d pw shader to improve perf.
Differential Revision: D75423245 Pull Request resolved: #11136
1 parent 2b253e3 commit 4645e57

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,14 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4646
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
4747
*/
4848
void main() {
49-
const ivec2 out_limits_scaled = (out_limits.xy + ivec2(TILE_SIZE_X - 1, TILE_SIZE_Y - 1)) / ivec2(TILE_SIZE_X, TILE_SIZE_Y);
49+
const int out_limits_scaled[2] = {out_limits.x + (TILE_SIZE_X - 1) * TILE_SIZE_X, out_limits.y + (TILE_SIZE_Y - 1) * TILE_SIZE_Y};
5050

51-
const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
52-
const ivec3 gpos = ivec3(
53-
gl_GlobalInvocationID.x % out_limits_scaled.x,
54-
div_by_x,
55-
gl_GlobalInvocationID.y);
51+
const int div_by_x = int(gl_GlobalInvocationID.x / out_limits_scaled[0]);
52+
const int out_pos[3] = {int(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x, int(gl_GlobalInvocationID.y)};
5653

5754
// If the top left position is out of bounds, then this invocation will have
5855
// no work to do.
59-
if (gpos.y >= out_limits_scaled.y || gpos.z >= out_limits.z) {
56+
if (out_pos[1] >= out_limits_scaled[1] || out_pos[2] >= out_limits.z) {
6057
return;
6158
}
6259

@@ -69,8 +66,8 @@ void main() {
6966
int pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
7067
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
7168
for (int x = 0; x < TILE_SIZE_X; ++x) {
72-
pos[i * 2] = gpos.x * TILE_SIZE_X + x;
73-
pos[i * 2 + 1] = gpos.y * TILE_SIZE_Y + y;
69+
pos[i * 2] = out_pos[0] * TILE_SIZE_X + x;
70+
pos[i * 2 + 1] = out_pos[1] * TILE_SIZE_Y + y;
7471
i++;
7572
}
7673
}
@@ -88,7 +85,7 @@ void main() {
8885
// Tuple of consecutive 4 elements represents a single output texel.
8986
float sum[TILE_SIZE_X * TILE_SIZE_Y * 4];
9087

91-
const vec4 bias = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
88+
const vec4 bias = texelFetch(t_bias, ivec2(out_pos[2], 0), 0);
9289

9390
// Initialize the output array with the bias value
9491
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i += 4) {
@@ -108,7 +105,7 @@ void main() {
108105

109106
// Load kernel values from texels to array
110107
[[unroll]] for (int i = 0; i < 4; ++i) {
111-
const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, gpos.z), 0);
108+
const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos[2]), 0);
112109
kernel_values[i * 4 + 0] = k_tex.x;
113110
kernel_values[i * 4 + 1] = k_tex.y;
114111
kernel_values[i * 4 + 2] = k_tex.z;
@@ -167,7 +164,7 @@ void main() {
167164
}
168165

169166
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
170-
const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], gpos.z);
167+
const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos[2]);
171168
if (all(lessThan(pos_l, out_limits.xyz))) {
172169
imageStore(t_out, pos_l, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
173170
}

0 commit comments

Comments
 (0)