Skip to content

Commit 10408cd

Browse files
committed
Specify Quant Type in AoT Compiler for better results
1 parent 2ec8870 commit 10408cd

File tree

4 files changed

+50
-25
lines changed

4 files changed

+50
-25
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Any
1515

1616
from examples.models import MODEL_NAME_TO_MODEL
17-
from examples.xnnpack import MODEL_NAME_TO_OPTIONS
17+
from examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType
1818

1919
DEFAULT_RUNNERS = {
2020
"linux": "linux.2xlarge",
@@ -154,7 +154,7 @@ def export_models_for_ci() -> dict[str, dict]:
154154
if backend == "xnnpack":
155155
if name not in MODEL_NAME_TO_OPTIONS:
156156
continue
157-
if MODEL_NAME_TO_OPTIONS[name].quantization:
157+
if MODEL_NAME_TO_OPTIONS[name].quantization != QuantType.NONE:
158158
backend += "-quantization"
159159

160160
if MODEL_NAME_TO_OPTIONS[name].delegation:

examples/xnnpack/__init__.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,44 @@
77
# pyre-unsafe
88

99
from dataclasses import dataclass
10+
from enum import Enum
11+
12+
13+
class QuantType(Enum):
14+
NONE = 1
15+
# Used for Operations that don't have weights
16+
STATIC_PER_TENSOR = 2
17+
# Used best for CNN/RNN Models with Conv layers
18+
STATIC_PER_CHANNEL = 3
19+
# Used for Linear Layers and Transformer Based Models
20+
DYNAMIC_PER_CHANNEL = 4
1021

1122

1223
@dataclass
1324
class XNNPACKOptions(object):
14-
quantization: bool
25+
quantization: QuantType
1526
delegation: bool
1627

1728

1829
MODEL_NAME_TO_OPTIONS = {
19-
"linear": XNNPACKOptions(True, True),
20-
"add": XNNPACKOptions(True, True),
21-
"add_mul": XNNPACKOptions(True, True),
22-
"dl3": XNNPACKOptions(True, True),
23-
"ic3": XNNPACKOptions(True, True),
24-
"ic4": XNNPACKOptions(True, True),
25-
"mv2": XNNPACKOptions(True, True),
26-
"mv3": XNNPACKOptions(True, True),
27-
"resnet18": XNNPACKOptions(True, True),
28-
"resnet50": XNNPACKOptions(True, True),
29-
"vit": XNNPACKOptions(True, True),
30-
"w2l": XNNPACKOptions(True, True),
31-
"edsr": XNNPACKOptions(True, True),
32-
"mobilebert": XNNPACKOptions(True, True),
33-
"llama2": XNNPACKOptions(False, True),
34-
"emformer_join": XNNPACKOptions(True, True),
35-
"emformer_predict": XNNPACKOptions(True, True),
36-
"emformer_transcribe": XNNPACKOptions(True, True),
30+
"linear": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
31+
"add": XNNPACKOptions(QuantType.STATIC_PER_TENSOR, True),
32+
"add_mul": XNNPACKOptions(QuantType.STATIC_PER_TENSOR, True),
33+
"dl3": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
34+
"ic3": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
35+
"ic4": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
36+
"mv2": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
37+
"mv3": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
38+
"resnet18": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
39+
"resnet50": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
40+
"vit": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True),
41+
"w2l": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True),
42+
"edsr": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
43+
"mobilebert": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True),
44+
"llama2": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True),
45+
"emformer_join": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True),
46+
"emformer_predict": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True),
47+
"emformer_transcribe": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
3748
}
3849

3950

examples/xnnpack/aot_compiler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666

6767
args = parser.parse_args()
6868

69-
if not args.delegate:
69+
if not args.delegate and args.quantize:
7070
raise NotImplementedError(
7171
"T161880157: Quantization-only without delegation is not supported yet"
7272
)
@@ -79,6 +79,8 @@
7979
f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}."
8080
)
8181

82+
quant_type = MODEL_NAME_TO_OPTIONS[args.model_name].quantization
83+
8284
model, example_inputs, _, _ = EagerModelFactory.create_model(
8385
*MODEL_NAME_TO_MODEL[args.model_name]
8486
)
@@ -91,7 +93,7 @@
9193
if args.quantize:
9294
logging.info("Quantizing Model...")
9395
# TODO(T165162973): This pass shall eventually be folded into quantizer
94-
model = quantize(model, example_inputs)
96+
model = quantize(model, example_inputs, quant_type)
9597
ep = torch.export.export_for_training(model, example_inputs)
9698

9799
edge = to_edge_transform_and_lower(

examples/xnnpack/quantization/utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,25 @@
1313

1414
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
1515

16+
from .. import QuantType
1617

17-
def quantize(model, example_inputs):
18+
19+
def quantize(
20+
model, example_inputs, quant_type: QuantType = QuantType.STATIC_PER_TENSOR
21+
):
1822
"""This is the official recommended flow for quantization in pytorch 2.0 export"""
1923
logging.info(f"Original model: {model}")
2024
quantizer = XNNPACKQuantizer()
2125
# if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
22-
operator_config = get_symmetric_quantization_config(is_per_channel=False)
26+
is_per_channel = (
27+
quant_type == QuantType.STATIC_PER_CHANNEL
28+
or quant_type == QuantType.DYNAMIC_PER_CHANNEL
29+
)
30+
is_dynamic = quant_type == QuantType.DYNAMIC_PER_CHANNEL
31+
operator_config = get_symmetric_quantization_config(
32+
is_per_channel=is_per_channel,
33+
is_dynamic=is_dynamic,
34+
)
2335
quantizer.set_global(operator_config)
2436
m = prepare_pt2e(model, quantizer)
2537
# calibration

0 commit comments

Comments
 (0)