Skip to content

Commit 1f29600

Browse files
committed
Update on "[ET-VK] Store weights transposed for int8 linear"
## Context The weight tensor of a linear layer is usually stored in a transposed manner, such that when computing the matrix multiplication, the reduction traverses along the rows of the weight tensor as opposed to the columns. This results in a better memory access pattern for CPUs. However, for GPUs, I have found that "un-transposing" the weight tensors result in better performance. This is likely due to the fact since GPUs can compute multiple output elements in parallel, reading along the columns allows for coalescing memory loads among threads in a work group. ## Changes * Introduce the ability to transpose height and weight dims when transferring tensor data to the GPU. * Prepackthe weight tensor "un-transposed" for the int8 quantized linear operator Differential Revision: [D72066588](https://our.internmc.facebook.com/intern/diff/D72066588/) [ghstack-poisoned]
1 parent 708ecb7 commit 1f29600

File tree

5 files changed

+10
-25
lines changed

5 files changed

+10
-25
lines changed

backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl

+2-7
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,8 @@ ivec4 read_texel(ivec4 tidx) {
4747
ivec4 sizes_to_use = sizes;
4848
int packed_dim_to_use = packed_dim;
4949
if (transpose_hw == 1) {
50-
int tmp = sizes_to_use.x;
51-
sizes_to_use.x = sizes_to_use.y;
52-
sizes_to_use.y = tmp;
53-
54-
tmp = tidx_to_use.x;
55-
tidx_to_use.x = tidx.y;
56-
tidx_to_use.y = tmp;
50+
sizes_to_use.xy = sizes_to_use.yx;
51+
tidx_to_use.xy = tidx.yx;
5752

5853
if (packed_dim == 1) {
5954
packed_dim_to_use = 0;

backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl

+2-7
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,8 @@ void main() {
3333

3434
ivec4 sizes = out_sizes;
3535
if (transpose_hw == 1) {
36-
int tmp = sizes.x;
37-
sizes.x = sizes.y;
38-
sizes.y = tmp;
39-
40-
tmp = out_tidx.x;
41-
out_tidx.x = out_tidx.y;
42-
out_tidx.y = tmp;
36+
sizes.xy = sizes.yx;
37+
out_tidx.xy = out_tidx.yx;
4338
}
4439
const int in_nchwi = tidx_to_nchwi(out_tidx, sizes);
4540

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl

+2-7
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,8 @@ VEC4_T read_texel(ivec4 tidx) {
4040
ivec4 sizes_to_use = sizes;
4141
int packed_dim_to_use = packed_dim;
4242
if (transpose_hw == 1) {
43-
int tmp = sizes_to_use.x;
44-
sizes_to_use.x = sizes_to_use.y;
45-
sizes_to_use.y = tmp;
46-
47-
tmp = tidx_to_use.x;
48-
tidx_to_use.x = tidx.y;
49-
tidx_to_use.y = tmp;
43+
sizes_to_use.xy = sizes_to_use.yx;
44+
tidx_to_use.xy = tidx.yx;
5045

5146
if (packed_dim == 1) {
5247
packed_dim_to_use = 0;

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ void main() {
7070
// TODO(ssjia): optimize memory access pattern by traversing mat1 x in inner loop
7171
for (int i = 0; i < mat1_sizes.x; i++) {
7272
const FLOAT_T mat1_val = t_mat1[mat1_offset];
73-
const FLOAT_T mat2_val = FLOAT_T(t_qmat2[qmat2_offset]) * scale;
73+
const FLOAT_T mat2_val = FLOAT_T(t_qmat2[qmat2_offset]);
7474

7575
outval += mat1_val * mat2_val;
7676

7777
mat1_offset++;
7878
qmat2_offset += qmat2_strides.y;
7979
}
8080

81-
t_out[out_bufi] = outval;
81+
t_out[out_bufi] = outval * scale;
8282
}
8383

8484
#else // USING_TEXTURE

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ ValueRef prepack_standard_hw_transposed(
175175
const int w_dim = new_out_sizes.size() - 1;
176176
const int h_dim = new_out_sizes.size() - 2;
177177
const int64_t tmp = new_out_sizes.at(w_dim);
178-
new_out_sizes[w_dim] = new_out_sizes[h_dim];
179-
new_out_sizes[h_dim] = tmp;
178+
new_out_sizes.at(w_dim) = new_out_sizes.at(h_dim);
179+
new_out_sizes.at(h_dim) = tmp;
180180
ValueRef tensor = graph.add_tensor(
181181
new_out_sizes,
182182
graph.dtype_of(tensor_data),

0 commit comments

Comments
 (0)