Skip to content

Commit 7a20e01

Browse files
committed
Update on "[ET-VK] Efficient tiled int8 matmul"
## Context Introduce a optimized tiled implementation for computing the weight int8-quantized linear operation. This implementation takes advantage of the following principles to squeeze out performance: * Compute an output tile with each thread, rather than a single output element. This allows for better memory re-use of loaded input tensor data. * Compute the output tile by iteratively loading tiles of the input matrices, caching them in registers, and then performing the `fma` accumulations to obtain a partial output. By splitting the data loading and computation into distinct steps, the GPU is able to perform latency hiding more effectively, i.e. switching to a warp that needs to perform compute when the current warp is waiting on data load * Use a work group size of `{N, 1, 1}`. This makes it so that all the threads in a work group load the same row of the input matrx, and consecutive columns of the weight matrix. This way, the row of the input is kept hot in the cache, and accesses to the weight matrix can be coalesced due to the previous diff un-transposing the weight matrix. Differential Revision: [D72066587](https://our.internmc.facebook.com/intern/diff/D72066587/) [ghstack-poisoned]
2 parents 8f3b16f + d73f38f commit 7a20e01

File tree

5 files changed

+10
-25
lines changed

5 files changed

+10
-25
lines changed

backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,8 @@ ivec4 read_texel(ivec4 tidx) {
4747
ivec4 sizes_to_use = sizes;
4848
int packed_dim_to_use = packed_dim;
4949
if (transpose_hw == 1) {
50-
int tmp = sizes_to_use.x;
51-
sizes_to_use.x = sizes_to_use.y;
52-
sizes_to_use.y = tmp;
53-
54-
tmp = tidx_to_use.x;
55-
tidx_to_use.x = tidx.y;
56-
tidx_to_use.y = tmp;
50+
sizes_to_use.xy = sizes_to_use.yx;
51+
tidx_to_use.xy = tidx.yx;
5752

5853
if (packed_dim == 1) {
5954
packed_dim_to_use = 0;

backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,8 @@ void main() {
3333

3434
ivec4 sizes = out_sizes;
3535
if (transpose_hw == 1) {
36-
int tmp = sizes.x;
37-
sizes.x = sizes.y;
38-
sizes.y = tmp;
39-
40-
tmp = out_tidx.x;
41-
out_tidx.x = out_tidx.y;
42-
out_tidx.y = tmp;
36+
sizes.xy = sizes.yx;
37+
out_tidx.xy = out_tidx.yx;
4338
}
4439
const int in_nchwi = tidx_to_nchwi(out_tidx, sizes);
4540

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,8 @@ VEC4_T read_texel(ivec4 tidx) {
4040
ivec4 sizes_to_use = sizes;
4141
int packed_dim_to_use = packed_dim;
4242
if (transpose_hw == 1) {
43-
int tmp = sizes_to_use.x;
44-
sizes_to_use.x = sizes_to_use.y;
45-
sizes_to_use.y = tmp;
46-
47-
tmp = tidx_to_use.x;
48-
tidx_to_use.x = tidx.y;
49-
tidx_to_use.y = tmp;
43+
sizes_to_use.xy = sizes_to_use.yx;
44+
tidx_to_use.xy = tidx.yx;
5045

5146
if (packed_dim == 1) {
5247
packed_dim_to_use = 0;

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ void main() {
7070
// TODO(ssjia): optimize memory access pattern by traversing mat1 x in inner loop
7171
for (int i = 0; i < mat1_sizes.x; i++) {
7272
const FLOAT_T mat1_val = t_mat1[mat1_offset];
73-
const FLOAT_T mat2_val = FLOAT_T(t_qmat2[qmat2_offset]) * scale;
73+
const FLOAT_T mat2_val = FLOAT_T(t_qmat2[qmat2_offset]);
7474

7575
outval += mat1_val * mat2_val;
7676

7777
mat1_offset++;
7878
qmat2_offset += qmat2_strides.y;
7979
}
8080

81-
t_out[out_bufi] = outval;
81+
t_out[out_bufi] = outval * scale;
8282
}
8383

8484
#else // USING_TEXTURE

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ ValueRef prepack_standard_hw_transposed(
175175
const int w_dim = new_out_sizes.size() - 1;
176176
const int h_dim = new_out_sizes.size() - 2;
177177
const int64_t tmp = new_out_sizes.at(w_dim);
178-
new_out_sizes[w_dim] = new_out_sizes[h_dim];
179-
new_out_sizes[h_dim] = tmp;
178+
new_out_sizes.at(w_dim) = new_out_sizes.at(h_dim);
179+
new_out_sizes.at(h_dim) = tmp;
180180
ValueRef tensor = graph.add_tensor(
181181
new_out_sizes,
182182
graph.dtype_of(tensor_data),

0 commit comments

Comments
 (0)