Skip to content

[ET-VK][int4] Wrap int4 linear calls with view_copy nodes to squeeze/unsqueeze inputs #8254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,21 @@ runtime.python_library(
]
)

runtime.python_library(
name = "squeeze_int4_linear_inputs",
srcs = [
"squeeze_int4_linear_inputs.py",
],
visibility = [
"//executorch/backends/...",
],
deps = [
"//executorch/backends/vulkan:custom_ops_lib",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
]
)

runtime.python_library(
name = "remove_asserts",
srcs = ["remove_asserts.py"],
Expand Down Expand Up @@ -99,6 +114,7 @@ runtime.python_library(
":remove_asserts",
":remove_local_scalar_dense",
":remove_redundant_ops",
":tag_memory_meta_pass"
":squeeze_int4_linear_inputs",
":tag_memory_meta_pass",
]
)
12 changes: 12 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
VkInt4WeightOnlyQuantizer,
Expand All @@ -12,6 +20,9 @@
from executorch.backends.vulkan._passes.remove_redundant_ops import (
RemoveRedundantOpsTransform,
)
from executorch.backends.vulkan._passes.squeeze_int4_linear_inputs import (
SqueezeInt4LinearInputs,
)
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass

__all__ = [
Expand All @@ -21,5 +32,6 @@
"RemoveAssertsTransform",
"RemoveLocalScalarDenseOpsTransform",
"RemoveRedundantOpsTransform",
"SqueezeInt4LinearInputs",
"TagMemoryMetaPass",
]
24 changes: 18 additions & 6 deletions backends/vulkan/_passes/int4_weight_only_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ def __init__(
from torchao.utils import find_multiple

self.origin_in_features = in_features
in_features = find_multiple(in_features, (1024,))
# pyre-ignore[6]: Incompatible parameter type
in_features = find_multiple(in_features, 1024)

self.use_bias = bias
self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.device = device
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
Expand Down Expand Up @@ -80,20 +81,28 @@ def __init__(
device=device,
),
)
if bias:
self.register_buffer(
"bias",
torch.empty((out_features,), dtype=torch.float32, device=device),
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.padding:
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
# The forward method is replaced. In the original implementation, the forward
# method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom
# operator is called instead.
return torch.ops.et_vk.linear_weight_int4(
r = torch.ops.et_vk.linear_weight_int4(
input,
self.weight,
self.groupsize,
self.scales_and_zeros,
self.inner_k_tiles,
)
if self.use_bias:
return r + self.bias
return r


# This function is coped from torchao.quantization.GPTQ._replace_linear_int4
Expand Down Expand Up @@ -128,7 +137,7 @@ def _vk_replace_linear_int4(
new_linear = linear_class(
child.in_features,
child.out_features,
bias=False,
bias=child.bias is not None,
device=child.weight.device,
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
Expand All @@ -138,6 +147,9 @@ def _vk_replace_linear_int4(
if copy_weights and child.weight.device != torch.device("meta"):
# pyre-fixme[16]: `Module` has no attribute `weight`.
new_linear.weight = child.weight
if child.bias is not None:
# pyre-fixme[16]: `Module` has no attribute `bias`.
new_linear.bias = child.bias
setattr(module, name, new_linear)
else:
_vk_replace_linear_int4(
Expand Down Expand Up @@ -189,7 +201,6 @@ def _create_quantized_state_dict(
mod.out_features < self.feature_limit
and mod.in_features < self.feature_limit
):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
logging.info(f"linear: {fqn}, in={in_features}, out={out_features}")
Expand All @@ -210,7 +221,8 @@ def _create_quantized_state_dict(
logging.warn(
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
)
padded_in_features = find_multiple(in_features, (1024,))
# pyre-ignore[6]: Incompatible parameter type
padded_in_features = find_multiple(in_features, 1024)
weight = F.pad(
weight, pad=(0, padded_in_features - in_features)
)
Expand Down
64 changes: 64 additions & 0 deletions backends/vulkan/_passes/squeeze_int4_linear_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Dict, List, Tuple

import executorch.backends.vulkan.custom_ops_lib # noqa: needed to access vk op
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue

from torch.fx.node import Argument


class SqueezeInt4LinearInputs(ExportPass):
def call_operator(
self,
op, # pyre-ignore
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
def _squeezable(shape: List[int]) -> bool:
return len(shape) > 2 and 1 in shape

if op != exir_ops.edge.et_vk.linear_weight_int4.default:
return super().call_operator(op, args, kwargs, meta)

# pyre-ignore[16]: `None` has no attribute `node`
input_shape = args[0].node.meta["val"].shape
output_shape = meta["val"].shape
if not _squeezable(input_shape):
return super().call_operator(op, args, kwargs, meta)

# squeeze input tensor
squeeze_shape = list(input_shape)
while _squeezable(squeeze_shape):
squeeze_shape.remove(1)

squeeze_out = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(args[0], squeeze_shape),
kwargs,
meta,
)
# call linear on squeezed output
new_args = (squeeze_out, *args[1:])
linear_out = super().call_operator(
op,
new_args,
kwargs,
meta,
)
# unsqueeze output
unsqueeze_shape = list(output_shape)
return super().call_operator(
exir_ops.edge.aten.view_copy.default,
(linear_out, unsqueeze_shape),
kwargs,
meta,
)
37 changes: 29 additions & 8 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,6 @@ void check_q_4w_linear_args(
const int group_size_val = graph.extract_scalar<int>(group_size);
VK_CHECK_COND(K % group_size_val == 0);

VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim);

VK_CHECK_COND(graph.has_standard_axis_map(mat1));
VK_CHECK_COND(graph.has_standard_axis_map(out));
}
Expand Down Expand Up @@ -320,13 +317,32 @@ void add_q_4w_linear_node(

const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size);

ValueRef mat1_W_packed = mat1;
ValueRef out_W_packed = out;
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
// Create temporary tensors to store the width packed versions of mat1 and out
TmpTensor mat1_tmp(
&graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked);
TmpTensor out_tmp(
&graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked);
if (storage_type == utils::kTexture3D) {
if (!graph.is_buffer_storage(out) &&
graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
// Ensure mat1 is width packed
mat1_W_packed = mat1_tmp;
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
// Ensure out is packed correctly
out_W_packed = out_tmp;
}
}

vkapi::ParamsBindList ubos({});
ubos.append(graph.logical_limits_ubo(out));
ubos.append(graph.sizes_ubo(mat1));
ubos.append(graph.logical_limits_ubo(out_W_packed));
ubos.append(graph.sizes_ubo(mat1_W_packed));
ubos.append(graph.strides_ubo(mat2));
ubos.append(graph.strides_ubo(scales_and_zeros));

utils::uvec3 global_wg_size = graph.logical_limits_of(out);
utils::uvec3 global_wg_size = graph.logical_limits_of(out_W_packed);
utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);

graph.execute_nodes().emplace_back(new DispatchNode(
Expand All @@ -335,15 +351,20 @@ void add_q_4w_linear_node(
global_wg_size,
local_wg_size,
// Inputs and Outputs
{{out, vkapi::MemoryAccessType::WRITE},
{{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}},
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
{{mat1_W_packed, mat2, scales_and_zeros},
vkapi::MemoryAccessType::READ}},
// Shader params buffers
ubos,
// Specialization Constants
{SV(group_size_val)},
// Resizing Logic
resize_q_4w_linear_node,
{}));
if (!graph.is_buffer_storage(out) &&
graph.packed_dim_of(out) != WHCN::kWidthDim) {
viewFn(graph, {out_W_packed, graph.add_none(), out});
}
}

void linear_weight_int4(
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def define_common_targets(is_fbcode = False):
"//executorch/backends/transforms:fuse_dequant_linear",
"//executorch/backends/transforms:fuse_view_copy",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/backends/transforms:view_copy_to_squeeze_unsqueeze",
"//executorch/backends/vulkan/_passes:vulkan_passes",
"//executorch/backends/vulkan/serialization:lib",
"//executorch/exir/backend:backend_details",
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@
from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
from executorch.backends.transforms.view_copy_to_squeeze_unsqueeze import (
ViewCopyToSqueezeUnsqueezePass,
)
from executorch.backends.vulkan._passes import (
insert_prepack_nodes,
RemoveLocalScalarDenseOpsTransform,
RemoveRedundantOpsTransform,
SqueezeInt4LinearInputs,
TagMemoryMetaPass,
)

Expand Down Expand Up @@ -149,7 +153,9 @@ def preprocess( # noqa: C901
RemoveRedundantOpsTransform(),
AddmmToLinearTransform(),
FuseDequantLinearPass(),
SqueezeInt4LinearInputs(),
FuseViewCopyTransform(),
ViewCopyToSqueezeUnsqueezePass(),
FuseBatchNormWithConvPass(program),
FuseClampPass(),
],
Expand Down
Loading