1
+ from typing import Callable
1
2
import torch
2
3
import torchao
3
4
from torchao .quantization .quant_primitives import (
@@ -500,7 +501,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, LinearActiv
500
501
"""
501
502
AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per row scaling
502
503
"""
503
- activation_granularity : str = PerRow ()
504
+ activation_granularity = PerRow ()
504
505
@classmethod
505
506
def from_float (cls , weight ):
506
507
@@ -537,6 +538,42 @@ def get_per_token_block_size(x):
537
538
weight = super (AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight , cls ).from_float (weight , input_quant_func )
538
539
return weight
539
540
541
+ class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight (AQMixin , LinearActivationQuantizedTensor ):
542
+ """
543
+ AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per tensor scaling
544
+ """
545
+ activation_granularity = PerTensor ()
546
+ @classmethod
547
+ def from_float (cls , weight ):
548
+
549
+ # avoid circular dep
550
+ from torchao .dtypes import to_affine_quantized_floatx
551
+ from torchao .quantization .quant_api import _input_activation_quant_func_fp8
552
+ # weight settings
553
+ def get_weight_block_size (x ):
554
+ assert x .ndim == 2 , "Only works for 2D tensors"
555
+ return x .shape
556
+ target_dtype = torch .float8_e4m3fn
557
+
558
+ input_target_dtype = torch .float8_e4m3fn
559
+ _layout = Float8Layout (mm_config = Float8MMConfig (use_fast_accum = True ))
560
+ input_quant_func = lambda x : _input_activation_quant_func_fp8 (
561
+ x = x ,
562
+ activation_granularity = cls .activation_granularity ,
563
+ activation_dtype = input_target_dtype ,
564
+ )
565
+ block_size = get_weight_block_size (weight )
566
+ weight = to_affine_quantized_floatx (
567
+ input_float = weight ,
568
+ block_size = block_size ,
569
+ target_dtype = target_dtype ,
570
+ _layout = _layout ,
571
+ scale_dtype = torch .float32 ,
572
+ )
573
+ from torchao .float8 .inference import _is_rowwise_scaled
574
+ weight = super (AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight , cls ).from_float (weight , input_quant_func )
575
+ return weight
576
+
540
577
541
578
# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
542
579
DEFAULT_AUTOQUANT_CLASS_LIST = [
@@ -557,6 +594,7 @@ def get_per_token_block_size(x):
557
594
OTHER_AUTOQUANT_CLASS_LIST = [
558
595
AQFloat8WeightOnlyQuantizedLinearWeight ,
559
596
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight ,
597
+ AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight ,
560
598
]
561
599
562
600
0 commit comments