Skip to content

Commit 000a0fd

Browse files
authored
Move some util functions from quantization.utils to torchao.utils (#337)
Summary: Moved ``` TORCH_VERSION_AFTER_2_(2/3/4) get_model_size_in_bytes unwrap_tensor_subclass ``` from quantization/utils.py to torchao/utils.py Test Plan: python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
1 parent 335171f commit 000a0fd

27 files changed

+129
-125
lines changed

test/dtypes/test_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
parametrize,
88
run_tests,
99
)
10-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
10+
from torchao.utils import TORCH_VERSION_AFTER_2_4
1111

1212
try:
1313
from torchao.prototype.fp8 import gemm_split_k, to_float8

test/integration/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
from parameterized import parameterized
7171
import itertools
7272
import logging
73-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
73+
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
7474

7575
logger = logging.getLogger("INFO")
7676

test/prototype/mx_formats/test_custom_cast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
)
4545

4646
from torchao.prototype.mx_formats.mx_tensor import MXTensor
47-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
47+
from torchao.utils import TORCH_VERSION_AFTER_2_4
4848

4949
if not TORCH_VERSION_AFTER_2_4:
5050
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
swap_linear_with_mx_linear,
2020
)
2121

22-
from torchao.quantization.utils import compute_error, TORCH_VERSION_AFTER_2_4
22+
from torchao.quantization.utils import compute_error
23+
from torchao.utils import TORCH_VERSION_AFTER_2_4
2324

2425
# trying to outsmart flake8
2526
__has_cuda = torch.cuda.is_available()

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
to_dtype,
2424
)
2525

26-
from torchao.quantization.utils import compute_error, TORCH_VERSION_AFTER_2_4
26+
from torchao.quantization.utils import compute_error
27+
from torchao.utils import TORCH_VERSION_AFTER_2_4
2728

2829
# trying to outsmart flake8
2930
__has_cuda = torch.cuda.is_available()

test/prototype/test_bitpacking.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torchao.prototype.common.bitpacking import pack, unpack
33
import pytest
44
from torch.utils._triton import has_triton
5-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
5+
from torchao.utils import TORCH_VERSION_AFTER_2_4
66

77
if not TORCH_VERSION_AFTER_2_4:
88
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -20,15 +20,15 @@ def test_uint3_to_int16_col_wise_cpu():
2020
unpacked = unpack(packed, 3, False, device='cpu')
2121
unpadded = unpacked[:test_tensor.shape[0], ...]
2222
assert(unpadded.allclose(test_tensor))
23-
23+
2424
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
2525
def test_uint4_to_uint8():
2626
test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda()
2727
packed = pack(test_tensor, 8, 4)
2828
unpacked = unpack(packed, 4)
2929
unpadded = unpacked[:test_tensor.shape[0], ...]
3030
assert(unpadded.allclose(test_tensor))
31-
31+
3232
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
3333
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
3434
def test_uint4_to_uint8_compile():
@@ -40,7 +40,7 @@ def test_uint4_to_uint8_compile():
4040
unpacked = unpack_compiled(packed, 4)
4141
unpadded = unpacked[:test_tensor.shape[0], ...]
4242
assert(unpadded.allclose(test_tensor))
43-
43+
4444
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
4545
def test_uint3_to_int16():
4646
test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda()
@@ -67,4 +67,4 @@ def test_uint3_to_int16_col_wise():
6767
packed = pack(test_tensor,16, 3, False)
6868
unpacked = unpack(packed, 3, False)
6969
unpadded = unpacked[:test_tensor.shape[0], ...]
70-
assert(unpadded.allclose(test_tensor))
70+
assert(unpadded.allclose(test_tensor))

test/quantization/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch.nn as nn
1111
from torch import Tensor
1212
from torch.nn import functional as F
13-
from torchao.quantization.utils import find_multiple
13+
from torchao.utils import find_multiple
1414

1515
def prepare_inputs_for_model(inps, max_new_tokens=1):
1616
# this is because input from lm-eval is 2d

test/quantization/test_qat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
fake_quantize_per_token,
2020
)
2121
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
22-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
22+
from torchao.utils import TORCH_VERSION_AFTER_2_4
2323

2424

2525
# TODO: put this in a common test utils file

test/quantization/test_quant_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
get_apply_int8wo_quant,
4545
get_apply_int8dyn_quant,
4646
)
47-
from torchao.quantization.utils import (
47+
from torchao.utils import (
4848
TORCH_VERSION_AFTER_2_3,
4949
TORCH_VERSION_AFTER_2_4,
5050
)
@@ -556,7 +556,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
556556
self.assertTrue(torch.equal(res, ref))
557557

558558
# workaround for export path
559-
from torchao.quantization.utils import unwrap_tensor_subclass
559+
from torchao.utils import unwrap_tensor_subclass
560560
m_unwrapped = unwrap_tensor_subclass(m)
561561

562562
m = torch.export.export(m_unwrapped, example_inputs).module()

test/quantization/test_quant_primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
MappingType,
2020
)
2121

22-
from torchao.quantization.utils import (
22+
from torchao.utils import (
2323
TORCH_VERSION_AFTER_2_3,
2424
TORCH_VERSION_AFTER_2_4,
2525
)

test/sparsity/test_fast_sparse_training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
swap_semi_sparse_linear_with_linear,
1313
SemiSparseLinear
1414
)
15-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
15+
from torchao.utils import TORCH_VERSION_AFTER_2_4
1616

1717
class TestModel(nn.Module):
1818
def __init__(self):
@@ -42,7 +42,7 @@ def test_runtime_weight_sparsification(self):
4242
if isinstance(mod, torch.nn.Linear):
4343
sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(mod.weight.detach()).to_dense()
4444
mod.weight = nn.Parameter(sparse)
45-
45+
4646
dense_result = model(input)
4747

4848
# map from fqn to replacement linear module

test/sparsity/test_sparse_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
_get_subclass_inserter,
1212
_is_linear,
1313
)
14-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
14+
from torchao.utils import TORCH_VERSION_AFTER_2_3
1515
from torch.testing._internal.common_utils import TestCase
1616

1717

test/test_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
33
from torch.testing._internal.optests import opcheck
44
import torchao
5-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
5+
from torchao.utils import TORCH_VERSION_AFTER_2_4
66
import unittest
77
from parameterized import parameterized
88
import pytest

torchao/_executorch_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs):
99
torch.ops.quantized_decomposed.quantize_per_channel_group is only available
1010
in PyTorch 2.3+ and recently changed signatures.
1111
"""
12-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
12+
from torchao.utils import TORCH_VERSION_AFTER_2_3
1313
if TORCH_VERSION_AFTER_2_3:
1414
return torch.ops.quantized_decomposed.quantize_per_channel_group(*args, **kwargs)
1515
raise ImportError("Need torch.ops.quantized_decomposed.quantize_per_channel_group, which is only available with PyTorch 2.3 or later.")
@@ -23,7 +23,7 @@ def _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper(*args, **k
2323
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric is only available
2424
in PyTorch 2.3+ and recently changed signatures.
2525
"""
26-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
26+
from torchao.utils import TORCH_VERSION_AFTER_2_3
2727
if TORCH_VERSION_AFTER_2_3:
2828
return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(*args, **kwargs)
2929
raise ImportError("Need torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric, which is only available with PyTorch 2.3 or later.")
@@ -37,7 +37,7 @@ def _quantized_decomposed_dequantize_per_channel_group_wrapper(*args, **kwargs):
3737
torch.ops.quantized_decomposed.dequantize_per_channel_group is only available
3838
in PyTorch 2.3+ and recently changed signatures.
3939
"""
40-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
40+
from torchao.utils import TORCH_VERSION_AFTER_2_3
4141
if TORCH_VERSION_AFTER_2_3:
4242
return torch.ops.quantized_decomposed.dequantize_per_channel_group(*args, **kwargs)
4343
raise ImportError("Need torch.ops.quantized_decomposed.dequantize_per_channel_group, which is only available with PyTorch 2.3 or later.")
@@ -51,7 +51,7 @@ def _quantized_decomposed_quantize_per_token_wrapper(*args, **kwargs):
5151
torch.ops.quantized_decomposed.quantize_per_token is only available
5252
in PyTorch 2.3+ and recently changed signatures.
5353
"""
54-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
54+
from torchao.utils import TORCH_VERSION_AFTER_2_3
5555
if TORCH_VERSION_AFTER_2_3:
5656
return torch.ops.quantized_decomposed.quantize_per_token(*args, **kwargs)
5757
raise ImportError("Need torch.ops.quantized_decomposed.quantize_per_token, which is only available with PyTorch 2.3 or later.")
@@ -65,7 +65,7 @@ def _quantized_decomposed_dequantize_per_token_wrapper(*args, **kwargs):
6565
torch.ops.quantized_decomposed.dequantize_per_token is only available
6666
in PyTorch 2.3+ and recently changed signatures.
6767
"""
68-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
68+
from torchao.utils import TORCH_VERSION_AFTER_2_3
6969
if TORCH_VERSION_AFTER_2_3:
7070
return torch.ops.quantized_decomposed.dequantize_per_token(*args, **kwargs)
7171
raise ImportError("Need torch.ops.quantized_decomposed.dequantize_per_token, which is only available with PyTorch 2.3 or later.")

torchao/kernel/intmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import torch
44

5-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_2
5+
from torchao.utils import TORCH_VERSION_AFTER_2_2
66

77
try:
88
# Only works for torch2.2 or newer.

torchao/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch import Tensor
3-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
3+
from torchao.utils import TORCH_VERSION_AFTER_2_4
44

55

66
def register_custom_op(name):

torchao/prototype/mx_formats/custom_cast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
from torch.utils._triton import has_triton
1313

14-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
14+
from torchao.utils import TORCH_VERSION_AFTER_2_4
1515

1616
# TODO(future): if needed, make the below work on previous PyTorch versions,
1717
# just need to hunt down the previous location of `libdevice`. An assert

torchao/quantization/GPTQ.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222
from .utils import (
2323
_lm_eval_available,
2424
_MultiInput,
25-
TORCH_VERSION_AFTER_2_3,
25+
)
26+
from torchao.utils import (
2627
find_multiple,
2728
)
29+
from torchao.utils import TORCH_VERSION_AFTER_2_3
2830
from typing import Any, Dict, Optional
2931
from .unified import Quantizer
3032

torchao/quantization/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
"Int8WeightOnlyQuantizedLinearWeight",
4545
"Int4WeightOnlyQuantizedLinearWeight",
4646
"compute_error",
47-
"get_model_size_in_bytes",
4847
"WeightOnlyInt8QuantLinear",
4948
"Int4WeightOnlyGPTQQuantizer",
5049
"Int4WeightOnlyQuantizer",

torchao/quantization/autoquant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
quantize_activation_per_token_absmax,
1010
safe_int_mm,
1111
)
12-
from .utils import TORCH_VERSION_AFTER_2_4
12+
from torchao.utils import TORCH_VERSION_AFTER_2_4
1313
import torch.nn.functional as F
1414
try:
1515
from torch._inductor.utils import do_bench

torchao/quantization/quant_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from typing import Any, Callable
2626

2727
from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
28-
from .utils import (
28+
from torchao.utils import (
2929
TORCH_VERSION_AFTER_2_4,
3030
unwrap_tensor_subclass,
3131
)

torchao/quantization/quant_primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from torchao.kernel.intmm import int_scaled_matmul
1616
from torchao.kernel.intmm import safe_int_mm
17-
from .utils import TORCH_VERSION_AFTER_2_3
17+
from torchao.utils import TORCH_VERSION_AFTER_2_3
1818

1919

2020
__all__ = [

torchao/quantization/subclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
groupwise_affine_quantize_tensor_from_qparams,
1919
MappingType,
2020
)
21-
from .utils import find_multiple
21+
from torchao.utils import find_multiple
2222
from typing import Tuple, Optional, Callable, Dict, Any
2323

2424

0 commit comments

Comments
 (0)