From 2eae7a963fbb62d9ebcfadd22071919e8b48b6db Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 23 Sep 2024 15:36:25 -0700 Subject: [PATCH] Move QMat2 to buffer storage and scales_and_zeros to Channels Packed (#5515) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5515 Storing QMat2 in a texture gives way to two main problems: - Indexing is a mess and additional computation is required to take into account the fact that we are reading ivec4's and only using half of the values - There is no texel fetching in int8. The texel is read in int32 and needs to be casted Keeping QMat2 in a buffer performs better because, although reading from buffers is slower, removing the extra computation compensates for this. {F1863459327} This diff also moves the scales_and_zeros tensor to Channels Packed in texture implementations because it just makes more sense, I had done some terrible indexing shennanigans before. ghstack-source-id: 244258611 exported-using-ghexport Reviewed By: yipjustin Differential Revision: D62504978 fbshipit-source-id: df2fdf87f75140be0a316576c8ffad67feefd6d7 --- .../runtime/graph/ops/glsl/q_4w_linear.glsl | 47 +++++++------------ .../graph/ops/impl/QuantizedMatMul.cpp | 18 +++---- .../vulkan/test/vulkan_compute_api_test.cpp | 18 +++++-- 3 files changed, 42 insertions(+), 41 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl index d07d45251f..de42f9ed99 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl @@ -26,13 +26,14 @@ layout(std430) buffer; ${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} ${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)} -${layout_declare_tensor(2, "r", "t_mat2", "int8", STORAGE)} +${layout_declare_tensor(2, "r", "t_mat2", "int8", "buffer")} ${layout_declare_tensor(3, "r", "t_scales_and_zeros", DTYPE, STORAGE)} $if STORAGE == "texture3d": ${layout_declare_ubo(4, "ivec4", "out_sizes")} ${layout_declare_ubo(5, "ivec4", "mat1_sizes")} - ${layout_declare_ubo(6, "ivec4", "scales_strides")} + ${layout_declare_ubo(6, "ivec4", "mat2_strides")} + ${layout_declare_ubo(7, "ivec4", "scales_strides")} $else: ${layout_declare_ubo(4, "ivec4", "out_sizes")} ${layout_declare_ubo(5, "ivec4", "out_strides")} @@ -64,9 +65,9 @@ void main() { float rc = 0.0; int k = 0; + const uint k_block = (K + group_size - 1) / group_size; #ifdef USING_BUFFER - const uint k_block = (K + group_size - 1) / group_size; ivec4 mat1_pos = ivec4(0, m, out_pos.z, out_pos.w); ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w); ivec4 scale_pos = ivec4(0, n, 0, out_pos.w); @@ -101,42 +102,30 @@ void main() { t_out[out_bufi] = FLOAT_T(rc); #else // Using texture - const uint texel_group_size = group_size / FOUR; - const uint k_block = (K + texel_group_size - 1) / texel_group_size; ivec3 mat1_pos = ivec3(0, m, out_pos.z); - ivec3 mat2_pos = ivec3(0, n, out_pos.z); - ivec3 scale_pos = ivec3(0, n, 0); - ivec3 zero_pos = ivec3(0, n, 1); + ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w); + ivec3 scale_zero_pos = ivec3(0, n, 0); + uint K_texel = K / FOUR; for (int kb = 0; kb < k_block; kb++) { - const int texel_kb = kb / FOUR; - const int kb_offset = kb % FOUR; - - scale_pos.x = texel_kb; - const VEC4_T scale_texel = load_texel(t_scales_and_zeros, scale_pos); - const float scale = float(scale_texel[kb_offset]); + scale_zero_pos.x = kb; + const vec4 scale_zero = load_texel(t_scales_and_zeros, scale_zero_pos); + const float scale = scale_zero.x; + const float zero = scale_zero.y - scale * 8.0; - zero_pos.x = texel_kb; - const VEC4_T zero_texel = load_texel(t_scales_and_zeros, zero_pos); - const float zero = float(zero_texel[kb_offset]) - scale * 8.0; - - for(uint idx = 0; idx < texel_group_size && k < K; idx++, k++) { + for(uint idx = 0; idx < group_size && k < K_texel; idx += FOUR, k++) { mat1_pos.x = k; const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos); - mat2_pos.x = k / 2; - const i8vec4 mat2_tex = i8vec4(load_texel(t_mat2, mat2_pos)); + mat2_pos.x = k * 2; // k * FOUR / 2 + const int mat2_id = tidx_to_bufi(mat2_pos, mat2_strides); - // Every two texels of mat1 correspond to one texel of mat2 - // Even mat1 indeces correspond to first half of mat2 texel and - // odd indeces correspond to second half - const int mat2_offset = (k & 1) == 0 ? 0 : 2; - for (int texel_idx = 0; texel_idx < FOUR; texel_idx++){ + for (int texel_pos = 0; texel_pos < FOUR; texel_pos++) { // Bitwise op treats sign bit from int8 as a value bit instead, // since there is no uint8_t datatype - uint mat2_val = (mat2_tex[mat2_offset + texel_idx / 2] & 0xFF); - mat2_val = (texel_idx & 1) == 0 ? mat2_val & mask : (mat2_val >> 4); - rc += mat1_tex[texel_idx] * (scale * float(mat2_val) + zero); + uint mat2_val = (t_mat2[mat2_id + texel_pos / 2] & 0xFF); + mat2_val = (texel_pos & 1) == 0 ? mat2_val & mask : (mat2_val >> 4); + rc += mat1_tex[texel_pos] * (scale * float(mat2_val) + zero); } } } diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp index 0152a4a351..17bd62ad6e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp @@ -33,7 +33,13 @@ void check_q_matmul_args( using namespace WHCN; VK_CHECK_COND(graph.packed_dim_of(mat1) == kWidthDim); VK_CHECK_COND(graph.packed_dim_of(mat2_data) == kWidthDim); - VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kWidthDim); + // VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kWidthDim); + + if (graph.storage_type_of(scales_and_zeros) == utils::kBuffer) { + VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kWidthDim); + } else { + VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kChannelsDim); + } if (graph.storage_type_of(out) == utils::kBuffer) { VK_CHECK_COND(graph.packed_dim_of(out) == kWidthDim); @@ -106,13 +112,8 @@ void add_q_matmul_node( const ValueRef out) { auto storage_type = graph.storage_type_of(out); - ValueRef mat2; - - if (storage_type == utils::kBuffer) { - mat2 = prepack_buffer_if_tensor_ref(graph, mat2_data, utils::kWidthPacked); - } else { - mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kWidthPacked); - } + ValueRef mat2 = + prepack_buffer_if_tensor_ref(graph, mat2_data, utils::kWidthPacked); ValueRef scales_and_zeros = prepack_if_tensor_ref(graph, scales_and_zeros_data, utils::kWidthPacked); @@ -135,6 +136,7 @@ void add_q_matmul_node( } else { ubos.append(graph.sizes_ubo(out)); ubos.append(graph.sizes_ubo(mat1)); + ubos.append(graph.strides_ubo(mat2)); ubos.append(graph.strides_ubo(scales_and_zeros)); } diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 44d183a8a5..4b024cf969 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -2932,7 +2932,7 @@ void test_int4pack_mm( int4mm_pack_weights(mat2_size, B_quant_data.data()); IOValueRef B_int4 = - graph.add_input_tensor(mat2_q_size, vkapi::kQInt8, storage_type); + graph.add_input_tensor(mat2_q_size, vkapi::kQInt8, utils::kBuffer); graph.copy_into_staging( B_int4.staging, B_int4_data.data(), B_int4_data.size()); @@ -2940,8 +2940,18 @@ void test_int4pack_mm( // Random scales and zeroes. Keep scales small to avoid overflow and zeroes in // int4 range - IOValueRef scales_and_zeros = - graph.add_input_tensor({2, N, k_groups}, vkapi::kFloat, storage_type); + IOValueRef scales_and_zeros; + + if (storage_type == utils::kBuffer) { + scales_and_zeros.value = graph.add_tensor( + {2, N, k_groups}, vkapi::kFloat, storage_type, utils::kWidthPacked); + } else { + scales_and_zeros.value = graph.add_tensor( + {2, N, k_groups}, vkapi::kFloat, storage_type, utils::kChannelsPacked); + } + + scales_and_zeros.staging = graph.set_input_tensor(scales_and_zeros.value); + std::vector s_data(graph.numel_of(scales_and_zeros.value)); const int zeros_stride = s_data.size() / 2; for (size_t i = 0; i < zeros_stride; i++) { @@ -3003,7 +3013,7 @@ void test_int4pack_mm( out_deq.staging, out_deq_data.data(), out_deq_data.size()); for (int i = 0; i < out_int4_data.size(); i++) { - CHECK_VALUE(out_int4_data, i, out_deq_data[i]); + EXPECT_TRUE(check_close(out_int4_data[i], out_deq_data[i])); } }