51
51
TORCH_VERSION_AT_LEAST_2_4 ,
52
52
TORCH_VERSION_AT_LEAST_2_5 ,
53
53
TORCH_VERSION_AT_LEAST_2_6 ,
54
+ is_MI300 ,
55
+ is_sm_89 ,
56
+ is_sm_90 ,
54
57
)
55
58
56
59
from .autoquant import AutoQuantizableLinearWeight , autoquant
@@ -827,10 +830,11 @@ def _normalize_granularity(
827
830
Union [_fp8_granularities , Tuple [_fp8_granularities , _fp8_granularities ]]
828
831
],
829
832
) -> Tuple [_fp8_granularities , _fp8_granularities ]:
833
+ processed_granularity = None
830
834
if granularity is None :
831
- return (PerTensor (), PerTensor ())
835
+ processed_granularity = (PerTensor (), PerTensor ())
832
836
elif isinstance (granularity , (PerTensor , PerRow )):
833
- return (granularity , granularity )
837
+ processed_granularity = (granularity , granularity )
834
838
elif isinstance (granularity , tuple ) and len (granularity ) == 2 :
835
839
if not (
836
840
isinstance (granularity [0 ], (PerTensor , PerRow ))
@@ -843,11 +847,25 @@ def _normalize_granularity(
843
847
raise ValueError (
844
848
f"Different granularities for activation and weight are not supported: { granularity } , only PerTensor or PerRow are supported."
845
849
)
846
- return granularity
850
+ processed_granularity = granularity
847
851
else :
848
852
raise ValueError (
849
853
f"Invalid granularity specification: { granularity } , only PerTensor or PerRow are supported."
850
854
)
855
+ # Validate granularity with supported Hardware
856
+ for _granularity in processed_granularity :
857
+ if isinstance (_granularity , PerTensor ):
858
+ assert (
859
+ is_sm_89 () or is_MI300 ()
860
+ ), "PerTensor quantization only works for CUDA>=8.9 and MI300+"
861
+ elif isinstance (_granularity , PerRow ):
862
+ assert (
863
+ is_sm_90 () or is_MI300 ()
864
+ ), "PerRow quantization only works for CUDA>=9.0 and MI300+"
865
+ else :
866
+ raise ValueError (f"Invalid granularity type: { _granularity } " )
867
+
868
+ return processed_granularity
851
869
852
870
853
871
def _input_activation_quant_func_fp8 (
@@ -939,6 +957,9 @@ def float8_dynamic_activation_float8_weight(
939
957
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
940
958
941
959
"""
960
+ assert (
961
+ is_sm_89 () or is_MI300 ()
962
+ ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
942
963
if mm_config is None :
943
964
mm_config = Float8MMConfig (use_fast_accum = True )
944
965
@@ -993,6 +1014,9 @@ def float8_static_activation_float8_weight(
993
1014
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
994
1015
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
995
1016
"""
1017
+ assert (
1018
+ is_sm_89 () or is_MI300 ()
1019
+ ), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
996
1020
if mm_config is None :
997
1021
mm_config = Float8MMConfig (use_fast_accum = True )
998
1022
0 commit comments