21
21
import torch .nn .functional as F
22
22
from typing import Any , Callable , Union , Dict , Optional
23
23
24
+ from torchao .dtypes .uintx .Uintx import UintxLayoutType
24
25
from torchao .dtypes import (
25
26
to_affine_quantized ,
26
27
TensorCoreTiledLayoutType ,
27
- PlainLayoutType
28
+ PlainLayoutType ,
29
+ AffineQuantizedTensor ,
30
+ SemiSparseLayoutType
28
31
)
29
32
from torchao .utils import (
30
33
TORCH_VERSION_AFTER_2_4 ,
@@ -186,9 +189,6 @@ def _replace_with_custom_fn_if_matches_filter(
186
189
187
190
188
191
def _is_linear (mod , * args ):
189
- # avoid circular dep
190
- from torchao .dtypes import AffineQuantizedTensor
191
-
192
192
# adding weight tensor subclass isinstance check to make sure the weight is only quantized once
193
193
# when it is shared by multiple linear modules
194
194
return (
@@ -332,9 +332,6 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
332
332
)
333
333
334
334
def _int8_asymm_per_token_quant (x : torch .Tensor ) -> torch .Tensor :
335
- # avoid circular dep
336
- from torchao .dtypes import to_affine_quantized
337
-
338
335
mapping_type = MappingType .ASYMMETRIC
339
336
target_dtype = torch .int8
340
337
return to_affine_quantized (x , mapping_type , _get_per_token_block_size (x ), target_dtype )
@@ -343,9 +340,6 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
343
340
if weight .shape [- 1 ] % group_size != 0 :
344
341
return weight
345
342
346
- # avoid circular dep
347
- from torchao .dtypes import to_affine_quantized
348
-
349
343
# weight settings
350
344
mapping_type = MappingType .SYMMETRIC
351
345
block_size = (1 , group_size )
@@ -418,9 +412,6 @@ def int8_weight_only():
418
412
Applies int8 weight-only symmetric per-channel quantization to linear layers.
419
413
"""
420
414
def apply_int8wo_quant (weight ):
421
- # avoid circular dep
422
- from torchao .dtypes import to_affine_quantized
423
-
424
415
mapping_type = MappingType .SYMMETRIC
425
416
target_dtype = torch .int8
426
417
eps = torch .finfo (torch .float32 ).eps
@@ -431,8 +422,6 @@ def apply_int8wo_quant(weight):
431
422
return _get_linear_subclass_inserter (apply_int8wo_quant )
432
423
433
424
def _int8_symm_per_token_reduced_range_quant (x : torch .Tensor ) -> torch .Tensor :
434
- # avoid circular dep
435
- from torchao .dtypes import to_affine_quantized
436
425
mapping_type = MappingType .SYMMETRIC
437
426
target_dtype = torch .int8
438
427
eps = 1e-5
@@ -452,8 +441,6 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight):
452
441
if in_features <= 16 :
453
442
return weight
454
443
455
- # avoid circular dep
456
- from torchao .dtypes import to_affine_quantized
457
444
# weight settings
458
445
mapping_type = MappingType .SYMMETRIC
459
446
def get_weight_block_size (x ):
@@ -478,7 +465,6 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
478
465
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
479
466
quantization + 2:4 sparsity to linear layers.
480
467
"""
481
- from torchao .dtypes import SemiSparseLayoutType
482
468
return int8_dynamic_activation_int8_weight (layout_type = SemiSparseLayoutType ())
483
469
484
470
@@ -494,8 +480,6 @@ def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
494
480
quantize_affine ,
495
481
dequantize_affine ,
496
482
)
497
- from torchao .dtypes .uintx .Uintx import UintxLayoutType
498
- from torchao .dtypes import to_affine_quantized
499
483
from torchao .quantization .quant_api import _get_linear_subclass_inserter
500
484
def apply_uintx_weight_only_quant (weight ):
501
485
0 commit comments