Skip to content

[ET-VK][int4] patch 4-bit linear op for ensuring w-packed in/out #8225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,6 @@ void check_q_4w_linear_args(
const int group_size_val = graph.extract_scalar<int>(group_size);
VK_CHECK_COND(K % group_size_val == 0);

VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim);

VK_CHECK_COND(graph.has_standard_axis_map(mat1));
VK_CHECK_COND(graph.has_standard_axis_map(out));
}
Expand Down Expand Up @@ -320,13 +317,32 @@ void add_q_4w_linear_node(

const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size);

ValueRef mat1_W_packed = mat1;
ValueRef out_W_packed = out;
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
// Create temporary tensors to store the width packed versions of mat1 and out
TmpTensor mat1_tmp(
&graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked);
TmpTensor out_tmp(
&graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked);
if (storage_type == utils::kTexture3D) {
if (!graph.is_buffer_storage(out) &&
graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
// Ensure mat1 is width packed
mat1_W_packed = mat1_tmp;
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
// Ensure out is packed correctly
out_W_packed = out_tmp;
}
}

vkapi::ParamsBindList ubos({});
ubos.append(graph.logical_limits_ubo(out));
ubos.append(graph.sizes_ubo(mat1));
ubos.append(graph.logical_limits_ubo(out_W_packed));
ubos.append(graph.sizes_ubo(mat1_W_packed));
ubos.append(graph.strides_ubo(mat2));
ubos.append(graph.strides_ubo(scales_and_zeros));

utils::uvec3 global_wg_size = graph.logical_limits_of(out);
utils::uvec3 global_wg_size = graph.logical_limits_of(out_W_packed);
utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);

graph.execute_nodes().emplace_back(new DispatchNode(
Expand All @@ -335,15 +351,20 @@ void add_q_4w_linear_node(
global_wg_size,
local_wg_size,
// Inputs and Outputs
{{out, vkapi::MemoryAccessType::WRITE},
{{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}},
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
{{mat1_W_packed, mat2, scales_and_zeros},
vkapi::MemoryAccessType::READ}},
// Shader params buffers
ubos,
// Specialization Constants
{SV(group_size_val)},
// Resizing Logic
resize_q_4w_linear_node,
{}));
if (!graph.is_buffer_storage(out) &&
graph.packed_dim_of(out) != WHCN::kWidthDim) {
viewFn(graph, {out_W_packed, graph.add_none(), out});
}
}

void linear_weight_int4(
Expand Down
Loading