Skip to content

Commit 2662be1

Browse files
committed
refactor: move FP8 quantization functions into QuantFP8
Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
1 parent b50d163 commit 2662be1

File tree

3 files changed

+63
-129
lines changed

3 files changed

+63
-129
lines changed

vllm/model_executor/layers/fused_moe/utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
import torch
77

8-
from vllm.model_executor.layers.quantization.utils.fp8_quant_ops import (
9-
quantize_fp8_per_group, quantize_fp8_per_tensor, quantize_fp8_per_token)
8+
from vllm import _custom_ops as ops
9+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
10+
per_token_group_quant_fp8)
1011
from vllm.model_executor.layers.quantization.utils.int8_utils import (
1112
per_token_group_quant_int8, per_token_quant_int8)
1213
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
@@ -121,19 +122,18 @@ def _fp8_quantize(
121122
is provided, the output will be blocked.
122123
"""
123124
if block_shape is None:
124-
if per_act_token:
125-
return quantize_fp8_per_token(A, A_scale)
126-
else:
127-
return quantize_fp8_per_tensor(A, A_scale)
125+
# TODO(luka): use QuantFP8 custom op
126+
# https://github.com/vllm-project/vllm/issues/20711
127+
A, A_scale = ops.scaled_fp8_quant(
128+
A, A_scale, use_per_token_if_dynamic=per_act_token)
128129
else:
129-
assert not per_act_token, \
130-
"per_act_token not supported with block_shape"
131-
assert A_scale is None, \
132-
"Group quantization doesn't support static scales"
133-
assert len(block_shape) == 2, "block_shape must be [m, k]"
130+
assert not per_act_token
131+
assert len(block_shape) == 2
134132
_, block_k = block_shape[0], block_shape[1]
135-
return quantize_fp8_per_group(
136-
A, block_k, column_major_scales=False) # Use row-major for MoE
133+
A, A_scale = per_token_group_quant_fp8(A, block_k)
134+
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
135+
136+
return A, A_scale
137137

138138

139139
def _int8_quantize(

vllm/model_executor/layers/quantization/input_quant_fp8.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77

88
from vllm import _custom_ops as ops
99
from vllm.model_executor.custom_op import CustomOp
10-
from vllm.model_executor.layers.quantization.utils.fp8_quant_ops import (
11-
quantize_fp8_per_group, quantize_fp8_per_tensor, quantize_fp8_per_token)
1210
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1311
GroupShape)
1412
from vllm.platforms import current_platform
1513

1614
# Using the default value (240.0) from pytorch will cause accuracy
1715
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm.
1816
_FP8_DTYPE = current_platform.fp8_dtype()
17+
_FP8_FINFO = torch.finfo(_FP8_DTYPE)
18+
_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max
19+
_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min
20+
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
1921

2022

2123
@CustomOp.register("quant_fp8")
@@ -92,9 +94,25 @@ def forward_native(
9294
and scale_ub.numel() == 1)
9395

9496
if self.use_per_token_if_dynamic and scale is None:
95-
out, scale = quantize_fp8_per_token(x, scale, scale_ub)
97+
# Per-token quantization logic
98+
x_max, _ = x.abs().max(dim=-1)
99+
x_max = x_max.unsqueeze(-1).to(torch.float32)
100+
if scale_ub is not None:
101+
x_max = x_max.clamp(max=scale_ub)
102+
scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
103+
104+
out = x.to(torch.float32) * scale.reciprocal()
105+
out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
96106
else:
97-
out, scale = quantize_fp8_per_tensor(x, scale)
107+
# Per-tensor quantization logic
108+
if scale is None:
109+
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
110+
scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
111+
112+
# Even for dynamic per-token scales,
113+
# reciprocal performs slightly better than division
114+
out = x.to(torch.float32) * scale.reciprocal()
115+
out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
98116

99117
# This currently generates an extra Triton kernel in compilation.
100118
# Fortunately, we don't use padding if compiling.
@@ -118,5 +136,31 @@ def _quantize_group_cuda(
118136

119137
def _quantize_group_native(
120138
self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
121-
return quantize_fp8_per_group(x, self.group_size,
122-
self.column_major_scales)
139+
orig_shape = x.shape
140+
hidden_dim = x.shape[-1]
141+
num_groups = (hidden_dim + self.group_size - 1) // self.group_size
142+
padded_dim = num_groups * self.group_size
143+
144+
if padded_dim != hidden_dim:
145+
padding = padded_dim - hidden_dim
146+
x = F.pad(x, (0, padding), mode='constant', value=0.0)
147+
148+
x_grouped = x.view(-1, num_groups, self.group_size)
149+
absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float()
150+
scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
151+
152+
x_scaled = x_grouped / scales
153+
x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
154+
155+
x_quant = x_quant.view(-1, padded_dim)
156+
if padded_dim != hidden_dim:
157+
x_quant = x_quant[..., :hidden_dim]
158+
x_quant = x_quant.view(orig_shape)
159+
160+
scales = scales.squeeze(-1)
161+
scales = scales.reshape(orig_shape[:-1] + (num_groups, ))
162+
163+
if self.column_major_scales:
164+
scales = scales.transpose(-2, -1).contiguous()
165+
166+
return x_quant, scales

vllm/model_executor/layers/quantization/utils/fp8_quant_ops.py

Lines changed: 0 additions & 110 deletions
This file was deleted.

0 commit comments

Comments
 (0)