We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent adc78b7 commit 6bb39a3Copy full SHA for 6bb39a3
torchao/quantization/pt2e/quantize_pt2e.py
@@ -25,6 +25,7 @@
25
_fuse_conv_bn_,
26
_get_node_name_to_scope,
27
)
28
+from typing import Union
29
30
from .convert import _convert_to_reference_decomposed_fx
31
from .prepare import prepare
@@ -39,7 +40,7 @@
39
40
41
def prepare_pt2e(
42
model: GraphModule,
- quantizer: Quantizer,
43
+ quantizer: Union[Quantizer, torch.ao.quantization.quantizer.quantizer.Quantizer],
44
) -> GraphModule:
45
"""Prepare a model for post training quantization
46
@@ -127,7 +128,7 @@ def calibrate(model, data_loader):
127
128
129
def prepare_qat_pt2e(
130
131
+ quantizer: Quantizer | torch.ao.quantization.quantizer.quantizer.Quantizer,
132
133
"""Prepare a model for quantization aware training
134
0 commit comments