Skip to content

Commit 6bb39a3

Browse files
metascroyfacebook-github-bot
authored andcommitted
Add backward compatible types to pt2e prepare (#2244)
Summary: Pull Request resolved: #2244 Differential Revision: D75248288
1 parent adc78b7 commit 6bb39a3

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torchao/quantization/pt2e/quantize_pt2e.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_fuse_conv_bn_,
2626
_get_node_name_to_scope,
2727
)
28+
from typing import Union
2829

2930
from .convert import _convert_to_reference_decomposed_fx
3031
from .prepare import prepare
@@ -39,7 +40,7 @@
3940

4041
def prepare_pt2e(
4142
model: GraphModule,
42-
quantizer: Quantizer,
43+
quantizer: Union[Quantizer, torch.ao.quantization.quantizer.quantizer.Quantizer],
4344
) -> GraphModule:
4445
"""Prepare a model for post training quantization
4546
@@ -127,7 +128,7 @@ def calibrate(model, data_loader):
127128

128129
def prepare_qat_pt2e(
129130
model: GraphModule,
130-
quantizer: Quantizer,
131+
quantizer: Quantizer | torch.ao.quantization.quantizer.quantizer.Quantizer,
131132
) -> GraphModule:
132133
"""Prepare a model for quantization aware training
133134

0 commit comments

Comments
 (0)