25
25
from torchao .dtypes .uintx .Uintx import UintxLayoutType
26
26
from torchao .dtypes import (
27
27
to_affine_quantized_intx ,
28
+ to_affine_quantized_floatx ,
28
29
TensorCoreTiledLayoutType ,
29
30
PlainLayoutType ,
30
31
AffineQuantizedTensor ,
@@ -670,6 +671,35 @@ def _validate_granularity(
670
671
else :
671
672
raise ValueError (f"Invalid granularity specification: { granularity } , only PerTensor or PerRow are supported." )
672
673
674
+ def _get_block_size (x : torch .Tensor , granularity : _fp8_granularities ):
675
+ if isinstance (granularity , PerTensor ):
676
+ return x .shape
677
+ elif isinstance (granularity , PerRow ):
678
+ return (1 ,) * (x .dim () - 1 ) + (x .shape [- 1 ],)
679
+ else :
680
+ raise ValueError (f"Unsupported granularity: { granularity } " )
681
+
682
+
683
+ def _input_quant_func_dyanmic_fp8 (
684
+ x : torch .Tensor ,
685
+ activation_granularity : _fp8_granularities ,
686
+ activation_dtype : torch .dtype ,
687
+ ):
688
+ if isinstance (activation_granularity , PerRow ):
689
+ assert (
690
+ x .dtype == torch .bfloat16
691
+ ), "PerRow quantization only works for bfloat16 precision input activation"
692
+
693
+ block_size = _get_block_size (x , activation_granularity )
694
+ activation = to_affine_quantized_floatx (
695
+ input_float = x ,
696
+ block_size = block_size ,
697
+ target_dtype = activation_dtype ,
698
+ scale_dtype = torch .float32 ,
699
+ layout_type = Float8LayoutType (mm_config = None ), # Config is stored on weight
700
+ )
701
+ return activation
702
+
673
703
674
704
def float8_dynamic_activation_float8_weight (
675
705
activation_dtype : torch .dtype = torch .float8_e4m3fn ,
@@ -693,28 +723,18 @@ def float8_dynamic_activation_float8_weight(
693
723
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
694
724
695
725
"""
696
- from torchao .dtypes import to_affine_quantized_floatx
697
-
698
726
if mm_config is None :
699
727
mm_config = Float8MMConfig (use_fast_accum = True )
700
728
701
729
activation_granularity , weight_granularity = _validate_granularity (granularity )
702
730
703
- def get_block_size (x : torch .Tensor , granularity : _fp8_granularities ):
704
- if isinstance (granularity , PerTensor ):
705
- return x .shape
706
- elif isinstance (granularity , PerRow ):
707
- return (1 ,) * (x .dim () - 1 ) + (x .shape [- 1 ],)
708
- else :
709
- raise ValueError (f"Unsupported granularity: { granularity } " )
710
-
711
731
def apply_float8_dynamic_activation_quant (weight : torch .Tensor ):
712
732
if isinstance (weight_granularity , PerRow ):
713
733
assert (
714
734
weight .dtype == torch .bfloat16
715
735
), "PerRow quantization only works for bfloat16 precision input weight"
716
736
717
- block_size = get_block_size (weight , weight_granularity )
737
+ block_size = _get_block_size (weight , weight_granularity )
718
738
quantized_weight = to_affine_quantized_floatx (
719
739
input_float = weight ,
720
740
block_size = block_size ,
@@ -723,23 +743,11 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
723
743
layout_type = Float8LayoutType (mm_config = mm_config ),
724
744
)
725
745
726
- def input_quant_func (x : torch .Tensor ):
727
- if isinstance (activation_granularity , PerRow ):
728
- assert (
729
- x .dtype == torch .bfloat16
730
- ), "PerRow quantization only works for bfloat16 precision input activation"
731
-
732
- block_size = get_block_size (x , activation_granularity )
733
- activation = to_affine_quantized_floatx (
734
- input_float = x ,
735
- block_size = block_size ,
736
- target_dtype = activation_dtype ,
737
- scale_dtype = torch .float32 ,
738
- layout_type = Float8LayoutType (
739
- mm_config = None
740
- ), # Config is stored on weight
741
- )
742
- return activation
746
+ input_quant_func = partial (
747
+ _input_quant_func_dyanmic_fp8 ,
748
+ activation_granularity = activation_granularity ,
749
+ activation_dtype = activation_dtype ,
750
+ )
743
751
744
752
quantized_weight = to_linear_activation_quantized (
745
753
quantized_weight , input_quant_func
0 commit comments