Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
687a386
fixing merge issues with main branch
LopezCastroRoberto Sep 11, 2025
6e788c5
fix: apply pre-commit formatting
LopezCastroRoberto Sep 11, 2025
72a7809
add spdx license
LopezCastroRoberto Sep 11, 2025
e8fb103
fix: apply pre-commit formatting
LopezCastroRoberto Sep 11, 2025
d5ca6c5
style: apply auto-fixes from pre-commit
LopezCastroRoberto Sep 11, 2025
39bbc0e
Merge branch 'main' into transforms
LopezCastroRoberto Sep 22, 2025
d774b11
style: apply auto-fixes from pre-commit
LopezCastroRoberto Sep 22, 2025
9c6f15c
custom ops fake fix
BlackSamorez Sep 30, 2025
f4a6f15
use pytest and cuda version check
LopezCastroRoberto Oct 1, 2025
7d378e0
adding hint to cmake for CI
LopezCastroRoberto Oct 2, 2025
c5e0595
Merge branch 'vllm-project:main' into transforms
LopezCastroRoberto Oct 2, 2025
84d6dda
re-define fake methods
LopezCastroRoberto Oct 2, 2025
6bdaff3
style: apply auto-fixes from pre-commit
LopezCastroRoberto Oct 2, 2025
a167efd
[CI/Build] limit to sm100
LopezCastroRoberto Oct 3, 2025
d32c180
[CI/Build] flip sm order
LopezCastroRoberto Oct 3, 2025
05543ed
Basic FP-Quant integration
BlackSamorez Sep 30, 2025
fe63388
Merge pull request #1 from LopezCastroRoberto/integration
LopezCastroRoberto Oct 6, 2025
f55e5b0
fix FP-Quant integration
LopezCastroRoberto Oct 6, 2025
dedb28a
Merge branch 'main' into transforms
LopezCastroRoberto Oct 7, 2025
e754b83
style: apply auto-fixes from pre-commit
LopezCastroRoberto Oct 7, 2025
8cca014
change dir qutlass utils
LopezCastroRoberto Oct 8, 2025
b3e354e
remove noqa e501
LopezCastroRoberto Oct 8, 2025
50eec02
Merge branch 'vllm-project:main' into transforms
LopezCastroRoberto Oct 8, 2025
df98064
isolate QuTLASS compile flags
LopezCastroRoberto Oct 9, 2025
057cbde
fix docstring
LopezCastroRoberto Oct 9, 2025
df2ee30
Merge branch 'main' into transforms
mgoin Oct 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,8 @@ steps:
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
- pytest -v -s tests/kernels/moe/test_flashinfer.py
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py

- label: Blackwell GPT-OSS Eval
timeout_in_minutes: 60
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,7 @@ endif()
# For CUDA we also build and ship some external projects.
if (VLLM_GPU_LANG STREQUAL "CUDA")
include(cmake/external_projects/flashmla.cmake)
include(cmake/external_projects/qutlass.cmake)

# vllm-flash-attn should be last as it overwrites some CMake functions
include(cmake/external_projects/vllm_flash_attn.cmake)
Expand Down
191 changes: 191 additions & 0 deletions benchmarks/kernels/bench_mxfp4_qutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import copy
import itertools

import torch
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
from weight_shapes import WEIGHT_SHAPES

from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.triton_utils import triton

PROVIDER_CFGS = {
"torch-bf16": dict(enabled=True),
"mxfp4": dict(no_a_quant=False, enabled=True),
"mxfp4-noquant": dict(no_a_quant=True, enabled=True),
}

_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]


def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
return (
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
* group_size**-0.5
)


def _quant_weight_mxfp4(
b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, device: str
):
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeMx(
b, forward_hadamard_matrix, method="abs_max"
)
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton")
return weight_hf_e2m1, weight_hf_scale_block


def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device):
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4(
b, forward_hadamard_matrix, device
)
alpha = torch.tensor([1.0], device="cuda")

if cfg["no_a_quant"]:
# Pre-quantize activation
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
a, forward_hadamard_matrix, method="abs_max"
)
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")

def run():
return matmul_mxf4_bf16_tn(
input_hf_e2m1,
weight_hf_e2m1,
input_hf_scale_block,
weight_hf_scale_block,
alpha,
)

return run

# Quantize activation on-the-fly
def run():
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
a, forward_hadamard_matrix, method="abs_max"
)
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
return matmul_mxf4_bf16_tn(
input_hf_e2m1,
weight_hf_e2m1,
input_hf_scale_block,
weight_hf_scale_block,
alpha,
)

return run


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[
1,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
24576,
32768,
],
x_log=False,
line_arg="provider",
line_vals=_enabled,
line_names=_enabled,
ylabel="TFLOP/s (larger is better)",
plot_name="BF16 vs MXFP4 GEMMs",
args={},
)
)
def benchmark(batch_size, provider, N, K, had_size):
M = batch_size
device = "cuda"
dtype = torch.bfloat16

a = torch.randn((M, K), device=device, dtype=dtype)
b = torch.randn((N, K), device=device, dtype=dtype)
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)

quantiles = [0.5, 0.2, 0.8]

if provider == "torch-bf16":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
)
else:
cfg = PROVIDER_CFGS[provider]
run_quant = build_mxfp4_runner(
cfg, a, b, forward_hadamard_matrix, dtype, device
)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: run_quant(), rep=200, quantiles=quantiles
)

to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)


def prepare_shapes(args):
out = []
for model, tp_size in itertools.product(args.models, args.tp_sizes):
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_dim] //= tp_size
KN.append(model)
out.append(KN)
return out


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.3-70B-Instruct"],
choices=list(WEIGHT_SHAPES.keys()),
)
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
args = parser.parse_args()

for K, N, model in prepare_shapes(args):
for had_size in [32, 64, 128]:
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs MXFP4 GEMMs TFLOP/s:")
benchmark.run(
print_data=True,
show_plots=True,
save_path=f"bench_mxfp4_res_n{N}_k{K}",
N=N,
K=K,
had_size=had_size,
)

print("Benchmark finished!")
Loading