Skip to content

Commit 47e49bd

Browse files
authored
[ET-VK] Remove the use of shared memory in conv2d pw to improve perf.
Differential Revision: D75316188 Pull Request resolved: #11110
1 parent 28886cd commit 47e49bd

File tree

1 file changed

+11
-21
lines changed

1 file changed

+11
-21
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
#define TILE_SIZE_X ${TILE_SIZE_X}
1616
#define TILE_SIZE_Y ${TILE_SIZE_Y}
17-
#define LOCAL_WG_SIZE 64
1817

1918
#define op(X, A, B) ${OPERATOR}
2019

@@ -39,53 +38,46 @@ layout(push_constant) uniform restrict Block {
3938

4039
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4140

42-
// For performance improvement, reduce register usage by caching positions in shared memory.
43-
// Offset index by 1 every 16 points to avoid bank access conflict.
44-
#define offset_pos_index(index) (index + ((index) >> 4))
45-
shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE * TILE_SIZE_X * TILE_SIZE_Y)];
46-
4741
/*
4842
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
4943
* output tile for pointwise convolution is more efficient because the kernel
5044
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
5145
*/
5246
void main() {
5347
const ivec2 out_limits_scaled = (out_limits.xy + ivec2(TILE_SIZE_X - 1, TILE_SIZE_Y - 1)) / ivec2(TILE_SIZE_X, TILE_SIZE_Y);
54-
const uint shared_mem_stride = LOCAL_WG_SIZE;
5548

5649
const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
5750
const ivec3 gpos = ivec3(
5851
gl_GlobalInvocationID.x % out_limits_scaled.x,
5952
div_by_x % out_limits_scaled.y,
6053
div_by_x / out_limits_scaled.y);
6154

55+
// If the top left position is out of bounds, then this invocation will have
56+
// no work to do.
57+
if (gpos.z >= out_limits.z) {
58+
return;
59+
}
60+
6261
// Output position for TILE_SIZE = 2
6362
// +--------+--------+
6463
// | pos[0] | pos[1] |
6564
// +--------+--------+
6665
// | pos[2] | pos[3] |
6766
// +--------+--------+
68-
ivec2 pos[TILE_SIZE_X * TILE_SIZE_Y];
67+
ivec3 pos[TILE_SIZE_X * TILE_SIZE_Y];
6968
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
7069
for (int x = 0; x < TILE_SIZE_X; ++x) {
71-
pos[i] = ivec2(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y);
72-
pos_shared[offset_pos_index((shared_mem_stride * i) + gl_LocalInvocationIndex)] = ivec3(pos[i], gpos.z);
70+
pos[i] = ivec3(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y, gpos.z);
7371
i++;
7472
}
7573
}
7674

77-
// If the top left position is out of bounds, then this invocation will have
78-
// no work to do.
79-
if (gpos.z >= out_limits.z) {
80-
return;
81-
}
82-
8375
// Compute the index of the input texture that needs to be loaded for each
8476
// output position. Note that negative indices can be produced indicating that
8577
// the top-left element is in a region added by padding.
8678
ivec2 ipos[TILE_SIZE_X * TILE_SIZE_Y];
8779
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
88-
ipos[i] = pos[i] * stride - padding;
80+
ipos[i] = pos[i].xy * stride - padding;
8981
}
9082

9183
// Final output array where each element is a tensor value.
@@ -171,10 +163,8 @@ void main() {
171163
}
172164

173165
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
174-
const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex;
175-
const ivec3 pos = pos_shared[offset_pos_index(index)];
176-
if (all(lessThan(pos, out_limits.xyz))) {
177-
imageStore(t_out, pos, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
166+
if (all(lessThan(pos[i], out_limits.xyz))) {
167+
imageStore(t_out, pos[i], op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
178168
}
179169
}
180170
}

0 commit comments

Comments
 (0)