Skip to content

Use w8a8 quantized matmul Pallas kernel #19170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/v1/tpu/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,38 @@ def test_basic(
assert "1024" in output or "0, 1" in output


@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This is a basic test for TPU only")
def test_w8a8_quantization(
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
) -> None:
model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8"
max_tokens = 5
tensor_parallel_size = 1
max_num_seqs = 4

prompt = "The next numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:"
example_prompts = [prompt]

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

with vllm_runner(
model,
max_num_batched_tokens=64,
max_model_len=4096,
gpu_memory_utilization=0.7,
max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
output = vllm_outputs[0][1]

assert "1024" in output or "0, 1" in output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion assert "1024" in output or "0, 1" in output is the same as in the test_basic function. For a test specifically targeting w8a8 quantization, could this assertion be made more specific to validate the correctness of the quantization itself? For example, comparing against known good outputs for this quantized model or checking for specific numerical properties might provide stronger validation.



TP_SIZE_8 = 8


Expand Down
30 changes: 20 additions & 10 deletions vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import Optional

import torch
# Required to register custom ops.
import torch_xla.experimental.custom_kernel # noqa: F401
from functorch.experimental.control_flow import cond # noqa: F401

from vllm.model_executor.layers.quantization.utils import replace_parameter
Expand Down Expand Up @@ -90,16 +92,24 @@ def apply_weights(self,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer)

import torch_xla.experimental.xla_quantized_matmul # noqa: F401
out = torch.ops.xla.quantized_matmul(x,
w_q,
w_s,
zero_point=None,
block_size=-1,
int4_weight=False,
quantize_activation=True)
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
out = out.to(x.dtype)
# import torch_xla.experimental.xla_quantized_matmul # noqa: F401
# out = torch.ops.xla.quantized_matmul(x,
# w_q,
# w_s,
# zero_point=None,
# block_size=-1,
# int4_weight=False,
# quantize_activation=True)
# # `quantized_matmul` output is fp32, cast it down to bf16 for perf
# out = out.to(x.dtype)

out = torch.ops.xla.quantized_matmul_int8(
x,
w_q,
w_s,
quantize_activation=True,
)
Comment on lines +106 to +111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The previous implementation using torch.ops.xla.quantized_matmul included a cast out = out.to(x.dtype) with the comment "quantized_matmul output is fp32, cast it down to bf16 for perf".

What is the output data type of the new torch.ops.xla.quantized_matmul_int8? If it also outputs in a higher precision (e.g., fp32) and x.dtype is a lower precision format like bfloat16, is a similar cast to x.dtype still necessary for performance or type consistency with subsequent layers? If the new op already handles this or outputs directly in x.dtype, then this change is fine.


# Explicitly capture control flow to make dynamo happy.
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
3 changes: 3 additions & 0 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def init_device(self):
# ring, the xla tpu compiler flag
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
# fix this. It will be removed after the bug in XLA compiler is fixed.
# os.environ["LIBTPU_INIT_ARGS"] = (
# "--xla_tpu_force_1d_allreduce_at_chunk_count=1
# --xla_jf_conv_input_fusion=False")
os.environ["LIBTPU_INIT_ARGS"] = (
"--xla_tpu_force_1d_allreduce_at_chunk_count=1")
torch.set_grad_enabled(False)
Expand Down