Skip to content

Commit 439d66d

Browse files
authored
[ET-VK] Store weights transposed for int8 linear
Differential Revision: D72066588 Pull Request resolved: #9765
1 parent 2aa7748 commit 439d66d

File tree

8 files changed

+110
-26
lines changed

8 files changed

+110
-26
lines changed

backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl

+19-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ ${layout_declare_ubo(B, "ivec4", "sizes")}
2727
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2828

2929
${layout_declare_spec_const(C, "int", "t_layout", "DEFAULT_LAYOUT")}
30+
${layout_declare_spec_const(C, "int", "transpose_hw", "0")}
31+
3032
const lowp ivec4 axis_map = unhash_axis_map(t_layout);
3133
const lowp int packed_dim = unhash_packed_dim(t_layout);
3234

@@ -41,8 +43,23 @@ int extend_sign(int x) {
4143
}
4244

4345
ivec4 read_texel(ivec4 tidx) {
46+
ivec4 tidx_to_use = tidx;
47+
ivec4 sizes_to_use = sizes;
48+
int packed_dim_to_use = packed_dim;
49+
if (transpose_hw == 1) {
50+
sizes_to_use.xy = sizes_to_use.yx;
51+
tidx_to_use.xy = tidx.yx;
52+
53+
if (packed_dim == 1) {
54+
packed_dim_to_use = 0;
55+
}
56+
if (packed_dim == 0) {
57+
packed_dim_to_use = 1;
58+
}
59+
}
60+
4461
const ivec4 buf_indices = tidx_to_nchwi(
45-
tidx, sizes, packed_dim);
62+
tidx_to_use, sizes_to_use, packed_dim_to_use);
4663

4764
int shift = (1 << 8) - 1;
4865
ivec4 masks;
@@ -70,7 +87,7 @@ ivec4 read_texel(ivec4 tidx) {
7087

7188
void main() {
7289
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
73-
const ivec4 tidx = lpos_to_tidx(lpos, sizes, axis_map.w, packed_dim);
90+
ivec4 tidx = lpos_to_tidx(lpos, sizes, axis_map.w, packed_dim);
7491

7592
if (any(greaterThanEqual(tidx, sizes))) {
7693
return;

backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl

+8-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2121
// This constant is unused in this shader but is kept so that the signature is
2222
// consistent with nchw_to_image.
2323
${layout_declare_spec_const(C, "int", "UNUSED_layout", "0")}
24+
${layout_declare_spec_const(C, "int", "transpose_hw", "0")}
2425

2526
void main() {
2627
int out_bufi = int(gl_GlobalInvocationID.x);
@@ -29,7 +30,13 @@ void main() {
2930
}
3031

3132
ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides);
32-
const int in_nchwi = tidx_to_nchwi(out_tidx, out_sizes);
33+
34+
ivec4 sizes = out_sizes;
35+
if (transpose_hw == 1) {
36+
sizes.xy = sizes.yx;
37+
out_tidx.xy = out_tidx.yx;
38+
}
39+
const int in_nchwi = tidx_to_nchwi(out_tidx, sizes);
3340

3441
t_out[out_bufi] = nchw_in[in_nchwi];
3542
}

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl

+19-2
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,31 @@ $if not FROM_STAGING:
3030
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3131

3232
${layout_declare_spec_const(C, "int", "t_layout", "DEFAULT_LAYOUT")}
33+
${layout_declare_spec_const(C, "int", "transpose_hw", "0")}
34+
3335
const lowp ivec4 axis_map = unhash_axis_map(t_layout);
3436
const lowp int packed_dim = unhash_packed_dim(t_layout);
3537

3638
VEC4_T read_texel(ivec4 tidx) {
39+
ivec4 tidx_to_use = tidx;
40+
ivec4 sizes_to_use = sizes;
41+
int packed_dim_to_use = packed_dim;
42+
if (transpose_hw == 1) {
43+
sizes_to_use.xy = sizes_to_use.yx;
44+
tidx_to_use.xy = tidx.yx;
45+
46+
if (packed_dim == 1) {
47+
packed_dim_to_use = 0;
48+
}
49+
if (packed_dim == 0) {
50+
packed_dim_to_use = 1;
51+
}
52+
}
53+
3754
$if FROM_STAGING:
38-
const ivec4 buf_indices = tidx_to_nchwi(tidx, sizes, packed_dim);
55+
const ivec4 buf_indices = tidx_to_nchwi(tidx_to_use, sizes_to_use, packed_dim_to_use);
3956
$else:
40-
const ivec4 buf_indices = tidx_to_4bufi(tidx, buf_strides, packed_dim);
57+
const ivec4 buf_indices = tidx_to_4bufi(tidx_to_use, buf_strides, packed_dim_to_use);
4158

4259
VEC4_T texel = VEC4_T(0);
4360
if (tidx[packed_dim] < sizes[packed_dim]) {

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

+15-16
Original file line numberDiff line numberDiff line change
@@ -64,24 +64,21 @@ void main() {
6464

6565
FLOAT_T outval = FLOAT_T(0.0);
6666

67-
// Initial mat1 tensor idx will be (0, out_tidx.y, out_tidx.z, 0)
6867
int mat1_offset = out_tidx.y * mat1_strides.y + out_tidx.z * qmat2_strides.z;
69-
// Initial qmat2 tensor idx wil be (0, out_tidx.x, 0, 0); note that the qmat2
70-
// tensor is transposed
71-
int qmat2_offset = out_tidx.x * qmat2_strides.y;
68+
int qmat2_offset = out_tidx.x;
7269

7370
// TODO(ssjia): optimize memory access pattern by traversing mat1 x in inner loop
7471
for (int i = 0; i < mat1_sizes.x; i++) {
7572
const FLOAT_T mat1_val = t_mat1[mat1_offset];
76-
const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale;
73+
const FLOAT_T mat2_val = FLOAT_T(t_qmat2[qmat2_offset]);
7774

7875
outval += mat1_val * mat2_val;
7976

8077
mat1_offset++;
81-
qmat2_offset++;
78+
qmat2_offset += qmat2_strides.y;
8279
}
8380

84-
t_out[out_bufi] = outval;
81+
t_out[out_bufi] = outval * scale;
8582
}
8683

8784
#else // USING_TEXTURE
@@ -97,25 +94,27 @@ void main() {
9794
return;
9895
}
9996

100-
const uint16_t qmat2_pos_y = out_pos.x * uint16_t(4);
97+
const uint16_t qmat2_pos_x = out_pos.x;
10198

10299
VEC4_T outtex = VEC4_T(0);
103100

104101
const VEC4_T scales = load_texel(t_scales, u16vec3(out_pos.x, 0, 0));
105102

103+
VEC4_T mat1_tex;
104+
VEC4_T mat2_tex[4];
106105
for (
107106
uint16_t i = uint16_t(0), x = uint16_t(0);
108107
i < uint16_t(mat1_sizes.x);
109108
i += uint16_t(4), x++)
110109
{
111-
const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0));
112-
const VEC4_T sums = VEC4_T(
113-
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y, 0))),
114-
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(1), 0))),
115-
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(2), 0))),
116-
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(3), 0))));
117-
118-
outtex += sums;
110+
mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0));
111+
112+
mat2_tex[0] = load_texel(t_qmat2, u16vec3(out_pos.x, i, 0));
113+
mat2_tex[1] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(1), 0));
114+
mat2_tex[2] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(2), 0));
115+
mat2_tex[3] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(3), 0));
116+
117+
outtex += mat1_tex.x * mat2_tex[0] + mat1_tex.y * mat2_tex[1] + mat1_tex.z * mat2_tex[2] + mat1_tex.w * mat2_tex[3];
119118
}
120119

121120
outtex *= scales;

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void resize_q_8w_linear_node(
4848
vTensorPtr qmat2 = graph->get_tensor(args[1].refs[1]);
4949

5050
const int out_cols = utils::val_at(-2, mat1->sizes());
51-
const int out_rows = utils::val_at(-2, qmat2->sizes());
51+
const int out_rows = utils::val_at(-1, qmat2->sizes());
5252

5353
std::vector<int64_t> new_out_sizes(3);
5454
if (mat1->sizes().size() == 2) {
@@ -86,7 +86,7 @@ void add_q_8w_linear_node(
8686
// Ensure out is packed correctly
8787
out_W_packed = out_tmp;
8888
}
89-
ValueRef q_mat2 = prepack_standard(
89+
ValueRef q_mat2 = prepack_standard_hw_transposed(
9090
graph, q_mat2_data, graph.storage_type_of(out), utils::kWidthPacked);
9191
ValueRef scales = prepack_standard(
9292
graph, scales_data, graph.storage_type_of(out), utils::kWidthPacked);

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

+32-2
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ void add_tensor_to_staging_node(
113113
void add_prepack_standard_node(
114114
ComputeGraph& graph,
115115
const ValueRef tensor_data,
116-
const ValueRef tensor) {
116+
const ValueRef tensor,
117+
const bool transpose_hw = false) {
117118
vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(
118119
*graph.get_tensor(tensor), graph.int8_buffers_enabled());
119120

@@ -127,6 +128,8 @@ void add_prepack_standard_node(
127128
ubos.append({graph.sizes_ubo(tensor)});
128129
}
129130

131+
int transpose_hw_spec = transpose_hw ? 1 : 0;
132+
130133
graph.prepack_nodes().emplace_back(new PrepackNode(
131134
graph,
132135
shader,
@@ -138,7 +141,7 @@ void add_prepack_standard_node(
138141
// Parameter Buffers
139142
ubos,
140143
// Specialization Constants
141-
{graph.hashed_layout_of(tensor)}));
144+
{graph.hashed_layout_of(tensor), transpose_hw_spec}));
142145
}
143146

144147
ValueRef prepack_standard(
@@ -158,6 +161,33 @@ ValueRef prepack_standard(
158161
return tensor;
159162
}
160163

164+
ValueRef prepack_standard_hw_transposed(
165+
ComputeGraph& graph,
166+
const ValueRef tensor_data,
167+
const utils::StorageType storage_type,
168+
const utils::GPUMemoryLayout layout,
169+
const bool passthrough,
170+
const utils::AxisMapLayout axis_map_layout) {
171+
(void)passthrough;
172+
173+
VK_CHECK_COND(graph.val_is_tref(tensor_data));
174+
std::vector<int64_t> new_out_sizes = graph.sizes_of(tensor_data);
175+
const int w_dim = new_out_sizes.size() - 1;
176+
const int h_dim = new_out_sizes.size() - 2;
177+
const int64_t tmp = new_out_sizes.at(w_dim);
178+
new_out_sizes.at(w_dim) = new_out_sizes.at(h_dim);
179+
new_out_sizes.at(h_dim) = tmp;
180+
ValueRef tensor = graph.add_tensor(
181+
new_out_sizes,
182+
graph.dtype_of(tensor_data),
183+
storage_type,
184+
layout,
185+
-1,
186+
axis_map_layout);
187+
add_prepack_standard_node(graph, tensor_data, tensor, true);
188+
return tensor;
189+
}
190+
161191
ValueRef prepack_standard_like(
162192
ComputeGraph& graph,
163193
const ValueRef tensor_data,

backends/vulkan/runtime/graph/ops/impl/Staging.h

+12
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@ ValueRef prepack_standard(
5151
const bool passthrough = false,
5252
const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap);
5353

54+
/*
55+
* Same as prepack_standard, but transpose the height and width dimensions of
56+
* the tensor while packing.
57+
*/
58+
ValueRef prepack_standard_hw_transposed(
59+
ComputeGraph& graph,
60+
const ValueRef tensor_data,
61+
const utils::StorageType storage_type,
62+
const utils::GPUMemoryLayout layout,
63+
const bool passthrough = false,
64+
const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap);
65+
5466
/*
5567
* Equivalent to `prepack_standard()` function, except the `storage_type` and
5668
* `memory_layout` are set to match `to_copy`, which must be a `Tensor`.

backends/vulkan/test/op_tests/cases.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,14 @@ def get_weight_int8pack_mm_inputs():
157157
[6, 1024, 256],
158158
[6, 256, 256],
159159
[6, 256, 512],
160+
[4, 768, 4096],
161+
[1024, 1024, 1024],
160162
]
161163

162164
inputs_list = [((M, K), (N, K), (N)) for M, K, N in MKN_list]
163165

164166
test_suite = VkTestSuite(inputs_list)
165-
test_suite.dtypes = ["at::kFloat", "at::kHalf"]
167+
test_suite.dtypes = ["at::kFloat"]
166168
test_suite.layouts = ["utils::kWidthPacked"]
167169
test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"]
168170
test_suite.prepacked_args = ["mat2", "scales"]

0 commit comments

Comments
 (0)