Skip to content

Commit b50d163

Browse files
committed
Add PyTorch implementation for QuantFP8 group quantization
Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
1 parent 74bd084 commit b50d163

File tree

6 files changed

+501
-60
lines changed

6 files changed

+501
-60
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#!/usr/bin/env python
2+
# SPDX-License-Identifier: Apache-2.0
3+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
4+
"""Benchmark for QuantFP8 Group Quantization implementation."""
5+
6+
import argparse
7+
8+
import torch
9+
10+
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
11+
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
12+
from vllm.platforms import current_platform
13+
14+
15+
def _time_cuda(
16+
fn,
17+
warmup_iters: int,
18+
bench_iters: int,
19+
) -> float:
20+
# warmup
21+
for _ in range(warmup_iters):
22+
fn()
23+
torch.cuda.synchronize()
24+
25+
start = torch.cuda.Event(enable_timing=True)
26+
end = torch.cuda.Event(enable_timing=True)
27+
28+
start.record()
29+
for _ in range(bench_iters):
30+
fn()
31+
end.record()
32+
torch.cuda.synchronize()
33+
34+
return start.elapsed_time(end) / bench_iters # ms/iter
35+
36+
37+
def run_benchmark(
38+
shape: tuple[int, int],
39+
group_size: int,
40+
column_major: bool,
41+
warmup_iters: int,
42+
bench_iters: int,
43+
) -> None:
44+
"""Benchmark QuantFP8 with group quantization using different backends."""
45+
num_tokens, hidden_dim = shape
46+
47+
device = torch.device("cuda")
48+
torch.manual_seed(42)
49+
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) * 8
50+
51+
group_shape = GroupShape(1, group_size)
52+
quant_op = QuantFP8(
53+
static=False, group_shape=group_shape, column_major_scales=column_major
54+
)
55+
56+
def cuda_impl():
57+
return quant_op.forward_cuda(x.clone())
58+
59+
def native_impl():
60+
return quant_op.forward_native(x.clone())
61+
62+
cuda_ms = _time_cuda(cuda_impl, warmup_iters, bench_iters)
63+
native_ms = _time_cuda(native_impl, warmup_iters, bench_iters)
64+
65+
speedup = cuda_ms / native_ms if native_ms else 0
66+
67+
cfg_desc = f"shape={shape} gs={group_size:<3} col_major={column_major}"
68+
print(f"{cfg_desc:45} | {cuda_ms:7.3f} | {native_ms:7.3f} | {speedup:6.2f}x")
69+
70+
71+
def parse_args():
72+
parser = argparse.ArgumentParser(
73+
description="Benchmark QuantFP8 group quantization implementation"
74+
)
75+
parser.add_argument(
76+
"--warmup-iters", type=int, default=10, help="Number of warmup iterations"
77+
)
78+
parser.add_argument(
79+
"--bench-iters", type=int, default=100, help="Number of benchmark iterations"
80+
)
81+
parser.add_argument(
82+
"--shapes",
83+
type=str,
84+
default="32,128;64,256;16,512;128,1024;256,2048",
85+
help="Shapes to benchmark as 'tokens,hidden;...' (default: multiple shapes)",
86+
)
87+
parser.add_argument(
88+
"--group-sizes",
89+
type=str,
90+
default="64,128",
91+
help="Group sizes to benchmark (comma-separated)",
92+
)
93+
parser.add_argument(
94+
"--no-column-major",
95+
action="store_true",
96+
help="Skip column-major scale benchmarks",
97+
)
98+
return parser.parse_args()
99+
100+
101+
def main():
102+
if not current_platform.is_cuda():
103+
raise RuntimeError("CUDA device is required to run this benchmark.")
104+
105+
args = parse_args()
106+
107+
shapes = []
108+
for shape_str in args.shapes.split(";"):
109+
tokens, hidden = map(int, shape_str.split(","))
110+
shapes.append((tokens, hidden))
111+
112+
group_sizes = list(map(int, args.group_sizes.split(",")))
113+
114+
print("\n" + "=" * 80)
115+
print("QuantFP8 Group Quantization Benchmark (CUDA kernel vs PyTorch native)")
116+
print("=" * 80)
117+
print(f"Device: {torch.cuda.get_device_name()}")
118+
print(f"Warmup iterations: {args.warmup_iters}")
119+
print(f"Benchmark iterations: {args.bench_iters}")
120+
print("=" * 80)
121+
122+
print(f"{'Configuration':45} | {'CUDA':^9} | {'Native':^9} | {'Speedup':^8}")
123+
print("-" * 80)
124+
125+
for shape in shapes:
126+
for gs in group_sizes:
127+
run_benchmark(
128+
shape,
129+
gs,
130+
column_major=False,
131+
warmup_iters=args.warmup_iters,
132+
bench_iters=args.bench_iters,
133+
)
134+
135+
if not args.no_column_major:
136+
run_benchmark(
137+
shape,
138+
gs,
139+
column_major=True,
140+
warmup_iters=args.warmup_iters,
141+
bench_iters=args.bench_iters,
142+
)
143+
144+
print("=" * 80)
145+
146+
147+
if __name__ == "__main__":
148+
main()
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Tests for QuantFP8 Group Quantization implementation."""
4+
5+
import pytest
6+
import torch
7+
8+
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
9+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
10+
GroupShape)
11+
from vllm.platforms import current_platform
12+
13+
14+
@pytest.mark.parametrize("batch_size", [16, 32])
15+
@pytest.mark.parametrize("hidden_dim",
16+
[256, 512, 513]) # Include non-divisible
17+
@pytest.mark.parametrize("group_size", [32, 64, 128])
18+
@pytest.mark.parametrize("seed", [42])
19+
@torch.inference_mode()
20+
def test_quantfp8_group_basic(batch_size: int, hidden_dim: int,
21+
group_size: int, seed: int) -> None:
22+
current_platform.seed_everything(seed)
23+
24+
x = torch.randn(
25+
(batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
26+
27+
# Create QuantFP8 with group quantization
28+
group_shape = GroupShape(1, group_size)
29+
quant_op = QuantFP8(static=False,
30+
group_shape=group_shape,
31+
column_major_scales=False)
32+
33+
expected_num_groups = (hidden_dim + group_size - 1) // group_size
34+
35+
# Test CUDA implementation (only supports divisible dimensions)
36+
if hidden_dim % group_size == 0:
37+
x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone())
38+
assert x_quant_cuda.shape == x.shape
39+
assert scales_cuda.shape == (batch_size, expected_num_groups)
40+
41+
# Test PyTorch native implementation
42+
x_quant_native, scales_native = quant_op.forward_native(x.clone())
43+
assert x_quant_native.shape == x.shape
44+
assert scales_native.shape == (batch_size, expected_num_groups)
45+
46+
# Test column_major_scales
47+
quant_op_col = QuantFP8(static=False,
48+
group_shape=group_shape,
49+
column_major_scales=True)
50+
_, scales_col = quant_op_col.forward_native(x.clone())
51+
assert scales_col.shape == (expected_num_groups, batch_size)
52+
53+
54+
@pytest.mark.parametrize("seed", [42])
55+
@torch.inference_mode()
56+
def test_quantfp8_group_multidimensional(seed: int) -> None:
57+
current_platform.seed_everything(seed)
58+
59+
group_size = 64
60+
61+
# Test with 3D input
62+
batch1, batch2, hidden_dim = 4, 8, 512
63+
x_3d = torch.randn(
64+
(batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
65+
66+
group_shape = GroupShape(1, group_size)
67+
quant_op = QuantFP8(static=False,
68+
group_shape=group_shape,
69+
column_major_scales=False)
70+
71+
x_quant, scales = quant_op.forward_native(x_3d.clone())
72+
assert x_quant.shape == x_3d.shape
73+
assert scales.shape == (batch1, batch2, hidden_dim // group_size)
74+
75+
# Test column_major_scales with multi-dim
76+
quant_op_col = QuantFP8(static=False,
77+
group_shape=group_shape,
78+
column_major_scales=True)
79+
_, scales_col = quant_op_col.forward_native(x_3d.clone())
80+
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)
81+
82+
# Test with 4D input
83+
batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
84+
x_4d = torch.randn((batch1, batch2, batch3, hidden_dim),
85+
dtype=torch.bfloat16,
86+
device="cuda") * 8
87+
88+
x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone())
89+
assert x_quant_4d.shape == x_4d.shape
90+
assert scales_4d.shape == (batch1, batch2, batch3,
91+
hidden_dim // group_size)
92+
93+
_, scales_4d_col = quant_op_col.forward_native(x_4d.clone())
94+
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size,
95+
batch3)
96+
97+
98+
@pytest.mark.parametrize("batch_size", [32])
99+
@pytest.mark.parametrize("hidden_dim", [1024])
100+
@pytest.mark.parametrize("group_size", [128])
101+
@pytest.mark.parametrize("seed", [42])
102+
@torch.inference_mode()
103+
def test_quantfp8_group_cuda_native_consistency(batch_size: int,
104+
hidden_dim: int,
105+
group_size: int,
106+
seed: int) -> None:
107+
"""Compare CUDA and native implementations for consistency."""
108+
current_platform.seed_everything(seed)
109+
110+
x = torch.randn(
111+
(batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
112+
113+
group_shape = GroupShape(1, group_size)
114+
quant_op = QuantFP8(static=False,
115+
group_shape=group_shape,
116+
column_major_scales=False)
117+
118+
# Run both implementations
119+
x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone())
120+
x_quant_native, scales_native = quant_op.forward_native(x.clone())
121+
122+
# Check shapes match
123+
assert x_quant_cuda.shape == x_quant_native.shape
124+
assert scales_cuda.shape == scales_native.shape
125+
126+
# Scales should match
127+
assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8)
128+
129+
# Quantized values should mostly match, with rare rounding differences
130+
# FP8 rounding at boundaries can differ between CUDA and PyTorch
131+
diff_count = (x_quant_cuda != x_quant_native).sum().item()
132+
diff_ratio = diff_count / x_quant_cuda.numel()
133+
assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}"
134+
135+
136+
@pytest.mark.parametrize("seed", [42])
137+
@torch.inference_mode()
138+
def test_quantfp8_group_edge_cases(seed: int) -> None:
139+
current_platform.seed_everything(seed)
140+
141+
batch_size = 16
142+
group_size = 64
143+
144+
# Test with single group (group_size >= hidden_dim)
145+
x_small = torch.randn(
146+
(batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
147+
group_shape = GroupShape(1, group_size)
148+
quant_op = QuantFP8(static=False,
149+
group_shape=group_shape,
150+
column_major_scales=False)
151+
152+
x_quant_small, scales_small = quant_op.forward_native(x_small.clone())
153+
assert x_quant_small.shape == x_small.shape
154+
assert scales_small.shape == (batch_size, 1)
155+
156+
# Test with zero inputs
157+
x_zero = torch.zeros((batch_size, 256),
158+
dtype=torch.bfloat16,
159+
device="cuda")
160+
x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone())
161+
assert x_quant_zero.shape == x_zero.shape
162+
assert (scales_zero > 0).all(), "Scales should be clamped to minimum"
163+
164+
# Test very large values
165+
x_large = torch.full((batch_size, 256),
166+
1000.0,
167+
dtype=torch.bfloat16,
168+
device="cuda")
169+
x_quant_large, scales_large = quant_op.forward_native(x_large.clone())
170+
assert x_quant_large.shape == x_large.shape
171+
# FP8 max is typically 448 or 224, so scales should be > 1
172+
assert (scales_large > 1.0).all(), "Large values should have scales > 1"
173+
174+
175+
@pytest.mark.parametrize(
176+
"batch_size,hidden_dim,group_size",
177+
[
178+
(16, 256, 16), # Small
179+
(64, 1024, 64), # Medium
180+
(128, 2048, 128), # Large
181+
(8, 513, 64), # Non-divisible (native only)
182+
])
183+
@pytest.mark.parametrize("seed", [42])
184+
@torch.inference_mode()
185+
def test_quantfp8_group_various_configs(batch_size: int, hidden_dim: int,
186+
group_size: int, seed: int) -> None:
187+
current_platform.seed_everything(seed)
188+
189+
x = torch.randn(
190+
(batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
191+
group_shape = GroupShape(1, group_size)
192+
quant_op = QuantFP8(static=False,
193+
group_shape=group_shape,
194+
column_major_scales=False)
195+
196+
expected_num_groups = (hidden_dim + group_size - 1) // group_size
197+
198+
x_quant_native, scales_native = quant_op.forward_native(x.clone())
199+
assert x_quant_native.shape == x.shape
200+
assert scales_native.shape == (batch_size, expected_num_groups)
201+
202+
if hidden_dim % group_size == 0:
203+
x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone())
204+
assert x_quant_cuda.shape == x.shape
205+
assert scales_cuda.shape == (batch_size, expected_num_groups)
206+
assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8)

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@
3232
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
3333
TopKWeightAndReduceNoOP)
3434
from vllm.model_executor.layers.fused_moe.utils import (
35-
_resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8)
35+
_resize_cache, moe_kernel_quantize_input)
3636
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
3737
calculate_tile_tokens_dim)
38+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
39+
per_token_group_quant_fp8)
3840
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
3941
dequant_mxfp4)
4042
from vllm.platforms import current_platform

0 commit comments

Comments
 (0)