10
10
11
11
import torch
12
12
import torch .nn as nn
13
-
14
- from executorch .exir import pass_base
15
13
from torch .ao .quantization .fx ._decomposed import (
16
14
dequantize_per_channel_group ,
17
15
quantize_per_channel_group ,
18
16
)
19
17
20
18
from torchao .dtypes import PlainLayout
21
- from torchao .quantization .granularity import PerGroup , PerRow
22
- from torchao .utils import TORCH_VERSION_AT_LEAST_2_6
19
+ from torchao .quantization .granularity import (
20
+ PerGroup ,
21
+ PerRow ,
22
+ )
23
+ from torchao .utils import (
24
+ TORCH_VERSION_AT_LEAST_2_6 ,
25
+ )
23
26
24
27
logger = logging .getLogger (__name__ )
25
28
logger .setLevel (logging .WARNING )
@@ -497,10 +500,10 @@ def quantize(self, model: nn.Module) -> nn.Module:
497
500
to_linear_activation_quantized ,
498
501
)
499
502
from torchao .quantization .quant_api import (
500
- _get_linear_subclass_inserter ,
501
503
MappingType ,
502
- to_affine_quantized_intx ,
503
504
ZeroPointDomain ,
505
+ _get_linear_subclass_inserter ,
506
+ to_affine_quantized_intx ,
504
507
)
505
508
from torchao .quantization .utils import _get_per_token_block_size
506
509
@@ -511,7 +514,6 @@ def int8_dynamic_activation_intx_weight(
511
514
has_weight_zeros : bool = False ,
512
515
weight_mapping_type = MappingType .ASYMMETRIC ,
513
516
act_mapping_type = MappingType .ASYMMETRIC ,
514
- scale_dtype = torch .float32 ,
515
517
layout = PackedLinearInt8DynamicActivationIntxWeightLayout (
516
518
target = "native"
517
519
), # PlainLayout() also works, but will be slow
@@ -560,7 +562,7 @@ def is_torchao_op_skippable(layout):
560
562
torch .int3 : 3 ,
561
563
torch .int4 : 4 ,
562
564
torch .int5 : 5 ,
563
- torch .int6 : 6 ,
565
+ torch .int6 : 4 ,
564
566
torch .int7 : 7 ,
565
567
torch .int8 : 8 ,
566
568
}
@@ -588,23 +590,22 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
588
590
assert weight .shape [- 1 ] % group_size == 0
589
591
590
592
layout = layout_arg
593
+ scale_dtype = None
591
594
tensor_quantizer = to_affine_quantized_intx
592
595
quant_min = - (1 << (bit_width - 1 ))
593
596
quant_max = (1 << (bit_width - 1 )) - 1
594
597
595
598
if isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout ):
596
- assert weight . device == torch . device (
597
- "cpu"
599
+ assert (
600
+ weight . device == torch . device ( "cpu" )
598
601
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires weight.device=CPU"
599
602
assert (
600
603
weight .dtype == torch .float32
601
604
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires weight.dtype=float32"
602
605
assert (
603
606
act_mapping_type == MappingType .ASYMMETRIC
604
607
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC"
605
- assert (
606
- not layout .has_params_set ()
607
- ), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set"
608
+ assert not layout .has_params_set (), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set"
608
609
layout = PackedLinearInt8DynamicActivationIntxWeightLayout (
609
610
bit_width = bit_width ,
610
611
group_size = group_size ,
@@ -616,7 +617,6 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
616
617
weight_dtype != torch .int4
617
618
or has_weight_zeros != True
618
619
or weight_mapping_type == MappingType .ASYMMETRIC
619
- or scale_dtype != torch .bfloat16
620
620
):
621
621
raise NotImplementedError (
622
622
"target 'aten' requires:\n "
@@ -628,6 +628,11 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
628
628
assert (
629
629
TORCH_VERSION_AT_LEAST_2_6
630
630
), "aten target is requires torch version > 2.6.0"
631
+ if torch .backends .kleidiai .is_available ():
632
+ if isinstance (granularity , PerGroup ):
633
+ scale_dtype = (
634
+ torch .bfloat16
635
+ ) # KleidiAI kernel requires bfloat16 scale_dtype
631
636
tensor_quantizer = (
632
637
to_packedlinearint8dynamicactivationintxweight_quantized_intx
633
638
)
0 commit comments