Skip to content

Commit 8726840

Browse files
committed
migrates prototype/mixed_precision to configs
Summary: Note: had to remove int4/int8 functionality to simplify the refactor. Whoever uses this script next and needs that functionality can add this back. Test Plan: ``` pytest test/prototype/test_mixed_precision.py -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 2b0f46d ghstack-comment-id: 2706815985 Pull Request resolved: #1854
1 parent 6f2019c commit 8726840

File tree

2 files changed

+44
-17
lines changed

2 files changed

+44
-17
lines changed

test/prototype/test_mixed_precision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class TestWeightOnlyQuantNaive(unittest.TestCase):
1414
def test_quantization_intNwo(self):
1515
# skip test int4wo for now since it is under development in torchao
16-
for quantization_bit in [2, 3, 5, 6, 8]:
16+
for quantization_bit in [2, 3, 5, 6]:
1717
for symmetric in [False, True]:
1818
with self.subTest(
1919
quantization_bit=quantization_bit, symmetric=symmetric

torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1+
from dataclasses import dataclass
2+
13
import torch
24

3-
from torchao.quantization import int4_weight_only, int8_weight_only
4-
from torchao.quantization.quant_api import _get_linear_subclass_inserter
5+
from torchao.core.config import AOBaseConfig
56
from torchao.quantization.quant_primitives import (
67
MappingType,
78
)
9+
from torchao.quantization.transform_module import (
10+
register_quantize_module_handler,
11+
)
812

913

10-
def intN_weight_only(group_size=32, n=8, symmetric=False):
14+
@dataclass
15+
class IntNWeightOnlyConfig(AOBaseConfig):
1116
"""
12-
Apply int N-bit weight only quantization to a linear layer.
17+
Configuration for applying int N-bit weight only quantization to a linear layer.
1318
Args:
1419
`group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32]
1520
`n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2]
@@ -18,6 +23,25 @@ def intN_weight_only(group_size=32, n=8, symmetric=False):
1823
quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize)
1924
"""
2025

26+
group_size: int = 32
27+
n: int = 8
28+
symmetric: bool = False
29+
30+
31+
# for bc
32+
intN_weight_only = IntNWeightOnlyConfig
33+
34+
35+
@register_quantize_module_handler(IntNWeightOnlyConfig)
36+
def _intN_weight_only_transform(
37+
module: torch.nn.Module,
38+
config: IntNWeightOnlyConfig,
39+
) -> torch.nn.Module:
40+
group_size = config.group_size
41+
n = config.n
42+
symmetric = config.symmetric
43+
weight = module.weight
44+
2145
# for asymmetric quantization
2246
def apply_intN_weight_only_quant_asym(weight):
2347
# avoid circular dependency
@@ -64,16 +88,19 @@ def apply_intN_weight_only_quant_sym(weight):
6488
zero_point_dtype=zero_point_dtype,
6589
)
6690

67-
try:
68-
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]"
69-
if n == 8:
70-
return int8_weight_only()
71-
elif n == 4:
72-
return int4_weight_only(group_size=group_size)
91+
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]"
92+
if n == 8:
93+
raise AssertionError(
94+
"Someone needs to refactor this code to handle int8_weight_only again"
95+
)
96+
elif n == 4:
97+
raise AssertionError(
98+
"Someone needs to refactor this code to handle int4_weight_only again"
99+
)
100+
else:
101+
if symmetric:
102+
new_weight = apply_intN_weight_only_quant_sym(weight)
73103
else:
74-
if symmetric:
75-
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_sym)
76-
else:
77-
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_asym)
78-
except Exception:
79-
raise
104+
new_weight = apply_intN_weight_only_quant_asym(weight)
105+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
106+
return module

0 commit comments

Comments
 (0)