From 0e7955d2263580136d85cc50ae69873ad3adcd07 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 14 May 2024 14:46:25 -0700 Subject: [PATCH] Implement `aten.linear.default` (#3594) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/3594 As title. Implementation is rather simple because the shaders just have to accumulate the `mat2` shader across the width dim rather than the height dim. Reviewed By: yipjustin Differential Revision: D57203869 fbshipit-source-id: 08932a75e66924a0dfb0816f8ccefa718a341dd8 --- .../runtime/graph/ops/glsl/addmm_naive.glsl | 11 +- .../runtime/graph/ops/glsl/addmm_naive.yaml | 4 + .../graph/ops/glsl/addmm_optimized.glsl | 22 ++-- .../graph/ops/glsl/addmm_optimized.yaml | 3 + .../vulkan/runtime/graph/ops/glsl/matmul.h | 102 +++++++++------- .../runtime/graph/ops/glsl/matmul_naive.glsl | 12 +- .../runtime/graph/ops/glsl/matmul_naive.yaml | 4 + .../graph/ops/glsl/matmul_optimized.glsl | 13 +-- .../graph/ops/glsl/matmul_optimized.yaml | 3 + .../vulkan/runtime/graph/ops/glsl/view.glsl | 16 +-- .../vulkan/runtime/graph/ops/impl/Linear.cpp | 109 ++++++++++++++---- .../vulkan/runtime/graph/ops/impl/MatMul.cpp | 65 +++++++---- .../vulkan/runtime/graph/ops/impl/MatMul.h | 22 ++++ .../vulkan/runtime/graph/ops/impl/View.cpp | 54 ++++++++- backends/vulkan/test/op_tests/cases.py | 21 ++++ 15 files changed, 330 insertions(+), 131 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/impl/MatMul.h diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive.glsl index abdbe24d22..dbc87eb794 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive.glsl @@ -10,6 +10,9 @@ #define PRECISION ${PRECISION} +$if MAT2_IS_TRANSPOSED: + #define MAT2_IS_TRANSPOSED + #include "indexing_utils.h" #include "matmul.h" @@ -45,7 +48,6 @@ void main() { } vec4 texel = vec4(0); - ivec3 mat1_pos = ivec3(0, pos.y, pos.z); $if MAT1_PACKING == "W_packed": $if MAT2_PACKING == "H_packed": @@ -53,16 +55,13 @@ void main() { texel = matmul_naive_W_packed_H_packed( im_mat1, im_mat2, - mat1_pos, - mat2_pos, + pos, in_sizes[0]); $elif MAT2_PACKING == "W_packed": - ivec3 mat2_pos = ivec3(pos.x, 0, pos.z); texel = matmul_naive_W_packed_W_packed( im_mat1, im_mat2, - mat1_pos, - mat2_pos, + pos, in_sizes[0]); $else: $raise Exception("Unsupported value for MAT2_PACKING") diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive.yaml b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive.yaml index 6861b312d5..48db85cb56 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive.yaml @@ -10,6 +10,7 @@ addmm_naive: NDIM: 3 MAT1_PACKING: W_packed MAT2_PACKING: H_packed + MAT2_IS_TRANSPOSED: false generate_variant_forall: DTYPE: - VALUE: float @@ -18,3 +19,6 @@ addmm_naive: - NAME: addmm_naive_W_packed_H_packed - NAME: addmm_naive_W_packed_W_packed MAT2_PACKING: W_packed + - NAME: linear_naive_W_packed_W_packed + MAT2_PACKING: W_packed + MAT2_IS_TRANSPOSED: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl index 2830a34290..9d45c33704 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl @@ -10,6 +10,9 @@ #define PRECISION ${PRECISION} +$if MAT2_IS_TRANSPOSED: + #define MAT2_IS_TRANSPOSED + #include "indexing_utils.h" #include "matmul.h" @@ -31,11 +34,8 @@ layout(set = 0, binding = 6) uniform PRECISION restrict SelfSizes { ivec4 self_sizes; }; -layout(set = 0, binding = 7) uniform PRECISION restrict PackedDimMeta { - int packed_dim_size; - int packed_dim_size_padded; - int packed_dim_texel_len; - int packed_dim_padding; +layout(set = 0, binding = 7) uniform PRECISION restrict InLimits { + ivec3 in_limits; }; layout(set = 0, binding = 8) uniform PRECISION restrict Params { @@ -57,8 +57,7 @@ void main() { im_mat2, pos, out_sizes[2], - packed_dim_texel_len, - packed_dim_padding); + in_limits[0]); for (int idx_c = 0; idx_c < FOUR; idx_c++) { for (int idx_r = 0; idx_r < FOUR; idx_r++) { @@ -70,17 +69,16 @@ void main() { out_pos, self_sizes.x == 1, self_sizes.y == 1); - results.data[idx_c][idx_r][0] = beta * self_texel.x + alpha * results.data[idx_c][idx_r][0]; // results is in transposed order w.r.t. the desired output imageStore( im_out, out_pos, vec4( - results.data[idx_c][idx_r][0], - results.data[idx_c][idx_r][1], - results.data[idx_c][idx_r][2], - results.data[idx_c][idx_r][3])); + beta * self_texel.x + alpha * results.data[idx_c][idx_r][0], + beta * self_texel.x + alpha * results.data[idx_c][idx_r][1], + beta * self_texel.x + alpha * results.data[idx_c][idx_r][2], + beta * self_texel.x + alpha * results.data[idx_c][idx_r][3])); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml index 53352342a8..73014d440d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml @@ -9,9 +9,12 @@ addmm_optimized: DTYPE: float NDIM: 3 PACKING: C_packed + MAT2_IS_TRANSPOSED: false generate_variant_forall: DTYPE: - VALUE: float - VALUE: half shader_variants: - NAME: addmm_optimized + - NAME: linear_optimized + MAT2_IS_TRANSPOSED: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul.h b/backends/vulkan/runtime/graph/ops/glsl/matmul.h index ec00a53a64..5a7f679587 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul.h +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul.h @@ -16,38 +16,66 @@ struct FloatMatrix { float data[FOUR][FOUR][FOUR]; }; +#ifdef MAT2_IS_TRANSPOSED +vec4 matmul_naive_W_packed_W_packed( +#else vec4 matmul_naive_W_packed_H_packed( - sampler3D im_mat1, - sampler3D im_mat2, - ivec3 mat1_pos, - ivec3 mat2_pos, +#endif + const sampler3D im_mat1, + const sampler3D im_mat2, + const ivec3 out_pos, const int width) { + ivec3 mat1_pos = ivec3(0, out_pos.y, out_pos.z); +#ifdef MAT2_IS_TRANSPOSED + ivec3 mat2_pos = ivec3(0, out_pos.x * 4, 0); +#else + ivec3 mat2_pos = ivec3(out_pos.x * 4, 0, out_pos.z); +#endif + vec4 texel = vec4(0); - int K = (width + 3) / 4; + const int K = (width + 3) / 4; for (int i = 0; i < K; ++i) { - vec4 mat1_tex = texelFetch(im_mat1, mat1_pos, 0); - vec4 sums = vec4( + const vec4 mat1_tex = texelFetch(im_mat1, mat1_pos, 0); +#ifdef MAT2_IS_TRANSPOSED + const vec4 sums = vec4( + dot(mat1_tex, texelFetch(im_mat2, mat2_pos, 0)), + dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 1, 0), 0)), + dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 2, 0), 0)), + dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 3, 0), 0))); +#else + const vec4 sums = vec4( dot(mat1_tex, texelFetch(im_mat2, mat2_pos, 0)), dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(1, 0, 0), 0)), dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(2, 0, 0), 0)), dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(3, 0, 0), 0))); +#endif texel += sums; mat1_pos.x++; +#ifdef MAT2_IS_TRANSPOSED + mat2_pos.x++; +#else mat2_pos.y++; +#endif } return texel; } +#ifdef MAT2_IS_TRANSPOSED +vec4 matmul_naive_W_packed_H_packed( +#else vec4 matmul_naive_W_packed_W_packed( - sampler3D im_mat1, - sampler3D im_mat2, - ivec3 mat1_pos, - ivec3 mat2_pos, +#endif + const sampler3D im_mat1, + const sampler3D im_mat2, + const ivec3 out_pos, const int width) { + ivec3 mat1_pos = ivec3(0, out_pos.y, out_pos.z); + ivec3 mat2_pos = ivec3(out_pos.x, 0, out_pos.z); + vec4 texel = vec4(0); int K = divup4(width); @@ -87,7 +115,7 @@ vec4 get_texel_W_packed( else if (broadcast_at_height) { self_texel = texelFetch(im_self, ivec3(pos.x, 0, 0), 0); } else { - self_texel = texelFetch(im_self, pos, 0); + self_texel = texelFetch(im_self, ivec3(pos.x, pos.y, 0), 0); } return self_texel; @@ -112,7 +140,7 @@ vec4 get_texel_C_packed( else if (broadcast_at_height) { self_texel = texelFetch(im_self, ivec3(pos.x, 0, 0), 0); } else { - self_texel = texelFetch(im_self, pos, 0); + self_texel = texelFetch(im_self, ivec3(pos.x, pos.y, 0), 0); } return self_texel; @@ -123,8 +151,7 @@ FloatMatrix matmul_partial_4x4( sampler3D im_mat2, const ivec3 pos, const int batch_size, - const int K_texel_len, - const int packed_dim_padding) { + const int K_texel_len) { FloatMatrix results; for (int i = 0; i < FOUR; i++) { for (int j = 0; j < FOUR; j++) { @@ -133,43 +160,36 @@ FloatMatrix matmul_partial_4x4( } } } - vec4 im_mat1_partial_rows[FOUR]; - vec4 im_mat2_partial_cols[FOUR]; + vec4 im_mat1_partial_load[FOUR]; + vec4 im_mat2_partial_load[FOUR]; for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) { if (FOUR * pos.z + batch_idx >= batch_size) { break; } - // read and cache 4x4 tile of im_mat1 (4 adjacent rows) + int mat_z = FOUR * pos.z + batch_idx; for (int mat1_x = 0; mat1_x < K_texel_len; mat1_x++) { - for (int mat1_row = 0; mat1_row < FOUR; mat1_row++) { - const int mat1_y = (FOUR * pos.y) + mat1_row; - const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, FOUR * pos.z + batch_idx); - im_mat1_partial_rows[mat1_row] = texelFetch(im_mat1, mat1_pos, 0); - // set the value out of the boundary to be 0 - if (mat1_x == K_texel_len - 1 && packed_dim_padding > 0) { - for (int kk = 0; kk < packed_dim_padding; kk++) { - im_mat1_partial_rows[mat1_row][3 - kk] = 0; - } - } - } - // read and cache 4x4 tile of im_mat2 (4 adjacent columns) - for (int mat2_col = 0; mat2_col < FOUR; mat2_col++) { - const int mat2_x = (FOUR * pos.x) + mat2_col; - const ivec3 pos_rd = ivec3(mat2_x, mat1_x, FOUR * pos.z + batch_idx); - im_mat2_partial_cols[mat2_col] = texelFetch(im_mat2, pos_rd, 0); - // set the value out of the boundary to be 0 - if (mat1_x == K_texel_len - 1 && packed_dim_padding > 0) { - for (int kk = 0; kk < packed_dim_padding; kk++) { - im_mat2_partial_cols[mat2_col][3 - kk] = 0; - } - } + for (int offset = 0; offset < FOUR; offset++) { + // read and cache 4x4 tile of im_mat1 + const int mat1_y = (FOUR * pos.y) + offset; + const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, mat_z); + im_mat1_partial_load[offset] = texelFetch(im_mat1, mat1_pos, 0); + // read and cache 4x4 tile of im_mat2 +#ifdef MAT2_IS_TRANSPOSED + const int mat2_y = (FOUR * pos.x) + offset; + const ivec3 mat2_pos = ivec3(mat1_x, mat2_y, 0); + im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0); +#else + const int mat2_x = (FOUR * pos.x) + offset; + const ivec3 mat2_pos = ivec3(mat2_x, mat1_x, mat_z); + im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0); +#endif } // perform partial dot products and add partial result to results for (int out_row = 0; out_row < FOUR; out_row++) { for (int out_col = 0; out_col < FOUR; out_col++) { results.data[out_row][out_col][batch_idx] += - dot(im_mat1_partial_rows[out_row], im_mat2_partial_cols[out_col]); + dot(im_mat1_partial_load[out_row], im_mat2_partial_load[out_col]); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl index d7e4395d04..37a9b60f3c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl @@ -10,6 +10,9 @@ #define PRECISION ${PRECISION} +$if MAT2_IS_TRANSPOSED: + #define MAT2_IS_TRANSPOSED + #include "indexing_utils.h" #include "matmul.h" @@ -35,24 +38,19 @@ void main() { } vec4 texel = vec4(0); - ivec3 mat1_pos = ivec3(0, pos.y, pos.z); $if MAT1_PACKING == "W_packed": $if MAT2_PACKING == "H_packed": - ivec3 mat2_pos = ivec3(pos.x * 4, 0, pos.z); texel = matmul_naive_W_packed_H_packed( im_mat1, im_mat2, - mat1_pos, - mat2_pos, + pos, in_sizes[0]); $elif MAT2_PACKING == "W_packed": - ivec3 mat2_pos = ivec3(pos.x, 0, pos.z); texel = matmul_naive_W_packed_W_packed( im_mat1, im_mat2, - mat1_pos, - mat2_pos, + pos, in_sizes[0]); $else: $raise Exception("Unsupported value for MAT2_PACKING") diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml index 727e8b361d..1c4db3f0ce 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml @@ -10,6 +10,7 @@ matmul_naive: NDIM: 3 MAT1_PACKING: W_packed MAT2_PACKING: H_packed + MAT2_IS_TRANSPOSED: false generate_variant_forall: DTYPE: - VALUE: float @@ -18,3 +19,6 @@ matmul_naive: - NAME: matmul_naive_W_packed_H_packed - NAME: matmul_naive_W_packed_W_packed MAT2_PACKING: W_packed + - NAME: matmul_transposed_naive_W_packed_W_packed + MAT2_PACKING: W_packed + MAT2_IS_TRANSPOSED: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl index dd9c57416d..f39bea12be 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl @@ -10,6 +10,9 @@ #define PRECISION ${PRECISION} +$if MAT2_IS_TRANSPOSED: + #define MAT2_IS_TRANSPOSED + #include "indexing_utils.h" #include "matmul.h" @@ -25,11 +28,8 @@ layout(set = 0, binding = 4) uniform PRECISION restrict OutSizes { ivec4 out_sizes; }; -layout(set = 0, binding = 5) uniform PRECISION restrict PackedDimMeta { - int packed_dim_size; - int packed_dim_size_padded; - int packed_dim_texel_len; - int packed_dim_padding; +layout(set = 0, binding = 5) uniform PRECISION restrict InLimits { + ivec3 in_limits; }; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -46,8 +46,7 @@ void main() { im_mat2, pos, out_sizes[2], - packed_dim_texel_len, - packed_dim_padding); + in_limits[0]); for (int idx_c = 0; idx_c < FOUR; idx_c++) { for (int idx_r = 0; idx_r < FOUR; idx_r++) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml index 7cec20e167..ecc62f7ca3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml @@ -9,9 +9,12 @@ matmul_optimized: DTYPE: float NDIM: 3 PACKING: C_packed + MAT2_IS_TRANSPOSED: false generate_variant_forall: DTYPE: - VALUE: float - VALUE: half shader_variants: - NAME: matmul_optimized + - NAME: matmul_transposed_optimized + MAT2_IS_TRANSPOSED: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.glsl b/backends/vulkan/runtime/graph/ops/glsl/view.glsl index 2429c841c9..6680baad03 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/view.glsl @@ -35,7 +35,7 @@ layout(constant_id = 4) const int out_packed_dim = C_DIM; void main() { const ivec3 out_pos = ivec3(gl_GlobalInvocationID); - const ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_sizes, out_packed_dim); + ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_sizes, out_packed_dim); if (all(greaterThanEqual(out_tensor_idx, out_sizes))) { return; @@ -46,13 +46,15 @@ void main() { // the input position from the indx. const ivec4 buf_indices = get_texel_nchw_buffer_ixs(out_tensor_idx, out_sizes, out_packed_dim); - VEC4_T value; + VEC4_T value = VEC4_T(0); // Need to look up the 4 values in the output texel separately. - for (int i =0 ; i < 4; i++) { - ivec4 user_coor = from_nchw_buffer_i(buf_indices[i], in_sizes); - ivec4 in_pos_elem = to_texture_elem_pos(user_coor, in_sizes, in_packed_dim); - VEC4_T intex = texelFetch(image_in, in_pos_elem.xyz, 0); - value[i] = intex[in_pos_elem.w]; + for (int i = 0 ; i < 4; i++) { + if (out_tensor_idx[out_packed_dim]++ < out_sizes[out_packed_dim]) { + ivec4 user_coor = from_nchw_buffer_i(buf_indices[i], in_sizes); + ivec4 in_pos_elem = to_texture_elem_pos(user_coor, in_sizes, in_packed_dim); + VEC4_T intex = texelFetch(image_in, in_pos_elem.xyz, 0); + value[i] = intex[in_pos_elem.w]; + } } imageStore(image_out, out_pos, value); diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 9e4ea7a9ba..8c963579da 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -56,21 +57,27 @@ void resize_addmm_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - (void)extra_args; vTensorPtr out = graph->get_tensor(args[0].refs[0]); vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); vTensorPtr self = graph->get_tensor(args[1].refs[2]); + bool mat2_is_transposed = graph->get_bool(extra_args[0]); + + const int out_cols = api::utils::val_at(-2, mat1->sizes()); + const int out_rows = mat2_is_transposed + ? api::utils::val_at(-2, mat2->sizes()) + : api::utils::val_at(-1, mat2->sizes()); + std::vector new_out_sizes(3); if (mat1->sizes().size() == 2) { new_out_sizes.resize(2); - new_out_sizes.at(0) = mat1->sizes().at(0); - new_out_sizes.at(1) = mat2->sizes().at(1); + new_out_sizes.at(0) = out_cols; + new_out_sizes.at(1) = out_rows; } else { new_out_sizes.at(0) = mat1->sizes().at(0); - new_out_sizes.at(1) = mat1->sizes().at(1); - new_out_sizes.at(2) = mat2->sizes().at(2); + new_out_sizes.at(1) = out_cols; + new_out_sizes.at(2) = out_rows; } out->virtual_resize(new_out_sizes); @@ -83,19 +90,22 @@ struct Params final { void add_addmm_naive_node( ComputeGraph& graph, - const ValueRef self, + const ValueRef self_data, const ValueRef mat1, const ValueRef mat2_data, const ValueRef beta, const ValueRef alpha, const ValueRef out, - const Params& params) { + const Params& params, + const ValueRef mat2_is_transposed) { + ValueRef self = prepack_if_tensor_ref(graph, self_data, api::kWidthPacked); ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, api::kHeightPacked); api::utils::uvec3 global_size = graph.extents_of(out); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - std::string kernel_name("addmm_naive"); + std::string kernel_name = + graph.get_bool(mat2_is_transposed) ? "linear_naive" : "addmm_naive"; kernel_name.reserve(kShaderNameReserve); add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat1)); add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat2)); @@ -119,18 +129,21 @@ void add_addmm_naive_node( // Specialization Constants {}, // Resizing Logic - resize_addmm_node)); + resize_addmm_node, + {mat2_is_transposed})); } void add_addmm_optimized_node( ComputeGraph& graph, - const ValueRef self, + const ValueRef self_data, const ValueRef mat1, const ValueRef mat2_data, const ValueRef beta, const ValueRef alpha, const ValueRef out, - const Params& params) { + const Params& params, + const ValueRef mat2_is_transposed) { + ValueRef self = prepack_if_tensor_ref(graph, self_data, api::kChannelsPacked); ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, api::kHeightPacked); // Ensure mat1 is width packed @@ -138,18 +151,24 @@ void add_addmm_optimized_node( auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); + const bool mat2_is_transposed_val = graph.get_bool(mat2_is_transposed); + // Ensure mat2 is height packed - ValueRef mat2_H_packed = mat2; - if (graph.memory_layout_of(mat2) != api::kHeightPacked) { - mat2_H_packed = graph.add_tensor_like(mat2, api::kHeightPacked); - viewFn(graph, {mat2, graph.add_none(), mat2_H_packed}); + ValueRef mat2_packed = mat2; + const api::GPUMemoryLayout mat2_layout = + mat2_is_transposed_val ? api::kWidthPacked : api::kHeightPacked; + if (graph.memory_layout_of(mat2) != mat2_layout) { + mat2_packed = graph.add_tensor_like(mat2, mat2_layout); + viewFn(graph, {mat2, graph.add_none(), mat2_packed}); } api::utils::uvec3 global_size = api::utils::divup_vec(graph.extents_of(out), {4, 4, 1}); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - std::string kernel_name("addmm_optimized"); + std::string kernel_name = graph.get_bool(mat2_is_transposed) + ? "linear_optimized" + : "addmm_optimized"; add_dtype_suffix(kernel_name, graph.dtype_of(out)); graph.execute_nodes().emplace_back(new ExecuteNode( @@ -159,19 +178,20 @@ void add_addmm_optimized_node( local_size, // Inputs and Outputs {{out, api::MemoryAccessType::WRITE}, - {{mat1_W_packed, mat2_H_packed, self}, api::MemoryAccessType::READ}}, + {{mat1_W_packed, mat2_packed, self}, api::MemoryAccessType::READ}}, // Shader params buffers { graph.texture_limits_ubo(out), graph.sizes_ubo(out), graph.sizes_ubo(self), - graph.packed_dim_meta_ubo(mat1_W_packed), + graph.texture_limits_ubo(mat1_W_packed), graph.create_params_buffer(params), }, // Specialization Constants {}, // Resizing Logic - resize_addmm_node)); + resize_addmm_node, + {mat2_is_transposed})); } void add_addmm_node( @@ -181,18 +201,25 @@ void add_addmm_node( const ValueRef mat2, const ValueRef beta, const ValueRef alpha, - const ValueRef out) { + const ValueRef out, + const ValueRef mat2_is_transposed) { float alpha_val = 1.0f; float beta_val = 1.0f; - alpha_val = graph.extract_scalar(alpha); - beta_val = graph.extract_scalar(beta); + if (alpha != kDummyValueRef) { + alpha_val = graph.extract_scalar(alpha); + } + if (beta != kDummyValueRef) { + beta_val = graph.extract_scalar(beta); + } Params params = {alpha_val, beta_val}; if (graph.memory_layout_of(mat1) == api::kChannelsPacked) { - add_addmm_optimized_node(graph, self, mat1, mat2, beta, alpha, out, params); + add_addmm_optimized_node( + graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed); } else if (graph.memory_layout_of(mat1) == api::kWidthPacked) { - add_addmm_naive_node(graph, self, mat1, mat2, beta, alpha, out, params); + add_addmm_naive_node( + graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed); } else { VK_THROW("Input should be channel packed or width packed."); } @@ -200,12 +227,44 @@ void add_addmm_node( void addmm(ComputeGraph& graph, const std::vector& args) { check_addmm_args(graph, args[0], args[1], args[2], args[3], args[4], args[5]); + ValueRef mat2_is_transposed = graph.add_scalar(false); return add_addmm_node( - graph, args[0], args[1], args[2], args[3], args[4], args[5]); + graph, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + mat2_is_transposed); +} + +void linear(ComputeGraph& graph, const std::vector& args) { + ValueRef input = args.at(0); + ValueRef weight_data = args.at(1); + ValueRef bias = args.at(2); + ValueRef out = args.at(3); + ValueRef weight = + prepack_if_tensor_ref(graph, weight_data, api::kWidthPacked); + ValueRef mat2_is_transposed = graph.add_scalar(true); + if (graph.val_is_none(bias)) { + return add_matmul_node(graph, input, weight, out, mat2_is_transposed); + } else { + return add_addmm_node( + graph, + bias, + input, + weight, + kDummyValueRef, + kDummyValueRef, + out, + mat2_is_transposed); + } } REGISTER_OPERATORS { VK_REGISTER_OP(aten.addmm.default, addmm); + VK_REGISTER_OP(aten.linear.default, linear); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index 063956ad31..0bdfad1c23 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -38,20 +39,26 @@ void resize_matmul_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - (void)extra_args; vTensorPtr out = graph->get_tensor(args[0].refs[0]); vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); + bool mat2_is_transposed = graph->get_bool(extra_args[0]); + + const int out_cols = api::utils::val_at(-2, mat1->sizes()); + const int out_rows = mat2_is_transposed + ? api::utils::val_at(-2, mat2->sizes()) + : api::utils::val_at(-1, mat2->sizes()); + std::vector new_out_sizes(3); if (mat1->sizes().size() == 2) { new_out_sizes.resize(2); - new_out_sizes.at(0) = mat1->sizes().at(0); - new_out_sizes.at(1) = mat2->sizes().at(1); + new_out_sizes.at(0) = out_cols; + new_out_sizes.at(1) = out_rows; } else { new_out_sizes.at(0) = mat1->sizes().at(0); - new_out_sizes.at(1) = mat1->sizes().at(1); - new_out_sizes.at(2) = mat2->sizes().at(2); + new_out_sizes.at(1) = out_cols; + new_out_sizes.at(2) = out_rows; } out->virtual_resize(new_out_sizes); @@ -61,13 +68,16 @@ void add_matmul_naive_node( ComputeGraph& graph, const ValueRef mat1, const ValueRef mat2_data, - const ValueRef out) { + const ValueRef out, + const ValueRef mat2_is_transposed) { ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, api::kHeightPacked); api::utils::uvec3 global_size = graph.extents_of(out); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - std::string kernel_name("matmul_naive"); + std::string kernel_name = graph.get_bool(mat2_is_transposed) + ? "matmul_transposed_naive" + : "matmul_naive"; kernel_name.reserve(kShaderNameReserve); add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat1)); add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat2)); @@ -89,14 +99,16 @@ void add_matmul_naive_node( // Specialization Constants {}, // Resizing Logic - resize_matmul_node)); + resize_matmul_node, + {mat2_is_transposed})); } void add_matmul_optimized_node( ComputeGraph& graph, const ValueRef mat1, const ValueRef mat2_data, - const ValueRef out) { + const ValueRef out, + const ValueRef mat2_is_transposed) { ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, api::kHeightPacked); // Ensure mat1 is width packed @@ -104,18 +116,24 @@ void add_matmul_optimized_node( auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); + const bool mat2_is_transposed_val = graph.get_bool(mat2_is_transposed); + // Ensure mat2 to height packed - ValueRef mat2_H_packed = mat2; - if (graph.memory_layout_of(mat2) != api::kHeightPacked) { - mat2_H_packed = graph.add_tensor_like(mat2, api::kHeightPacked); - viewFn(graph, {mat2, graph.add_none(), mat2_H_packed}); + ValueRef mat2_packed = mat2; + const api::GPUMemoryLayout mat2_layout = + mat2_is_transposed_val ? api::kWidthPacked : api::kHeightPacked; + if (graph.memory_layout_of(mat2) != mat2_layout) { + mat2_packed = graph.add_tensor_like(mat2, mat2_layout); + viewFn(graph, {mat2, graph.add_none(), mat2_packed}); } api::utils::uvec3 global_size = api::utils::divup_vec(graph.extents_of(out), {4, 4, 1}); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - std::string kernel_name("matmul_optimized"); + std::string kernel_name = mat2_is_transposed_val + ? "matmul_transposed_optimized" + : "matmul_optimized"; add_dtype_suffix(kernel_name, graph.dtype_of(out)); graph.execute_nodes().emplace_back(new ExecuteNode( @@ -125,26 +143,30 @@ void add_matmul_optimized_node( local_size, // Inputs and Outputs {{out, api::MemoryAccessType::WRITE}, - {{mat1_W_packed, mat2_H_packed}, api::MemoryAccessType::READ}}, + {{mat1_W_packed, mat2_packed}, api::MemoryAccessType::READ}}, // Shader params buffers { graph.texture_limits_ubo(out), graph.sizes_ubo(out), - graph.packed_dim_meta_ubo(mat1_W_packed), + graph.texture_limits_ubo(mat1_W_packed), }, // Specialization Constants - {})); + {}, + // Resizing Logic + resize_matmul_node, + {mat2_is_transposed})); } void add_matmul_node( ComputeGraph& graph, const ValueRef mat1, const ValueRef mat2_data, - const ValueRef out) { + const ValueRef out, + const ValueRef mat2_is_transposed) { if (graph.memory_layout_of(mat1) == api::kChannelsPacked) { - add_matmul_optimized_node(graph, mat1, mat2_data, out); + add_matmul_optimized_node(graph, mat1, mat2_data, out, mat2_is_transposed); } else if (graph.memory_layout_of(mat1) == api::kWidthPacked) { - add_matmul_naive_node(graph, mat1, mat2_data, out); + add_matmul_naive_node(graph, mat1, mat2_data, out, mat2_is_transposed); } else { VK_THROW("Input should be channel packed or width packed."); } @@ -152,7 +174,8 @@ void add_matmul_node( void matmul(ComputeGraph& graph, const std::vector& args) { check_matmul_args(graph, args[0], args[1], args[2]); - return add_matmul_node(graph, args[0], args[1], args[2]); + const ValueRef mat2_is_transposed = graph.add_scalar(false); + return add_matmul_node(graph, args[0], args[1], args[2], mat2_is_transposed); } REGISTER_OPERATORS { diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.h b/backends/vulkan/runtime/graph/ops/impl/MatMul.h new file mode 100644 index 0000000000..38f7907f1b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +void add_matmul_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef out, + const ValueRef mat2_is_transposed); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index ef23110d11..b3b4dedefd 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -14,7 +14,50 @@ namespace vkcompute { -void add_view_node(ComputeGraph& graph, ValueRef in, ValueRef out) { +std::vector compute_out_sizes( + std::vector orig_sizes, + std::vector& view_sizes) { + std::vector out_sizes(view_sizes.begin(), view_sizes.end()); + int64_t numel = 1; + int64_t transferred_numel = 1; + + for (int i = 0; i < orig_sizes.size(); i++) { + numel *= orig_sizes.at(i); + } + for (int i = 0; i < view_sizes.size(); i++) { + if (view_sizes.at(i) > 0) { + transferred_numel *= view_sizes.at(i); + } + } + for (int i = 0; i < out_sizes.size(); i++) { + if (out_sizes.at(i) == -1) { + out_sizes.at(i) = numel / transferred_numel; + } + } + return out_sizes; +} + +void resize_view_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + if (extra_args[0] == kDummyValueRef || graph->val_is_none(extra_args[0])) { + out->virtual_resize(in->sizes()); + } else { + IntListPtr view_sizes = graph->get_int_list(extra_args[0]); + std::vector out_sizes = + compute_out_sizes(in->sizes(), *view_sizes); + out->virtual_resize(out_sizes); + } +} + +void add_view_node( + ComputeGraph& graph, + ValueRef in, + ValueRef sizes, + ValueRef out) { vTensorPtr t_in = graph.get_tensor(in); vTensorPtr t_out = graph.get_tensor(out); @@ -35,13 +78,14 @@ void add_view_node(ComputeGraph& graph, ValueRef in, ValueRef out) { // Parameter Buffers {t_out->sizes_ubo(), t_in->sizes_ubo()}, // Specialization Constants - {SV(t_in->gpu_memory_layout_int()), SV(t_out->gpu_memory_layout_int())})); + {SV(t_in->gpu_memory_layout_int()), SV(t_out->gpu_memory_layout_int())}, + // Resizing Logic + resize_view_node, + {sizes})); } void view(ComputeGraph& graph, const std::vector& args) { - // Note: The second argument size_ref is not used here. Since the output - // tensor's size have been determined during compilation. - return add_view_node(graph, args[0], args[2]); + return add_view_node(graph, args[0], args[1], args[2]); } REGISTER_OPERATORS { diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index a1e6227a22..d115f1897f 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -94,6 +94,26 @@ def get_addmm_inputs(): return test_suite +def get_linear_inputs(): + MKN_list = [ + (S2, M2, M1), + (L, L, M1), + ] + + inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list] + inputs_list += [((M, K), (N, K), (N)) for M, K, N in MKN_list] + inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list] + inputs_list += [((3, M, K), (N, K), (N)) for M, K, N in MKN_list] + + test_suite = VkTestSuite(inputs_list) + test_suite.dtypes = ["at::kFloat"] + test_suite.layouts = [ + "api::kWidthPacked", + "api::kChannelsPacked", + ] + return test_suite + + def get_pool2d_inputs(): test_suite = VkTestSuite( [ @@ -747,6 +767,7 @@ def get_gelu_inputs(): "aten.addmm.default": get_addmm_inputs(), "aten.bmm.default": get_bmm_inputs(), "aten.mm.default": get_mm_inputs(), + "aten.linear.default": get_linear_inputs(), "aten.max_pool2d_with_indices.default": get_pool2d_inputs(), "aten.convolution.default": get_conv_inputs(), "aten.native_layer_norm.default": get_native_layer_norm_inputs(),