-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
Use w8a8 quantized matmul Pallas kernel #19170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,8 @@ | |
from typing import Optional | ||
|
||
import torch | ||
# Required to register custom ops. | ||
import torch_xla.experimental.custom_kernel # noqa: F401 | ||
from functorch.experimental.control_flow import cond # noqa: F401 | ||
|
||
from vllm.model_executor.layers.quantization.utils import replace_parameter | ||
|
@@ -90,16 +92,24 @@ def apply_weights(self, | |
bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
w_q, w_s, _, _, _ = self._get_weight_params(layer) | ||
|
||
import torch_xla.experimental.xla_quantized_matmul # noqa: F401 | ||
out = torch.ops.xla.quantized_matmul(x, | ||
w_q, | ||
w_s, | ||
zero_point=None, | ||
block_size=-1, | ||
int4_weight=False, | ||
quantize_activation=True) | ||
# `quantized_matmul` output is fp32, cast it down to bf16 for perf | ||
out = out.to(x.dtype) | ||
# import torch_xla.experimental.xla_quantized_matmul # noqa: F401 | ||
# out = torch.ops.xla.quantized_matmul(x, | ||
# w_q, | ||
# w_s, | ||
# zero_point=None, | ||
# block_size=-1, | ||
# int4_weight=False, | ||
# quantize_activation=True) | ||
# # `quantized_matmul` output is fp32, cast it down to bf16 for perf | ||
# out = out.to(x.dtype) | ||
|
||
out = torch.ops.xla.quantized_matmul_int8( | ||
x, | ||
w_q, | ||
w_s, | ||
quantize_activation=True, | ||
) | ||
Comment on lines
+106
to
+111
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous implementation using What is the output data type of the new |
||
|
||
# Explicitly capture control flow to make dynamo happy. | ||
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 | ||
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion
assert "1024" in output or "0, 1" in output
is the same as in thetest_basic
function. For a test specifically targetingw8a8
quantization, could this assertion be made more specific to validate the correctness of the quantization itself? For example, comparing against known good outputs for this quantized model or checking for specific numerical properties might provide stronger validation.