@@ -31,19 +31,37 @@ def is_deep_gemm_supported() -> bool:
31
31
32
32
33
33
@functools .cache
34
- def is_blackwell_deep_gemm_used () -> bool :
35
- """Return ``True`` if vLLM is configured to use DeepGEMM on a
36
- Blackwell-class GPU.
34
+ def is_blackwell_deep_gemm_e8m0_used () -> bool :
35
+ """Return ``True`` if vLLM is configured to use DeepGEMM "
36
+ "E8M0 scale on a Blackwell-class GPU.
37
37
"""
38
- if not (envs .VLLM_USE_DEEP_GEMM and has_deep_gemm ()):
38
+ if not (envs .VLLM_USE_DEEP_GEMM ):
39
+ logger .debug_once ("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM=0." )
40
+ return False
41
+
42
+ if not has_deep_gemm ():
43
+ logger .debug_once ("DeepGEMM E8M0 disabled: DeepGEMM backend missing." )
44
+ return False
45
+
46
+ if not envs .VLLM_USE_DEEP_GEMM_E8M0 :
47
+ logger .debug_once ("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0." )
39
48
return False
40
49
41
50
_lazy_init ()
51
+
42
52
if _fp8_gemm_nt_impl is None :
53
+ logger .debug_once (
54
+ "DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found" )
43
55
return False
44
56
45
- return (current_platform .is_cuda ()
46
- and current_platform .is_device_capability (100 ))
57
+ enabled = (current_platform .is_cuda ()
58
+ and current_platform .has_device_capability (100 ))
59
+ if enabled :
60
+ logger .debug_once ("DeepGEMM E8M0 enabled on Blackwell GPU." )
61
+ else :
62
+ logger .debug_once (
63
+ "DeepGEMM E8M0 disabled: not running on Blackwell GPU." )
64
+ return enabled
47
65
48
66
49
67
def _missing (* _ : Any , ** __ : Any ) -> NoReturn :
@@ -109,21 +127,30 @@ def fp8_gemm_nt(*args, **kwargs):
109
127
_lazy_init ()
110
128
if _fp8_gemm_nt_impl is None :
111
129
return _missing (* args , ** kwargs )
112
- return _fp8_gemm_nt_impl (* args , ** kwargs )
130
+ return _fp8_gemm_nt_impl (
131
+ * args ,
132
+ disable_ue8m0_cast = not is_blackwell_deep_gemm_e8m0_used (),
133
+ ** kwargs )
113
134
114
135
115
136
def m_grouped_fp8_gemm_nt_contiguous (* args , ** kwargs ):
116
137
_lazy_init ()
117
138
if _grouped_impl is None :
118
139
return _missing (* args , ** kwargs )
119
- return _grouped_impl (* args , ** kwargs )
140
+ return _grouped_impl (
141
+ * args ,
142
+ disable_ue8m0_cast = not is_blackwell_deep_gemm_e8m0_used (),
143
+ ** kwargs )
120
144
121
145
122
146
def fp8_m_grouped_gemm_nt_masked (* args , ** kwargs ):
123
147
_lazy_init ()
124
148
if _grouped_masked_impl is None :
125
149
return _missing (* args , ** kwargs )
126
- return _grouped_masked_impl (* args , ** kwargs )
150
+ return _grouped_masked_impl (
151
+ * args ,
152
+ disable_ue8m0_cast = not is_blackwell_deep_gemm_e8m0_used (),
153
+ ** kwargs )
127
154
128
155
129
156
def _ceil_to_ue8m0 (x : torch .Tensor ):
@@ -181,6 +208,6 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
181
208
"m_grouped_fp8_gemm_nt_contiguous" ,
182
209
"fp8_m_grouped_gemm_nt_masked" ,
183
210
"per_block_cast_to_fp8" ,
184
- "is_blackwell_deep_gemm_used " ,
211
+ "is_blackwell_deep_gemm_e8m0_used " ,
185
212
"is_deep_gemm_supported" ,
186
213
]
0 commit comments