Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Implement fallback for FP8 channelwise using torch._scaled_mm #6552

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,6 @@ def __init__(self, strategy: str, is_static_input_scheme: bool):
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()

# On Lovelace, fail for now if channelwise.
# TODO: (@tms) fallback
if (not self.cutlass_fp8_supported
and self.strategy == QuantizationStrategy.CHANNEL):
raise ValueError(
"Channelwise fp8 quantization requires vLLM's custom "
"cutlass kernels, which are not supported on your device."
"Consider quantizing with per tensor scales or upgrading "
"to Hopper.")

def get_min_capability(self) -> int:
# lovelace and up
return 89
Expand All @@ -53,7 +43,6 @@ def process_weights_after_loading(self, layer) -> None:

# If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL:
assert self.cutlass_fp8_supported
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)

Expand Down
34 changes: 24 additions & 10 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,34 @@ def apply_fp8_linear(
bias=bias)

else:
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
qinput, x_scale = ops.scaled_fp8_quant(input,
input_scale,
batch_dim_padding=17)

# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
if weight_scale.numel() == 1:
# Fused GEMM_DQ
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
else:
# Fallback for channelwise case, where the weight scales are
# applied separately.
# Write output in fp32 to allow subsequent ops to happen in-place
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you just add a comment here that describes the math and why what we are doing works?

output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=torch.float32,
scale_a=x_scale)

output = output * weight_scale.t()
if bias is not None:
output = output + bias
output = output.to(dtype=input.dtype)

return torch.narrow(output, 0, 0, input.shape[0])

Expand Down
Loading