Skip to content

Commit 478d15b

Browse files
authored
Add hardware check to fp8 quant (#1314)
1 parent b65e513 commit 478d15b

File tree

3 files changed

+59
-3
lines changed

3 files changed

+59
-3
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,20 @@ def test_fp8_linear_variants(
134134
compute_error(output_original, output_quantized) > 20
135135
), f"Quantization error is too high got a SQNR of {error}"
136136

137+
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
137138
def test_invalid_granularity(self):
138139
with pytest.raises(ValueError, match="Invalid granularity specification"):
139140
float8_dynamic_activation_float8_weight(granularity="invalid")
140141

142+
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
141143
def test_mismatched_granularity(self):
142144
with pytest.raises(
143145
ValueError,
144146
match="Different granularities for activation and weight are not supported",
145147
):
146148
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
147149

150+
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
148151
def test_unsupported_granularity(self):
149152
class UnsupportedGranularity:
150153
pass

torchao/quantization/quant_api.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151
TORCH_VERSION_AT_LEAST_2_4,
5252
TORCH_VERSION_AT_LEAST_2_5,
5353
TORCH_VERSION_AT_LEAST_2_6,
54+
is_MI300,
55+
is_sm_89,
56+
is_sm_90,
5457
)
5558

5659
from .autoquant import AutoQuantizableLinearWeight, autoquant
@@ -827,10 +830,11 @@ def _normalize_granularity(
827830
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
828831
],
829832
) -> Tuple[_fp8_granularities, _fp8_granularities]:
833+
processed_granularity = None
830834
if granularity is None:
831-
return (PerTensor(), PerTensor())
835+
processed_granularity = (PerTensor(), PerTensor())
832836
elif isinstance(granularity, (PerTensor, PerRow)):
833-
return (granularity, granularity)
837+
processed_granularity = (granularity, granularity)
834838
elif isinstance(granularity, tuple) and len(granularity) == 2:
835839
if not (
836840
isinstance(granularity[0], (PerTensor, PerRow))
@@ -843,11 +847,25 @@ def _normalize_granularity(
843847
raise ValueError(
844848
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
845849
)
846-
return granularity
850+
processed_granularity = granularity
847851
else:
848852
raise ValueError(
849853
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
850854
)
855+
# Validate granularity with supported Hardware
856+
for _granularity in processed_granularity:
857+
if isinstance(_granularity, PerTensor):
858+
assert (
859+
is_sm_89() or is_MI300()
860+
), "PerTensor quantization only works for CUDA>=8.9 and MI300+"
861+
elif isinstance(_granularity, PerRow):
862+
assert (
863+
is_sm_90() or is_MI300()
864+
), "PerRow quantization only works for CUDA>=9.0 and MI300+"
865+
else:
866+
raise ValueError(f"Invalid granularity type: {_granularity}")
867+
868+
return processed_granularity
851869

852870

853871
def _input_activation_quant_func_fp8(
@@ -939,6 +957,9 @@ def float8_dynamic_activation_float8_weight(
939957
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
940958
941959
"""
960+
assert (
961+
is_sm_89() or is_MI300()
962+
), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
942963
if mm_config is None:
943964
mm_config = Float8MMConfig(use_fast_accum=True)
944965

@@ -993,6 +1014,9 @@ def float8_static_activation_float8_weight(
9931014
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
9941015
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
9951016
"""
1017+
assert (
1018+
is_sm_89() or is_MI300()
1019+
), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
9961020
if mm_config is None:
9971021
mm_config = Float8MMConfig(use_fast_accum=True)
9981022

torchao/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
"TORCH_VERSION_AFTER_2_3",
3232
"TORCH_VERSION_AFTER_2_4",
3333
"TORCH_VERSION_AFTER_2_5",
34+
"is_MI300",
35+
"is_sm_89",
36+
"is_sm_90",
3437
]
3538

3639

@@ -586,6 +589,32 @@ def _torch_version_at_least(min_version):
586589
return is_fbcode() or version("torch") >= min_version
587590

588591

592+
def is_MI300():
593+
if torch.cuda.is_available() and torch.version.hip:
594+
mxArchName = ["gfx940", "gfx941", "gfx942"]
595+
archName = torch.cuda.get_device_properties().gcnArchName
596+
for arch in mxArchName:
597+
if arch in archName:
598+
return True
599+
return False
600+
601+
602+
def is_sm_89():
603+
return (
604+
torch.cuda.is_available()
605+
and torch.version.cuda
606+
and torch.cuda.get_device_capability() >= (8, 9)
607+
)
608+
609+
610+
def is_sm_90():
611+
return (
612+
torch.cuda.is_available()
613+
and torch.version.cuda
614+
and torch.cuda.get_device_capability() >= (9, 0)
615+
)
616+
617+
589618
TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev")
590619
TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev")
591620
TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")

0 commit comments

Comments
 (0)