File tree 4 files changed +25
-10
lines changed 4 files changed +25
-10
lines changed Original file line number Diff line number Diff line change @@ -601,5 +601,18 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
601
601
# make sure it compiles
602
602
torch ._export .aot_compile (m_unwrapped , example_inputs )
603
603
604
+ def test_register_apply_tensor_subclass (self ):
605
+ from torchao .quantization import register_apply_tensor_subclass
606
+ def apply_my_dtype (weight ):
607
+ return weight * 2
608
+
609
+ m = ToyLinearModel ().eval ().to (torch .bfloat16 ).to ("cuda" )
610
+ with self .assertRaisesWithRegex ("not supported" ):
611
+ quantize (m , "my_dtype" )
612
+
613
+ register_apply_tensor_subclass ("my_dtype" , apply_my_dtype )
614
+ # make sure it runs
615
+ quantize (m , "my_dtype" )
616
+
604
617
if __name__ == "__main__" :
605
618
unittest .main ()
Original file line number Diff line number Diff line change 25
25
26
26
from torchao .quantization import (
27
27
autoquant ,
28
+ quantize ,
29
+ register_apply_tensor_subclass ,
28
30
)
29
31
from . import dtypes
30
32
31
33
__all__ = [
32
34
"dtypes" ,
33
35
"autoquant" ,
36
+ "quantize" ,
37
+ "register_apply_tensor_subclass" ,
34
38
]
Original file line number Diff line number Diff line change 14
14
from .autoquant import *
15
15
16
16
__all__ = [
17
- "DynamicallyPerAxisQuantizedLinear" ,
18
- "apply_weight_only_int8_quant" ,
19
- "apply_dynamic_quant" ,
20
- "change_linear_weights_to_int8_dqtensors" ,
21
- "change_linear_weights_to_int8_woqtensors" ,
22
- "change_linear_weights_to_int4_woqtensors" ,
23
17
"swap_conv2d_1x1_to_linear"
24
18
"safe_int_mm" ,
25
19
"autoquant" ,
31
25
"swap_linear_with_smooth_fq_linear" ,
32
26
"smooth_fq_linear_to_inference" ,
33
27
"set_smooth_fq_attribute" ,
34
- "Int8DynamicallyQuantizedLinearWeight" ,
35
- "Int8WeightOnlyQuantizedLinearWeight" ,
36
- "Int4WeightOnlyQuantizedLinearWeight" ,
37
28
"compute_error" ,
38
- "WeightOnlyInt8QuantLinear" ,
39
29
"Int4WeightOnlyGPTQQuantizer" ,
40
30
"Int4WeightOnlyQuantizer" ,
41
31
"quantize_affine" ,
42
32
"dequantize_affine" ,
43
33
"choose_qprams_affine" ,
34
+ "quantize" ,
35
+ "register_apply_tensor_subclass" ,
44
36
]
Original file line number Diff line number Diff line change 41
41
Int4WeightOnlyGPTQQuantizer ,
42
42
Int4WeightOnlyQuantizer ,
43
43
)
44
+ import logging
44
45
from .autoquant import autoquant , AutoQuantizableLinearWeight
45
46
46
47
@@ -438,3 +439,8 @@ def get_per_token_block_size(x):
438
439
"int8_weight_only" : int8wo (),
439
440
"int8_dynamic" : int8da_int8w (),
440
441
}
442
+
443
+ def register_apply_tensor_subclass (name : str , apply_tensor_subclass : Callable ):
444
+ if name in _APPLY_TS_TABLE :
445
+ logging .warning (f"shortcut string { name } already exist, overwriting" )
446
+ _APPLY_TS_TABLE [name ] = apply_tensor_subclass
You can’t perform that action at this time.
0 commit comments