Skip to content

Commit 3f82843

Browse files
authored
[ET-VK] Reducing precision of some in members in conv2d pw to improved performance.
Differential Revision: D75423958 Pull Request resolved: #11139
1 parent 6432ba4 commit 3f82843

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,13 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
5050
void main() {
5151
const int out_limits_scaled[2] = {out_limits.x + (TILE_SIZE_X - 1) * TILE_SIZE_X, out_limits.y + (TILE_SIZE_Y - 1) * TILE_SIZE_Y};
5252

53-
const int div_by_x = int(gl_GlobalInvocationID.x / out_limits_scaled[0]);
54-
const int out_pos[3] = {int(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x, int(gl_GlobalInvocationID.y)};
53+
const uint16_t div_by_x = uint16_t(gl_GlobalInvocationID.x / out_limits_scaled[0]);
54+
const uint16_t out_pos_xy[2] = {uint16_t(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x};
55+
const int out_pos_z = int(gl_GlobalInvocationID.y);
5556

5657
// If the top left position is out of bounds, then this invocation will have
5758
// no work to do.
58-
if (out_pos[1] >= out_limits_scaled[1] || out_pos[2] >= out_limits.z) {
59+
if (out_pos_xy[1] >= out_limits_scaled[1] || out_pos_z >= out_limits.z) {
5960
return;
6061
}
6162

@@ -68,8 +69,8 @@ void main() {
6869
uint16_t pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
6970
for (uint16_t y = uint16_t(0), i = uint16_t(0); y < TILE_SIZE_Y; ++y) {
7071
for (uint16_t x = uint16_t(0); x < TILE_SIZE_X; ++x) {
71-
pos[i * 2] = uint16_t(out_pos[0]) * TILE_SIZE_X + x;
72-
pos[i * 2 + 1] = uint16_t(out_pos[1]) * TILE_SIZE_Y + y;
72+
pos[i * 2] = out_pos_xy[0] * TILE_SIZE_X + x;
73+
pos[i * 2 + 1] = out_pos_xy[1] * TILE_SIZE_Y + y;
7374
i++;
7475
}
7576
}
@@ -78,7 +79,7 @@ void main() {
7879
// Tuple of consecutive 4 elements represents a single output texel.
7980
float sum[TILE_SIZE_X * TILE_SIZE_Y * 4];
8081

81-
const vec4 bias = texelFetch(t_bias, ivec2(out_pos[2], 0), 0);
82+
const vec4 bias = texelFetch(t_bias, ivec2(out_pos_z, 0), 0);
8283

8384
// Initialize the output array with the bias value
8485
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i += 4) {
@@ -98,7 +99,7 @@ void main() {
9899

99100
// Load kernel values from texels to array
100101
[[unroll]] for (int i = 0; i < 4; ++i) {
101-
const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos[2]), 0);
102+
const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos_z), 0);
102103
kernel_values[i * 4 + 0] = k_tex.x;
103104
kernel_values[i * 4 + 1] = k_tex.y;
104105
kernel_values[i * 4 + 2] = k_tex.z;
@@ -157,8 +158,8 @@ void main() {
157158
}
158159

159160
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
160-
const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos[2]);
161-
if (all(lessThan(pos_l, out_limits.xyz))) {
161+
const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos_z);
162+
if (all(lessThan(pos_l.xy, out_limits.xy))) {
162163
imageStore(t_out, pos_l, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
163164
}
164165
}

0 commit comments

Comments
 (0)