Skip to content

[Kernel] Support deep_gemm for linear methods #19085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions vllm/model_executor/layers/quantization/deepgemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.util
import logging

import torch

from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import direct_register_custom_op

has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
if has_deep_gemm:
import deep_gemm

logger = logging.getLogger(__name__)


def prepare_block_fp8_matmul_inputs(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> tuple[int, int, int, torch.Tensor]:
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]

assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
assert A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]

M = A.numel() // A.shape[-1]

assert B.ndim == 2
assert B.is_contiguous()
assert Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]

C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)

return M, N, K, C


def w8a8_block_fp8_matmul_deepgemm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
output_dtype)
# Deepgemm only supports output tensor type as bfloat16
assert C.dtype == torch.bfloat16
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
return C


def w8a8_block_fp8_matmul_deepgemm_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
output_dtype)
return C


direct_register_custom_op(
op_name="w8a8_block_fp8_matmul_deepgemm",
op_func=w8a8_block_fp8_matmul_deepgemm,
mutates_args=[],
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
dispatch_key=current_platform.dispatch_key,
)
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def apply(self,

if self.block_quant:
assert self.quant_config.weight_block_size is not None

return torch.ops.vllm.apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
Expand Down
40 changes: 39 additions & 1 deletion vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

# Adapted from https://github.com/sgl-project/sglang/pull/2575
import functools
import importlib.util
import json
import os
from typing import Any, Callable, Optional, Union

import torch

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand All @@ -20,6 +22,7 @@
from vllm.utils import direct_register_custom_op

logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None


def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
Expand Down Expand Up @@ -98,6 +101,19 @@ def dispatch_w8a8_blockscale_func(
return w8a8_block_fp8_matmul


def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
"""
Check if DeepGEMM should be used based on the output dtype and weight shape.
DeepGEMM is only supported for bfloat16 output dtype and weights with shape
divisible by 128.
"""

return (current_platform.is_cuda()
and current_platform.is_device_capability(90) and has_deep_gemm
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)


# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
def apply_w8a8_block_fp8_linear(
Expand All @@ -114,6 +130,29 @@ def apply_w8a8_block_fp8_linear(
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype

if should_use_deepgemm(output_dtype, weight):

input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]

q_input, x_scale = per_token_group_quant_fp8(
input_2d,
block_size[1],
column_major_scales=True,
)

output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=output_dtype)
if bias is not None:
output += bias
return output.to(dtype=output_dtype).view(*output_shape)

if current_platform.is_cuda():
if current_platform.has_device_capability(100):
Expand All @@ -134,7 +173,6 @@ def ceil_div(x: int, y: int) -> int:

w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported)

if use_cutlass:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass)
Expand Down