Skip to content

Commit a4985a8

Browse files
[ET-VK] De vectorise positions in conv2d pw shader to improve perf. (#11186)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11122 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/93/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/93/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/92/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/93/orig @diff-train-skip-merge --------- Co-authored-by: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com>
1 parent 22e7dbd commit a4985a8

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,32 +58,28 @@ void main() {
5858
return;
5959
}
6060

61-
// If the top left position is out of bounds, then this invocation will have
62-
// no work to do.
63-
if (gpos.z >= out_limits.z) {
64-
return;
65-
}
66-
6761
// Output position for TILE_SIZE = 2
6862
// +--------+--------+
6963
// | pos[0] | pos[1] |
7064
// +--------+--------+
7165
// | pos[2] | pos[3] |
7266
// +--------+--------+
73-
ivec3 pos[TILE_SIZE_X * TILE_SIZE_Y];
67+
int pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
7468
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
7569
for (int x = 0; x < TILE_SIZE_X; ++x) {
76-
pos[i] = ivec3(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y, gpos.z);
70+
pos[i * 2] = gpos.x * TILE_SIZE_X + x;
71+
pos[i * 2 + 1] = gpos.y * TILE_SIZE_Y + y;
7772
i++;
7873
}
7974
}
8075

8176
// Compute the index of the input texture that needs to be loaded for each
8277
// output position. Note that negative indices can be produced indicating that
8378
// the top-left element is in a region added by padding.
84-
ivec2 ipos[TILE_SIZE_X * TILE_SIZE_Y];
79+
int ipos[TILE_SIZE_X * TILE_SIZE_Y * 2];
8580
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
86-
ipos[i] = pos[i].xy * stride - padding;
81+
ipos[i * 2] = pos[i * 2] * stride.x - padding.x;
82+
ipos[i * 2 + 1] = pos[i * 2 + 1] * stride.y - padding.y;
8783
}
8884

8985
// Final output array where each element is a tensor value.
@@ -118,7 +114,7 @@ void main() {
118114
}
119115

120116
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
121-
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0);
117+
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i * 2], ipos[i * 2 + 1], z4), 0);
122118
// Load the input texel into an array
123119
float tex_values[4];
124120
tex_values[0] = in_tex.x;
@@ -169,8 +165,9 @@ void main() {
169165
}
170166

171167
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
172-
if (all(lessThan(pos[i], out_limits.xyz))) {
173-
imageStore(t_out, pos[i], op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
168+
const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], gpos.z);
169+
if (all(lessThan(pos_l, out_limits.xyz))) {
170+
imageStore(t_out, pos_l, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
174171
}
175172
}
176173
}

0 commit comments

Comments
 (0)