Skip to content

Commit 3700642

Browse files
authored
[Refactor] Remove Duplicate per_block_cast_to_fp8, Remove Dependencies of DeepGEMM (#21787)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 0bd409c commit 3700642

File tree

8 files changed

+55
-132
lines changed

8 files changed

+55
-132
lines changed

benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,16 @@
44
# ruff: noqa: E501
55
import time
66

7-
# Import DeepGEMM functions
8-
import deep_gemm
97
import torch
10-
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
118

12-
# Import vLLM functions
139
from vllm import _custom_ops as ops
1410
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
11+
get_col_major_tma_aligned_tensor,
1512
per_token_group_quant_fp8,
1613
w8a8_block_fp8_matmul,
1714
)
1815
from vllm.triton_utils import triton
19-
20-
21-
# Copied from
22-
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9
23-
def per_token_cast_to_fp8(
24-
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
25-
"""Convert tensor to FP8 format with per-token scaling."""
26-
assert x.dim() == 2 and x.size(1) % 128 == 0
27-
m, n = x.shape
28-
x_view = x.view(m, -1, 128)
29-
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
30-
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(
31-
torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
32-
33-
34-
# Copied from
35-
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17
36-
def per_block_cast_to_fp8(
37-
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
38-
"""Convert tensor to FP8 format with per-block scaling."""
39-
assert x.dim() == 2
40-
m, n = x.shape
41-
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128),
42-
dtype=x.dtype,
43-
device=x.device)
44-
x_padded[:m, :n] = x
45-
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
46-
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
47-
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
48-
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
49-
x_amax / 448.0).view(x_view.size(0), x_view.size(2))
16+
from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8
5017

5118

5219
def benchmark_shape(m: int,
@@ -69,14 +36,14 @@ def benchmark_shape(m: int,
6936

7037
# Pre-quantize B for all implementations
7138
# (weights can be pre-quantized offline)
72-
B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B)
73-
B_vllm, B_scale_vllm = per_block_cast_to_fp8(B)
39+
B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B, [128, 128], use_ue8m0=True)
40+
B_vllm, B_scale_vllm = per_block_cast_to_fp8(B, [128, 128], use_ue8m0=True)
7441

7542
# Block size configuration
7643
block_size = [128, 128]
7744

7845
# Pre-quantize A for all implementations
79-
A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
46+
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
8047
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
8148
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
8249
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
@@ -85,7 +52,7 @@ def benchmark_shape(m: int,
8552

8653
# === DeepGEMM Implementation ===
8754
def deepgemm_gemm():
88-
deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm),
55+
fp8_gemm_nt((A_deepgemm, A_scale_deepgemm),
8956
(B_deepgemm, B_scale_deepgemm),
9057
C_deepgemm)
9158
return C_deepgemm

tests/kernels/moe/modular_kernel_tools/utils.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
import math
43

54
import torch
65

76
import vllm._custom_ops as ops
7+
from vllm.utils.deep_gemm import per_block_cast_to_fp8
88

99

1010
def per_token_cast_to_fp8(
@@ -20,29 +20,6 @@ def per_token_cast_to_fp8(
2020
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
2121

2222

23-
def per_block_cast_to_fp8(
24-
x: torch.Tensor, block_size_k: int,
25-
block_size_n: int) -> tuple[torch.Tensor, torch.Tensor]:
26-
assert x.dim() == 2
27-
m, n = x.shape
28-
x_padded = torch.zeros(
29-
(
30-
int(math.ceil(m / block_size_k)) * block_size_k,
31-
int(math.ceil(n / block_size_n)) * block_size_n,
32-
),
33-
dtype=x.dtype,
34-
device=x.device,
35-
)
36-
x_padded[:m, :n] = x
37-
x_view = x_padded.view(-1, block_size_k,
38-
x_padded.size(1) // block_size_k, block_size_n)
39-
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
40-
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
41-
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
42-
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
43-
return x_scaled_sub, scales
44-
45-
4623
def make_non_quant_weights(
4724
e: int,
4825
n: int,
@@ -99,11 +76,9 @@ def make_block_quant_fp8_weights(
9976

10077
for i in range(e):
10178
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
102-
block_size_k=block_k,
103-
block_size_n=block_n)
79+
block_size=[block_k, block_n])
10480
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
105-
block_size_k=block_k,
106-
block_size_n=block_n)
81+
block_size=[block_k, block_n])
10782

10883
return w1, w2, w1_s, w2_s
10984

tests/kernels/moe/test_cutlass_grouped_gemm.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
from tests.kernels.utils import baseline_scaled_mm
1313
from vllm import _custom_ops as ops
1414
from vllm.platforms import current_platform
15-
16-
17-
def cdiv(a, b):
18-
return (a + b - 1) // b
15+
from vllm.utils import cdiv
16+
from vllm.utils.deep_gemm import per_block_cast_to_fp8
1917

2018

2119
def per_token_cast_to_fp8(
@@ -32,21 +30,6 @@ def per_token_cast_to_fp8(
3230
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
3331

3432

35-
def per_block_cast_to_fp8(
36-
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
37-
assert x.dim() == 2
38-
m, n = x.shape
39-
x_padded = torch.zeros((cdiv(m, 128) * 128, cdiv(n, 128) * 128),
40-
device=x.device,
41-
dtype=x.dtype)
42-
x_padded[:m, :n] = x
43-
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
44-
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
45-
x_scaled = (x_view * (448.0 / x_amax)).to(dtype=torch.float8_e4m3fn)
46-
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
47-
x_amax / 448.0).view(x_view.size(0), x_view.size(2))
48-
49-
5033
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
5134
(4, 8192, 7168, 4096),
5235
(4, 8192, 2048, 7168),

tests/kernels/moe/test_deepgemm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,12 @@ def make_block_quant_fp8_weights(
6969
dtype=torch.float32)
7070

7171
for i in range(e):
72-
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
73-
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
72+
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
73+
block_size=block_size,
74+
use_ue8m0=True)
75+
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
76+
block_size=block_size,
77+
use_ue8m0=True)
7478

7579
return w1, w2, w1_s, w2_s
7680

tests/kernels/moe/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import torch
66

77
import vllm._custom_ops as ops
8-
from tests.kernels.quant_utils import (per_block_cast_to_fp8,
9-
per_block_cast_to_int8)
8+
from tests.kernels.quant_utils import per_block_cast_to_int8
109
from vllm.model_executor.layers.fused_moe import fused_experts
1110
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
1211
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
@@ -15,6 +14,7 @@
1514
from vllm.model_executor.layers.fused_moe.utils import (
1615
moe_kernel_quantize_input)
1716
from vllm.utils import round_up
17+
from vllm.utils.deep_gemm import per_block_cast_to_fp8
1818

1919

2020
def triton_moe(

tests/kernels/quant_utils.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -222,25 +222,6 @@ def native_per_token_group_quant_int8(x,
222222
DEFAULT_BLOCK_SHAPE = [128, 128]
223223

224224

225-
def per_block_cast_to_fp8(
226-
x: torch.Tensor,
227-
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
228-
) -> tuple[torch.Tensor, torch.Tensor]:
229-
block_m, block_n = block_shape
230-
assert x.dim() == 2
231-
m, n = x.shape
232-
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
233-
dtype=x.dtype,
234-
device=x.device)
235-
x_padded[:m, :n] = x
236-
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
237-
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
238-
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
239-
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
240-
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
241-
return x_scaled_sub, scales
242-
243-
244225
def per_block_cast_to_int8(
245226
x: torch.Tensor,
246227
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,

tests/kernels/quantization/test_block_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
117117
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
118118

119119
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
120-
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
120+
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)
121121

122122
As = As_fp8.to(torch.float32)
123123
Bs = Bs_fp8.to(torch.float32)

vllm/utils/deep_gemm.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import vllm.envs as envs
1616
from vllm.platforms import current_platform
17-
from vllm.utils import has_deep_gemm
17+
from vllm.utils import cdiv, has_deep_gemm
1818

1919

2020
@functools.cache
@@ -37,7 +37,7 @@ def is_blackwell_deep_gemm_used() -> bool:
3737
return False
3838

3939
_lazy_init()
40-
if _per_block_cast_impl is None:
40+
if _fp8_gemm_nt_impl is None:
4141
return False
4242

4343
return (current_platform.is_cuda()
@@ -63,18 +63,15 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
6363
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
6464
_grouped_impl: Callable[..., Any] | None = None
6565
_grouped_masked_impl: Callable[..., Any] | None = None
66-
_per_block_cast_impl: Callable[..., Any] | None = None
6766

6867

6968
def _lazy_init() -> None:
7069
"""Import deep_gemm and resolve symbols on first use."""
71-
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl, \
72-
_per_block_cast_impl
70+
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
7371

7472
# fast path
7573
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
76-
or _grouped_masked_impl is not None
77-
or _per_block_cast_impl is not None):
74+
or _grouped_masked_impl is not None):
7875
return
7976

8077
if not has_deep_gemm():
@@ -90,14 +87,6 @@ def _lazy_init() -> None:
9087
_grouped_masked_impl = _resolve_symbol(
9188
_dg, "fp8_m_grouped_gemm_nt_masked",
9289
"m_grouped_gemm_fp8_fp8_bf16_nt_masked")
93-
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
94-
try:
95-
_math_mod = importlib.import_module(
96-
"deep_gemm.utils.math") # type: ignore
97-
_per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8",
98-
None)
99-
except ModuleNotFoundError:
100-
_per_block_cast_impl = None
10190

10291

10392
def fp8_gemm_nt(*args, **kwargs):
@@ -121,13 +110,37 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
121110
return _grouped_masked_impl(*args, **kwargs)
122111

123112

124-
def per_block_cast_to_fp8(x, *args, **kwargs):
125-
_lazy_init()
126-
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
127-
return _per_block_cast_impl(x, use_ue8m0=True)
128-
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
129-
from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf
130-
return _pbcf(x, *args, **kwargs)
113+
def _ceil_to_ue8m0(x: torch.Tensor):
114+
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
115+
116+
117+
def _align(x: int, y: int) -> int:
118+
return cdiv(x, y) * y
119+
120+
121+
DEFAULT_BLOCK_SIZE = [128, 128]
122+
123+
124+
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
125+
# TODO(wentao): optimize this function, using triton or cuda kernel
126+
def per_block_cast_to_fp8(
127+
x: torch.Tensor,
128+
block_size: list[int] = DEFAULT_BLOCK_SIZE,
129+
use_ue8m0: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
130+
assert x.dim() == 2
131+
m, n = x.shape
132+
block_m, block_n = block_size
133+
x_padded = torch.zeros((_align(m, block_m), _align(n, block_n)),
134+
dtype=x.dtype,
135+
device=x.device)
136+
x_padded[:m, :n] = x
137+
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
138+
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
139+
sf = x_amax / 448.0
140+
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
141+
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
142+
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
143+
x_view.size(0), x_view.size(2))
131144

132145

133146
def calc_diff(x: torch.Tensor, y: torch.Tensor):

0 commit comments

Comments
 (0)