Skip to content

Commit 103bad4

Browse files
committed
Add register_apply_tensor_subclass
Summary: `register_apply_tensor_subclass` allows users to add a string shortcut for a new apply_tensor_subclass function, they can use this to test their new dtype tensor subclass see `test/quantization/test_quant_api.py -k test_register_apply_tensor_subclass` for detail Test Plan: python test/quantization/test_quant_api.py -k test_register_apply_tensor_subclass Reviewers: Subscribers: Tasks: Tags:
1 parent ca19e23 commit 103bad4

File tree

4 files changed

+25
-10
lines changed

4 files changed

+25
-10
lines changed

test/quantization/test_quant_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,5 +601,18 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
601601
# make sure it compiles
602602
torch._export.aot_compile(m_unwrapped, example_inputs)
603603

604+
def test_register_apply_tensor_subclass(self):
605+
from torchao.quantization import register_apply_tensor_subclass
606+
def apply_my_dtype(weight):
607+
return weight * 2
608+
609+
m = ToyLinearModel().eval().to(torch.bfloat16).to("cuda")
610+
with self.assertRaisesWithRegex("not supported"):
611+
quantize(m, "my_dtype")
612+
613+
register_apply_tensor_subclass("my_dtype", apply_my_dtype)
614+
# make sure it runs
615+
quantize(m, "my_dtype")
616+
604617
if __name__ == "__main__":
605618
unittest.main()

torchao/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@
2525

2626
from torchao.quantization import (
2727
autoquant,
28+
quantize,
29+
register_apply_tensor_subclass,
2830
)
2931
from . import dtypes
3032

3133
__all__ = [
3234
"dtypes",
3335
"autoquant",
36+
"quantize",
37+
"register_apply_tensor_subclass",
3438
]

torchao/quantization/__init__.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@
1414
from .autoquant import *
1515

1616
__all__ = [
17-
"DynamicallyPerAxisQuantizedLinear",
18-
"apply_weight_only_int8_quant",
19-
"apply_dynamic_quant",
20-
"change_linear_weights_to_int8_dqtensors",
21-
"change_linear_weights_to_int8_woqtensors",
22-
"change_linear_weights_to_int4_woqtensors",
2317
"swap_conv2d_1x1_to_linear"
2418
"safe_int_mm",
2519
"autoquant",
@@ -31,14 +25,12 @@
3125
"swap_linear_with_smooth_fq_linear",
3226
"smooth_fq_linear_to_inference",
3327
"set_smooth_fq_attribute",
34-
"Int8DynamicallyQuantizedLinearWeight",
35-
"Int8WeightOnlyQuantizedLinearWeight",
36-
"Int4WeightOnlyQuantizedLinearWeight",
3728
"compute_error",
38-
"WeightOnlyInt8QuantLinear",
3929
"Int4WeightOnlyGPTQQuantizer",
4030
"Int4WeightOnlyQuantizer",
4131
"quantize_affine",
4232
"dequantize_affine",
4333
"choose_qprams_affine",
34+
"quantize",
35+
"register_apply_tensor_subclass",
4436
]

torchao/quantization/quant_api.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
Int4WeightOnlyGPTQQuantizer,
4242
Int4WeightOnlyQuantizer,
4343
)
44+
import logging
4445
from .autoquant import autoquant, AutoQuantizableLinearWeight
4546

4647

@@ -438,3 +439,8 @@ def get_per_token_block_size(x):
438439
"int8_weight_only": int8wo(),
439440
"int8_dynamic": int8da_int8w(),
440441
}
442+
443+
def register_apply_tensor_subclass(name: str, apply_tensor_subclass: Callable):
444+
if name in _APPLY_TS_TABLE:
445+
logging.warning(f"shortcut string {name} already exist, overwriting")
446+
_APPLY_TS_TABLE[name] = apply_tensor_subclass

0 commit comments

Comments
 (0)