Skip to content

Commit 2b56245

Browse files
authored
Add register_apply_tensor_subclass (#366)
Summary: `register_apply_tensor_subclass` allows users to add a string shortcut for a new apply_tensor_subclass function, they can use this to test their new dtype tensor subclass see `test/quantization/test_quant_api.py -k test_register_apply_tensor_subclass` for detail Test Plan: python test/quantization/test_quant_api.py -k test_register_apply_tensor_subclass Reviewers: Subscribers: Tasks: Tags:
1 parent ca19e23 commit 2b56245

File tree

4 files changed

+45
-14
lines changed

4 files changed

+45
-14
lines changed

test/quantization/test_quant_api.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@
3131
Int8WeightOnlyQuantizedLinearWeight,
3232
Int4WeightOnlyQuantizedLinearWeight,
3333
)
34+
from torchao import quantize
3435
from torchao.quantization.quant_api import (
3536
_replace_with_custom_fn_if_matches_filter,
3637
Quantizer,
3738
TwoStepQuantizer,
38-
quantize,
3939
int8da_int4w,
4040
int4wo,
4141
int8wo,
@@ -51,6 +51,7 @@
5151
from torchao.utils import unwrap_tensor_subclass
5252
import copy
5353
import tempfile
54+
from torch.testing._internal.common_utils import TestCase
5455

5556

5657
def dynamic_quant(model, example_inputs):
@@ -147,7 +148,7 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
147148
_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
148149
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)
149150

150-
class TestQuantFlow(unittest.TestCase):
151+
class TestQuantFlow(TestCase):
151152
def test_dynamic_quant_gpu_singleline(self):
152153
m = ToyLinearModel().eval()
153154
example_inputs = m.example_inputs()
@@ -601,5 +602,20 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
601602
# make sure it compiles
602603
torch._export.aot_compile(m_unwrapped, example_inputs)
603604

605+
def test_register_apply_tensor_subclass(self):
606+
from torchao import register_apply_tensor_subclass
607+
def apply_my_dtype(weight):
608+
return weight * 2
609+
610+
m = ToyLinearModel().eval()
611+
example_inputs = m.example_inputs()
612+
with self.assertRaisesRegex(ValueError, "not supported"):
613+
quantize(m, "my_dtype")
614+
615+
register_apply_tensor_subclass("my_dtype", apply_my_dtype)
616+
# make sure it runs
617+
quantize(m, "my_dtype")
618+
m(*example_inputs)
619+
604620
if __name__ == "__main__":
605621
unittest.main()

torchao/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@
2525

2626
from torchao.quantization import (
2727
autoquant,
28+
quantize,
29+
register_apply_tensor_subclass,
2830
)
2931
from . import dtypes
3032

3133
__all__ = [
3234
"dtypes",
3335
"autoquant",
36+
"quantize",
37+
"register_apply_tensor_subclass",
3438
]

torchao/quantization/__init__.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@
1414
from .autoquant import *
1515

1616
__all__ = [
17-
"DynamicallyPerAxisQuantizedLinear",
18-
"apply_weight_only_int8_quant",
19-
"apply_dynamic_quant",
20-
"change_linear_weights_to_int8_dqtensors",
21-
"change_linear_weights_to_int8_woqtensors",
22-
"change_linear_weights_to_int4_woqtensors",
2317
"swap_conv2d_1x1_to_linear"
2418
"safe_int_mm",
2519
"autoquant",
@@ -31,14 +25,12 @@
3125
"swap_linear_with_smooth_fq_linear",
3226
"smooth_fq_linear_to_inference",
3327
"set_smooth_fq_attribute",
34-
"Int8DynamicallyQuantizedLinearWeight",
35-
"Int8WeightOnlyQuantizedLinearWeight",
36-
"Int4WeightOnlyQuantizedLinearWeight",
3728
"compute_error",
38-
"WeightOnlyInt8QuantLinear",
3929
"Int4WeightOnlyGPTQQuantizer",
4030
"Int4WeightOnlyQuantizer",
4131
"quantize_affine",
4232
"dequantize_affine",
4333
"choose_qprams_affine",
34+
"quantize",
35+
"register_apply_tensor_subclass",
4436
]

torchao/quantization/quant_api.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
Int4WeightOnlyGPTQQuantizer,
4242
Int4WeightOnlyQuantizer,
4343
)
44+
import logging
4445
from .autoquant import autoquant, AutoQuantizableLinearWeight
4546

4647

@@ -50,13 +51,14 @@
5051
"TwoStepQuantizer",
5152
"Int4WeightOnlyGPTQQuantizer",
5253
"Int4WeightOnlyQuantizer",
53-
"quantize",
5454
"autoquant",
5555
"_get_subclass_inserter",
56+
"quantize",
5657
"int8da_int4w",
5758
"int8da_int8w",
5859
"int4wo",
5960
"int8wo",
61+
"register_apply_tensor_subclass",
6062
]
6163

6264
from .GPTQ import (
@@ -292,7 +294,8 @@ def filter_fn(module, fqn):
292294
m = quantize(m, apply_weight_quant, filter_fn)
293295
"""
294296
if isinstance(apply_tensor_subclass, str):
295-
assert apply_tensor_subclass in _APPLY_TS_TABLE, f"{apply_tensor_subclass} not supported: {_APPLY_TS_TABLE.keys()}"
297+
if apply_tensor_subclass not in _APPLY_TS_TABLE:
298+
raise ValueError(f"{apply_tensor_subclass} not supported: {_APPLY_TS_TABLE.keys()}")
296299
apply_tensor_subclass = _APPLY_TS_TABLE[apply_tensor_subclass]
297300

298301
assert not isinstance(apply_tensor_subclass, str)
@@ -438,3 +441,19 @@ def get_per_token_block_size(x):
438441
"int8_weight_only": int8wo(),
439442
"int8_dynamic": int8da_int8w(),
440443
}
444+
445+
def register_apply_tensor_subclass(name: str, apply_tensor_subclass: Callable):
446+
"""Register a string shortcut for `apply_tensor_subclass` that takes a weight Tensor
447+
as input and ouptuts a tensor with tensor subclass applied
448+
449+
Example:
450+
def apply_my_dtype(weight):
451+
return weight * 2
452+
453+
register_apply_tensor_subclass("my_dtype", apply_my_dtype)
454+
# calls `apply_my_dtype` on weights
455+
quantize(m, "my_dtype")
456+
"""
457+
if name in _APPLY_TS_TABLE:
458+
logging.warning(f"shortcut string {name} already exist, overwriting")
459+
_APPLY_TS_TABLE[name] = apply_tensor_subclass

0 commit comments

Comments
 (0)