Skip to content

Commit 500a456

Browse files
committed
chore: move imports to top of the file
1 parent 6028093 commit 500a456

File tree

1 file changed

+4
-20
lines changed

1 file changed

+4
-20
lines changed

torchao/quantization/quant_api.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@
2121
import torch.nn.functional as F
2222
from typing import Any, Callable, Union, Dict, Optional
2323

24+
from torchao.dtypes.uintx.Uintx import UintxLayoutType
2425
from torchao.dtypes import (
2526
to_affine_quantized,
2627
TensorCoreTiledLayoutType,
27-
PlainLayoutType
28+
PlainLayoutType,
29+
AffineQuantizedTensor,
30+
SemiSparseLayoutType
2831
)
2932
from torchao.utils import (
3033
TORCH_VERSION_AFTER_2_4,
@@ -186,9 +189,6 @@ def _replace_with_custom_fn_if_matches_filter(
186189

187190

188191
def _is_linear(mod, *args):
189-
# avoid circular dep
190-
from torchao.dtypes import AffineQuantizedTensor
191-
192192
# adding weight tensor subclass isinstance check to make sure the weight is only quantized once
193193
# when it is shared by multiple linear modules
194194
return (
@@ -332,9 +332,6 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
332332
)
333333

334334
def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
335-
# avoid circular dep
336-
from torchao.dtypes import to_affine_quantized
337-
338335
mapping_type = MappingType.ASYMMETRIC
339336
target_dtype = torch.int8
340337
return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype)
@@ -343,9 +340,6 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
343340
if weight.shape[-1] % group_size != 0:
344341
return weight
345342

346-
# avoid circular dep
347-
from torchao.dtypes import to_affine_quantized
348-
349343
# weight settings
350344
mapping_type = MappingType.SYMMETRIC
351345
block_size = (1, group_size)
@@ -418,9 +412,6 @@ def int8_weight_only():
418412
Applies int8 weight-only symmetric per-channel quantization to linear layers.
419413
"""
420414
def apply_int8wo_quant(weight):
421-
# avoid circular dep
422-
from torchao.dtypes import to_affine_quantized
423-
424415
mapping_type = MappingType.SYMMETRIC
425416
target_dtype = torch.int8
426417
eps = torch.finfo(torch.float32).eps
@@ -431,8 +422,6 @@ def apply_int8wo_quant(weight):
431422
return _get_linear_subclass_inserter(apply_int8wo_quant)
432423

433424
def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
434-
# avoid circular dep
435-
from torchao.dtypes import to_affine_quantized
436425
mapping_type = MappingType.SYMMETRIC
437426
target_dtype = torch.int8
438427
eps = 1e-5
@@ -452,8 +441,6 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight):
452441
if in_features <= 16:
453442
return weight
454443

455-
# avoid circular dep
456-
from torchao.dtypes import to_affine_quantized
457444
# weight settings
458445
mapping_type = MappingType.SYMMETRIC
459446
def get_weight_block_size(x):
@@ -478,7 +465,6 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
478465
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
479466
quantization + 2:4 sparsity to linear layers.
480467
"""
481-
from torchao.dtypes import SemiSparseLayoutType
482468
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())
483469

484470

@@ -494,8 +480,6 @@ def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
494480
quantize_affine,
495481
dequantize_affine,
496482
)
497-
from torchao.dtypes.uintx.Uintx import UintxLayoutType
498-
from torchao.dtypes import to_affine_quantized
499483
from torchao.quantization.quant_api import _get_linear_subclass_inserter
500484
def apply_uintx_weight_only_quant(weight):
501485

0 commit comments

Comments
 (0)