15
15
quantize_per_channel_group ,
16
16
)
17
17
18
- from torchao .dtypes import PlainLayout
19
18
from torchao .quantization .granularity import (
20
19
PerGroup ,
21
20
PerRow ,
@@ -516,7 +515,7 @@ def quantize(self, model: nn.Module) -> nn.Module:
516
515
@dataclass
517
516
class Int8DynamicActivationIntxWeightConfig (AOBaseConfig ):
518
517
"""
519
- Configuration for dynamically quantizes activations with 8-bits and weights with a low-bit value for linear layers .
518
+ Configuration for dynamically quantizing activations with 8-bits and quantizing weights with a low-bit value.
520
519
More specifically, activations are dynamically quantized to 8-bits in a channelwise manner with scales and zeros.
521
520
Weights are quantized with scales and optionally zeros (controlled by has_weight_zeros) in a groupwise or channelwise
522
521
manner using the number of bits specified by weight_dtype.
@@ -527,20 +526,17 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
527
526
has_weight_zeros: Whether or not to include zeros in the weight quantization.
528
527
weight_mapping_type: The type of mapping to use for the weight quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC.
529
528
act_mapping_type: The type of mapping to use for the activation quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC.
530
- layout: The layout to use for the packed weight tensor. Must be PackedLinearInt8DynamicActivationIntxWeightLayout (default) or PlainLayout.
531
- The layout does not affect the quantization numerically and both layouts will give the same results. PlainLayout is a generic layout
532
- that works on all devices, but it is much slower than PackedLinearInt8DynamicActivationIntxWeightLayout on CPU.
533
- PackedLinearInt8DynamicActivationIntxWeightLayout is a specialized layout for CPU performance.
534
- When using PackedLinearInt8DynamicActivationIntxWeightLayout,
535
- - The weight tensor must have device=CPU
536
- - The weight tensor must have dtype=float32 (note that after applying quantization, the weights will no longer be float32)
537
- - act_mapping_type must be MappingType.ASYMMETRIC
529
+ layout: The layout to use for the packed weight tensor. The layout does not affect the quantization numerically and different
530
+ layouts will give similar results. The following are available layouts:
531
+ - PackedLinearInt8DynamicActivationIntxWeightLayout: This layout is optimized for CPU performance.
532
+ - QDQLayout: This layout is designed for export to ExecuTorch
533
+ - PlainLayout: This layout is a simple python-based layout. It has low performance, but can be used
534
+ when PackedLinearInt8DynamicActivationIntxWeightLayout is unavailable.
538
535
"""
539
536
540
537
weight_dtype : torch .dtype = torch .int4
541
538
granularity : Union [PerRow , PerGroup ] = PerRow ()
542
539
has_weight_zeros : bool = False
543
- has_bias : bool = False
544
540
weight_mapping_type : MappingType = MappingType .ASYMMETRIC
545
541
act_mapping_type : MappingType = MappingType .ASYMMETRIC
546
542
layout : Layout = PackedLinearInt8DynamicActivationIntxWeightLayout (target = "native" )
@@ -559,27 +555,10 @@ def _int8_dynamic_activation_intx_weigh_transform(
559
555
weight_dtype = config .weight_dtype
560
556
granularity = config .granularity
561
557
has_weight_zeros = config .has_weight_zeros
562
- has_bias = config .has_bias
563
558
weight_mapping_type = config .weight_mapping_type
564
559
act_mapping_type = config .act_mapping_type
565
560
layout = config .layout
566
561
567
- def is_torchao_op_skippable (layout ):
568
- return isinstance (layout , PlainLayout ) or (
569
- isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout )
570
- and layout .target == Target .ATEN
571
- )
572
-
573
- if not is_torchao_op_skippable (layout ):
574
- try :
575
- torch .ops .torchao ._pack_8bit_act_4bit_weight
576
- except AttributeError :
577
- raise Exception (
578
- "TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU."
579
- + " You can also set target to 'aten' if you are using ARM CPU."
580
- + " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance."
581
- )
582
-
583
562
dtype_to_bit_width = {
584
563
torch .int1 : 1 ,
585
564
torch .int2 : 2 ,
@@ -603,7 +582,18 @@ def is_torchao_op_skippable(layout):
603
582
else :
604
583
raise ValueError (f"granularity must be PerGroup or PerRow, got { granularity } " )
605
584
585
+ tensor_impl_ctr_kwargs = None
606
586
if isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout ):
587
+ # We need to create a new layout object for each module because when
588
+ # granulairty is PerRow, the layout objects cannot share the group_size
589
+ layout = PackedLinearInt8DynamicActivationIntxWeightLayout (layout .target )
590
+ layout .set_params (
591
+ bit_width = bit_width ,
592
+ group_size = group_size ,
593
+ has_weight_zeros = has_weight_zeros ,
594
+ has_bias = False ,
595
+ )
596
+
607
597
assert (
608
598
weight .device == torch .device ("cpu" )
609
599
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires weight.device=CPU"
@@ -613,20 +603,24 @@ def is_torchao_op_skippable(layout):
613
603
assert (
614
604
act_mapping_type == MappingType .ASYMMETRIC
615
605
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC"
616
- assert not layout .has_params_set (), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set"
617
- layout = PackedLinearInt8DynamicActivationIntxWeightLayout (
618
- bit_width = bit_width ,
619
- group_size = group_size ,
620
- has_weight_zeros = has_weight_zeros ,
621
- has_bias = has_bias ,
622
- target = "aten" if layout .target == Target .ATEN else "native" ,
623
- )
624
606
625
- # ATEN KleidiAI kernel
626
- # TODO: long term, we want to disfavor this kernel and instead use KleidiAI kernels in torchao
627
- # that are vailable via PackedLinearInt8DynamicActivationIntxWeightLayout(target="native")
628
- # where applicable
629
- if layout .target == Target .ATEN :
607
+ tensor_impl_ctr_kwargs = {"bias" : bias }
608
+
609
+ if layout .target == Target .NATIVE :
610
+ # Check kernels are installed/loaded
611
+ try :
612
+ torch .ops .torchao ._pack_8bit_act_4bit_weight
613
+ except AttributeError :
614
+ raise Exception (
615
+ "TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU."
616
+ + " You can also set target to 'aten' if you are using ARM CPU."
617
+ )
618
+ elif layout .target == Target .ATEN :
619
+ # TODO: long term, we want to disfavor this route for using KleidiAI in torchao
620
+ # KleidiAI kernels are accessible via Target.NATIVE if torchao is built
621
+ # with TORCHAO_BUILD_KLEIDIAI=1. The Target.NATIVE route has the advantage
622
+ # of it automatially dispatching to different kernel libaries based on the CPU
623
+ # capability and the desired quantization
630
624
assert (
631
625
TORCH_VERSION_AT_LEAST_2_6
632
626
), "ATEN target requires torch version > 2.6.0"
@@ -657,7 +651,7 @@ def is_torchao_op_skippable(layout):
657
651
else ZeroPointDomain .NONE ,
658
652
_layout = layout ,
659
653
use_hqq = False ,
660
- tensor_impl_ctr_kwargs = { "bias" : bias } if has_bias else None ,
654
+ tensor_impl_ctr_kwargs = tensor_impl_ctr_kwargs ,
661
655
)
662
656
663
657
# Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused
@@ -678,7 +672,10 @@ def is_torchao_op_skippable(layout):
678
672
module .weight = torch .nn .Parameter (weight , requires_grad = False )
679
673
680
674
# If bias was packed with weights, set bias to None on module
681
- if has_bias :
675
+ if (
676
+ isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout )
677
+ and layout .has_bias
678
+ ):
682
679
module .bias = None
683
680
684
681
return module
0 commit comments