Skip to content

Commit b1dfb5a

Browse files
committed
[ET-VK] Storing positions in uint16 to instead of int in conv2d pw shader.
Pull Request resolved: #11138 This diff modifies the `conv2d_pw_s1p0.glsl` shader to store positions in `uint16` instead of `int`. The changes include adding the necessary extension for explicit arithmetic types, updating the type definitions for `TILE_SIZE_X` and `TILE_SIZE_Y`, and changing the type of the `pos` array. ghstack-source-id: 286652102 @exported-using-ghexport Differential Revision: [D75423935](https://our.internmc.facebook.com/intern/diff/D75423935/)
1 parent a5ac056 commit b1dfb5a

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88

99
#version 450 core
1010

11+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
12+
1113
#define PRECISION ${PRECISION}
1214

1315
#define VEC4_T ${texel_type(DTYPE)}
1416

15-
#define TILE_SIZE_X ${TILE_SIZE_X}
16-
#define TILE_SIZE_Y ${TILE_SIZE_Y}
17+
#define TILE_SIZE_X uint16_t(${TILE_SIZE_X})
18+
#define TILE_SIZE_Y uint16_t(${TILE_SIZE_Y})
1719

1820
#define op(X, A, B) ${OPERATOR}
1921

@@ -63,11 +65,11 @@ void main() {
6365
// +--------+--------+
6466
// | pos[2] | pos[3] |
6567
// +--------+--------+
66-
int pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
67-
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
68-
for (int x = 0; x < TILE_SIZE_X; ++x) {
69-
pos[i * 2] = out_pos[0] * TILE_SIZE_X + x;
70-
pos[i * 2 + 1] = out_pos[1] * TILE_SIZE_Y + y;
68+
uint16_t pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
69+
for (uint16_t y = uint16_t(0), i = uint16_t(0); y < TILE_SIZE_Y; ++y) {
70+
for (uint16_t x = uint16_t(0); x < TILE_SIZE_X; ++x) {
71+
pos[i * 2] = uint16_t(out_pos[0]) * TILE_SIZE_X + x;
72+
pos[i * 2 + 1] = uint16_t(out_pos[1]) * TILE_SIZE_Y + y;
7173
i++;
7274
}
7375
}

0 commit comments

Comments
 (0)