Skip to content

Commit 0d33ec9

Browse files
committed
up
1 parent 8856c1e commit 0d33ec9

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

torchao/experimental/quant_api.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,19 @@
1010

1111
import torch
1212
import torch.nn as nn
13-
14-
from executorch.exir import pass_base
1513
from torch.ao.quantization.fx._decomposed import (
1614
dequantize_per_channel_group,
1715
quantize_per_channel_group,
1816
)
1917

2018
from torchao.dtypes import PlainLayout
21-
from torchao.quantization.granularity import PerGroup, PerRow
22-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
19+
from torchao.quantization.granularity import (
20+
PerGroup,
21+
PerRow,
22+
)
23+
from torchao.utils import (
24+
TORCH_VERSION_AT_LEAST_2_6,
25+
)
2326

2427
logger = logging.getLogger(__name__)
2528
logger.setLevel(logging.WARNING)
@@ -497,10 +500,10 @@ def quantize(self, model: nn.Module) -> nn.Module:
497500
to_linear_activation_quantized,
498501
)
499502
from torchao.quantization.quant_api import (
500-
_get_linear_subclass_inserter,
501503
MappingType,
502-
to_affine_quantized_intx,
503504
ZeroPointDomain,
505+
_get_linear_subclass_inserter,
506+
to_affine_quantized_intx,
504507
)
505508
from torchao.quantization.utils import _get_per_token_block_size
506509

@@ -511,7 +514,6 @@ def int8_dynamic_activation_intx_weight(
511514
has_weight_zeros: bool = False,
512515
weight_mapping_type=MappingType.ASYMMETRIC,
513516
act_mapping_type=MappingType.ASYMMETRIC,
514-
scale_dtype=torch.float32,
515517
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(
516518
target="native"
517519
), # PlainLayout() also works, but will be slow
@@ -560,7 +562,7 @@ def is_torchao_op_skippable(layout):
560562
torch.int3: 3,
561563
torch.int4: 4,
562564
torch.int5: 5,
563-
torch.int6: 6,
565+
torch.int6: 4,
564566
torch.int7: 7,
565567
torch.int8: 8,
566568
}
@@ -588,23 +590,22 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
588590
assert weight.shape[-1] % group_size == 0
589591

590592
layout = layout_arg
593+
scale_dtype = None
591594
tensor_quantizer = to_affine_quantized_intx
592595
quant_min = -(1 << (bit_width - 1))
593596
quant_max = (1 << (bit_width - 1)) - 1
594597

595598
if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout):
596-
assert weight.device == torch.device(
597-
"cpu"
599+
assert (
600+
weight.device == torch.device("cpu")
598601
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires weight.device=CPU"
599602
assert (
600603
weight.dtype == torch.float32
601604
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires weight.dtype=float32"
602605
assert (
603606
act_mapping_type == MappingType.ASYMMETRIC
604607
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC"
605-
assert (
606-
not layout.has_params_set()
607-
), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set"
608+
assert not layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set"
608609
layout = PackedLinearInt8DynamicActivationIntxWeightLayout(
609610
bit_width=bit_width,
610611
group_size=group_size,
@@ -616,7 +617,6 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
616617
weight_dtype != torch.int4
617618
or has_weight_zeros != True
618619
or weight_mapping_type == MappingType.ASYMMETRIC
619-
or scale_dtype != torch.bfloat16
620620
):
621621
raise NotImplementedError(
622622
"target 'aten' requires:\n"
@@ -628,6 +628,11 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
628628
assert (
629629
TORCH_VERSION_AT_LEAST_2_6
630630
), "aten target is requires torch version > 2.6.0"
631+
if torch.backends.kleidiai.is_available():
632+
if isinstance(granularity, PerGroup):
633+
scale_dtype = (
634+
torch.bfloat16
635+
) # KleidiAI kernel requires bfloat16 scale_dtype
631636
tensor_quantizer = (
632637
to_packedlinearint8dynamicactivationintxweight_quantized_intx
633638
)

0 commit comments

Comments
 (0)