-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Benchmark] Refactor benchmark script for fp8 & int8 #19627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
import argparse | ||
import copy | ||
import itertools | ||
|
@@ -11,35 +10,89 @@ | |
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant | ||
from vllm.triton_utils import triton | ||
|
||
PROVIDER_CFGS = { | ||
"torch-bf16": dict(enabled=True), | ||
"fp8-tensor-w-token-a": dict( | ||
w="tensor", a="token", no_a_quant=False, enabled=False | ||
), | ||
"fp8-tensor-w-tensor-a": dict( | ||
w="tensor", a="tensor", no_a_quant=False, enabled=True | ||
), | ||
"fp8-channel-w-token-a": dict( | ||
w="channel", a="token", no_a_quant=False, enabled=True | ||
), | ||
"fp8-channel-w-tensor-a": dict( | ||
w="channel", a="tensor", no_a_quant=False, enabled=False | ||
), | ||
"fp8-tensor-w-token-a-noquant": dict( | ||
w="tensor", a="token", no_a_quant=True, enabled=False | ||
), | ||
"fp8-tensor-w-tensor-a-noquant": dict( | ||
w="tensor", a="tensor", no_a_quant=True, enabled=True | ||
), | ||
"fp8-channel-w-token-a-noquant": dict( | ||
w="channel", a="token", no_a_quant=True, enabled=True | ||
), | ||
"fp8-channel-w-tensor-a-noquant": dict( | ||
w="channel", a="tensor", no_a_quant=True, enabled=False | ||
), | ||
} | ||
|
||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] | ||
|
||
|
||
def _quant_weight_fp8(b: torch.Tensor, w_type: str, device: str): | ||
if w_type == "tensor": | ||
scale_b = torch.ones(1, device=device, dtype=torch.float32) | ||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||
else: | ||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, use_per_token_if_dynamic=True) | ||
return b_fp8.t(), scale_b_fp8 | ||
|
||
|
||
def build_fp8_runner(cfg, a, b, dtype, device): | ||
b_fp8, scale_b_fp8 = _quant_weight_fp8(b, cfg["w"], device) | ||
|
||
scale_a_const = ( | ||
torch.ones(1, device=device, dtype=torch.float32) | ||
if cfg["a"] == "tensor" | ||
else None | ||
) | ||
|
||
if cfg["no_a_quant"]: | ||
if cfg["a"] == "tensor": | ||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const) | ||
else: | ||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True) | ||
|
||
def run(): | ||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||
|
||
return run | ||
|
||
if cfg["a"] == "tensor": | ||
|
||
def run(): | ||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const) | ||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||
|
||
else: | ||
|
||
def run(): | ||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True) | ||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||
|
||
return run | ||
|
||
|
||
@triton.testing.perf_report( | ||
triton.testing.Benchmark( | ||
x_names=["batch_size"], | ||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], | ||
x_log=False, | ||
line_arg="provider", | ||
line_vals=[ | ||
"torch-bf16", | ||
# "fp8-tensor-w-token-a", | ||
"fp8-tensor-w-tensor-a", | ||
"fp8-channel-w-token-a", | ||
# "fp8-channel-w-tensor-a", | ||
# "fp8-tensor-w-token-a-noquant", | ||
"fp8-tensor-w-tensor-a-noquant", | ||
"fp8-channel-w-token-a-noquant", | ||
# "fp8-channel-w-tensor-a-noquant", | ||
], | ||
line_names=[ | ||
"torch-bf16", | ||
# "fp8-tensor-w-token-a", | ||
"fp8-tensor-w-tensor-a", | ||
"fp8-channel-w-token-a", | ||
# "fp8-channel-w-tensor-a", | ||
# "fp8-tensor-w-token-a-noquant", | ||
"fp8-tensor-w-tensor-a-noquant", | ||
"fp8-channel-w-token-a-noquant", | ||
# "fp8-channel-w-tensor-a-noquant", | ||
], | ||
line_vals=_enabled, | ||
line_names=_enabled, | ||
ylabel="TFLOP/s (larger is better)", | ||
plot_name="BF16 vs FP8 GEMMs", | ||
args={}, | ||
|
@@ -50,144 +103,34 @@ def benchmark(batch_size, provider, N, K): | |
device = "cuda" | ||
dtype = torch.bfloat16 | ||
|
||
# Create input tensors | ||
a = torch.randn((M, K), device=device, dtype=dtype) | ||
b = torch.randn((N, K), device=device, dtype=dtype) | ||
|
||
quantiles = [0.5, 0.2, 0.8] | ||
|
||
if "torch-bf16" in provider: | ||
if provider == "torch-bf16": | ||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles | ||
) | ||
|
||
elif "fp8" in provider: | ||
# Weights are always quantized ahead of time | ||
if "noquant" in provider: | ||
# For no quantization, we just measure the GEMM | ||
if "tensor-w-token-a" in provider: | ||
# Dynamic per-token quant for A, per-tensor quant for B | ||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b) | ||
assert scale_b_fp8.numel() == 1 | ||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( | ||
a, use_per_token_if_dynamic=True | ||
) | ||
|
||
def run_quant(): | ||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||
|
||
elif "tensor-w-tensor-a" in provider: | ||
# Static per-tensor quantization with fixed scales | ||
# for both A and B | ||
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) | ||
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) | ||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||
assert scale_b_fp8.numel() == 1 | ||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) | ||
|
||
def run_quant(): | ||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||
|
||
elif "channel-w-token-a" in provider: | ||
# Static per-channel quantization for weights, per-token | ||
# quant for A | ||
scale_b = torch.tensor((N,), device=device, dtype=torch.float32) | ||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||
scale_b_fp8 = scale_b_fp8.expand(N).contiguous() | ||
assert scale_b_fp8.numel() == N | ||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( | ||
a, use_per_token_if_dynamic=True | ||
) | ||
|
||
def run_quant(): | ||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||
|
||
elif "channel-w-tensor-a" in provider: | ||
# Static per-channel quantization for weights, per-tensor | ||
# quant for A | ||
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) | ||
scale_b = torch.tensor((N,), device=device, dtype=torch.float32) | ||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||
scale_b_fp8 = scale_b_fp8.expand(N).contiguous() | ||
assert scale_b_fp8.numel() == N | ||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) | ||
|
||
def run_quant(): | ||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||
|
||
else: | ||
# In these cases, we quantize the activations during the GEMM call | ||
if "tensor-w-token-a" in provider: | ||
# Dynamic per-token quant for A, per-tensor quant for B | ||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b) | ||
assert scale_b_fp8.numel() == 1 | ||
|
||
def run_quant(): | ||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( | ||
a, use_per_token_if_dynamic=True | ||
) | ||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||
|
||
elif "tensor-w-tensor-a" in provider: | ||
# Static per-tensor quantization with fixed scales | ||
# for both A and B | ||
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) | ||
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) | ||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||
assert scale_b_fp8.numel() == 1 | ||
|
||
def run_quant(): | ||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) | ||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||
|
||
elif "channel-w-token-a" in provider: | ||
# Static per-channel quantization for weights, per-token | ||
# quant for A | ||
scale_b = torch.tensor((N,), device=device, dtype=torch.float32) | ||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||
scale_b_fp8 = scale_b_fp8.expand(N).contiguous() | ||
assert scale_b_fp8.numel() == N | ||
|
||
def run_quant(): | ||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( | ||
a, use_per_token_if_dynamic=True | ||
) | ||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||
|
||
elif "channel-w-tensor-a" in provider: | ||
# Static per-channel quantization for weights, per-tensor | ||
# quant for A | ||
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) | ||
scale_b = torch.tensor((N,), device=device, dtype=torch.float32) | ||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||
scale_b_fp8 = scale_b_fp8.expand(N).contiguous() | ||
assert scale_b_fp8.numel() == N | ||
|
||
def run_quant(): | ||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) | ||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||
|
||
b_fp8 = b_fp8.t() | ||
|
||
else: | ||
cfg = PROVIDER_CFGS[provider] | ||
run_quant = build_fp8_runner(cfg, a, b, dtype, device) | ||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||
lambda: run_quant(), quantiles=quantiles | ||
) | ||
|
||
# Calculate TFLOP/s, two flops per multiply-add | ||
tflops = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3) | ||
return tflops(ms), tflops(max_ms), tflops(min_ms) | ||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) | ||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) | ||
|
||
|
||
def prepare_shapes(args): | ||
KN_model_names = [] | ||
models_tps = list(itertools.product(args.models, args.tp_sizes)) | ||
for model, tp_size in models_tps: | ||
assert model in WEIGHT_SHAPES | ||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): | ||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size | ||
out = [] | ||
for model, tp_size in itertools.product(args.models, args.tp_sizes): | ||
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): | ||
KN[tp_dim] //= tp_size | ||
KN.append(model) | ||
KN_model_names.append(KN) | ||
return KN_model_names | ||
out.append(KN) | ||
return out | ||
|
||
|
||
if __name__ == "__main__": | ||
|
@@ -197,21 +140,13 @@ def prepare_shapes(args): | |
nargs="+", | ||
type=str, | ||
default=["meta-llama/Llama-3.1-8B-Instruct"], | ||
choices=[*WEIGHT_SHAPES.keys()], | ||
help="List of models to benchmark", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why remove the help here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the name "models" already show the meaning and we can save some space for "help", but certainly I can add it back if you wish. |
||
) | ||
parser.add_argument( | ||
"--tp-sizes", | ||
nargs="+", | ||
type=int, | ||
default=[1], | ||
help="List of tensor parallel sizes", | ||
choices=list(WEIGHT_SHAPES.keys()), | ||
) | ||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) | ||
args = parser.parse_args() | ||
|
||
KN_model_names = prepare_shapes(args) | ||
for K, N, model_name in KN_model_names: | ||
print(f"{model_name}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:") | ||
for K, N, model in prepare_shapes(args): | ||
print(f"{model}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:") | ||
benchmark.run( | ||
print_data=True, | ||
show_plots=True, | ||
|
Uh oh!
There was an error while loading. Please reload this page.