Skip to content

Commit 913a9e9

Browse files
committed
migrate prototype/quantized_training to configs
Summary: As titled Test Plan: ``` pytest test/prototype/test_quantized_training.py -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e4b9e86 ghstack-comment-id: 2706898680 Pull Request resolved: #1855
1 parent 64e5f8c commit 913a9e9

File tree

3 files changed

+66
-28
lines changed

3 files changed

+66
-28
lines changed

torchao/prototype/quantized_training/bitnet.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from torch.distributed._tensor import DTensor
1313
from torch.utils._triton import has_triton
1414

15-
from torchao.quantization.quant_api import _get_linear_subclass_inserter
15+
from torchao.core.config import AOBaseConfig
16+
from torchao.quantization.transform_module import (
17+
register_quantize_module_handler,
18+
)
1619
from torchao.utils import TorchAOBaseTensor
1720

1821
from .int8 import quantize_int8_rowwise
@@ -232,10 +235,22 @@ def backward(ctx, grad_output):
232235
return grad_input, grad_weight, grad_bias
233236

234237

235-
def bitnet_training():
236-
return _get_linear_subclass_inserter(
237-
BitNetTrainingLinearWeight, allow_requires_grad=True
238-
)
238+
class BitNetTrainingConfig(AOBaseConfig):
239+
pass
240+
241+
242+
# for bc
243+
bitnet_training = BitNetTrainingConfig
244+
245+
246+
@register_quantize_module_handler(BitNetTrainingConfig)
247+
def _bitnet_training_transform(
248+
module: torch.nn.Module,
249+
config: BitNetTrainingConfig,
250+
) -> torch.nn.Module:
251+
new_weight = BitNetTrainingLinearWeight(module.weight)
252+
module.weight = torch.nn.Parameter(new_weight, requires_grad=True)
253+
return module
239254

240255

241256
def _pack_i2_in_i8(x: Tensor):

torchao/prototype/quantized_training/int8.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from torch import Tensor
55
from torch.utils._python_dispatch import return_and_correct_aliasing
66

7-
from torchao.quantization.quant_api import _get_linear_subclass_inserter
7+
from torchao.core.config import AOBaseConfig
8+
from torchao.quantization.transform_module import (
9+
register_quantize_module_handler,
10+
)
811
from torchao.utils import TorchAOBaseTensor
912

1013
aten = torch.ops.aten
@@ -293,7 +296,19 @@ def _(func, types, args, kwargs):
293296
return return_and_correct_aliasing(func, args, kwargs, out)
294297

295298

296-
def int8_weight_only_quantized_training():
297-
return _get_linear_subclass_inserter(
298-
Int8QuantizedTrainingLinearWeight.from_float, allow_requires_grad=True
299-
)
299+
class Int8WeightOnlyQuantizedTrainingConfig(AOBaseConfig):
300+
pass
301+
302+
303+
# for bc
304+
int8_weight_only_quantized_training = Int8WeightOnlyQuantizedTrainingConfig
305+
306+
307+
@register_quantize_module_handler(Int8WeightOnlyQuantizedTrainingConfig)
308+
def _int8_weight_only_quantized_training_transform(
309+
module: torch.nn.Module,
310+
config: Int8WeightOnlyQuantizedTrainingConfig,
311+
) -> torch.nn.Module:
312+
new_weight = Int8QuantizedTrainingLinearWeight.from_float(module.weight)
313+
module.weight = torch.nn.Parameter(new_weight, requires_grad=True)
314+
return module

torchao/prototype/quantized_training/int8_mixed_precision.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
from typing import Any, NamedTuple, Optional, Tuple, Union
1+
from dataclasses import dataclass
2+
from typing import Any, Optional, Tuple, Union
23

34
import torch
45
import torch.utils._pytree as pytree
56
from torch import Tensor, nn
67
from torch.utils._triton import has_triton
78

8-
from torchao.quantization.quant_api import _get_linear_subclass_inserter
9+
from torchao.core.config import AOBaseConfig
10+
from torchao.quantization.transform_module import (
11+
register_quantize_module_handler,
12+
)
913
from torchao.utils import TorchAOBaseTensor
1014

1115
from .int8 import quantize_int8_rowwise
@@ -23,10 +27,16 @@ def scaled_int8_mm(
2327
return torch._int_mm(A, B) * col_scale.view(-1) * row_scale.view(-1, 1)
2428

2529

26-
class Int8MixedPrecisionTrainingConfig(NamedTuple):
30+
@dataclass
31+
class Int8MixedPrecisionTrainingConfig(AOBaseConfig):
2732
output: bool = True
2833
grad_input: bool = True
2934
grad_weight: bool = True
35+
module_swap: bool = False
36+
37+
38+
# for bc
39+
int8_mixed_precision_training = Int8MixedPrecisionTrainingConfig
3040

3141

3242
_DEFAULT_CONFIG = Int8MixedPrecisionTrainingConfig()
@@ -265,25 +275,23 @@ def backward(ctx, grad_output):
265275
return grad_input, grad_weight, grad_bias, None
266276

267277

268-
def int8_mixed_precision_training(
269-
config: Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG,
270-
*,
271-
module_swap: bool = False,
278+
@register_quantize_module_handler(Int8MixedPrecisionTrainingConfig)
279+
def _int8_mixed_precision_training_transform(
280+
module: torch.nn.Module,
281+
config: Int8MixedPrecisionTrainingConfig,
272282
):
283+
module_swap = config.module_swap
284+
273285
# TODO: skip small layers that don't have perf gain.
274286
if module_swap:
275287
# module swap implementation
276-
def convert_linear(linear: nn.Linear):
277-
linear.__class__ = Int8MixedPrecisionTrainingLinear
278-
linear.config = config
279-
return linear
280-
281-
return convert_linear
288+
module.__class__ = Int8MixedPrecisionTrainingLinear
289+
module.config = config
290+
return module
282291

283292
else:
284293
# tensor subclass implementation
285-
return _get_linear_subclass_inserter(
286-
Int8MixedPrecisionTrainingLinearWeight,
287-
config=config,
288-
allow_requires_grad=True,
289-
)
294+
295+
new_weight = Int8MixedPrecisionTrainingLinearWeight(module.weight, config)
296+
module.weight = torch.nn.Parameter(new_weight, requires_grad=True)
297+
return module

0 commit comments

Comments
 (0)