Skip to content

Commit db29fda

Browse files
committed
NXP backend: Add NeutronQuantizer
1 parent 4717459 commit db29fda

File tree

5 files changed

+1142
-0
lines changed

5 files changed

+1142
-0
lines changed
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024-2025 NXP
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import List, Tuple, Union, Optional
8+
9+
import torch
10+
from torch import fx
11+
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
12+
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
13+
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
14+
15+
from executorch.backends.nxp.quantizer.patterns import (
16+
AddmmPattern,
17+
AvgPoolPattern,
18+
Conv1dPattern,
19+
Conv2dPattern,
20+
LinearPattern,
21+
MaxPoolPattern,
22+
PadPattern,
23+
QuantizationPattern,
24+
ReluInPlacePattern,
25+
ReluPattern,
26+
PermutePattern,
27+
ReshapePattern,
28+
SoftMaxPattern,
29+
)
30+
from executorch.backends.nxp.quantizer.utils import no_outside_users, find_sequential_partitions_aten, is_annotated
31+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
32+
OperatorConfig,
33+
QuantizationAnnotation,
34+
QuantizationConfig,
35+
QuantizationSpec,
36+
)
37+
from executorch.src.executorch.backends.nxp.quantizer.patterns import AddmmPattern
38+
39+
40+
class NeutronAtenQuantizer(Quantizer):
41+
def __init__(
42+
self, pattern: QuantizationPattern, quantization_config: QuantizationConfig
43+
) -> None:
44+
super().__init__()
45+
self.pattern = pattern
46+
self.quantization_config = quantization_config
47+
48+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
49+
fused_partitions = find_sequential_partitions_aten(
50+
model,
51+
self.pattern.partition_types(),
52+
)
53+
54+
input_act_qspec = self.quantization_config.input_activation
55+
weight_qspec = self.quantization_config.weight
56+
bias_qspec = self.quantization_config.bias
57+
output_act_qspec = self.quantization_config.output_activation
58+
59+
for fused_partition in fused_partitions:
60+
if not no_outside_users(fused_partition):
61+
continue
62+
63+
anchors = self.pattern.get_anchors(model, fused_partition)
64+
if not anchors or anchors.empty:
65+
continue
66+
if is_annotated(
67+
[x[0] for x in anchors.inputs + anchors.weights + anchors.biases + anchors.output]
68+
):
69+
continue
70+
71+
for output, *custom_spec in anchors.output:
72+
# pyre-ignore[16]: no attribute
73+
output.meta["quantization_annotation"] = QuantizationAnnotation(
74+
# pyre-ignore[6]: incompatible parameter type
75+
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
76+
_annotated=True,
77+
)
78+
79+
def annotate_inputs(
80+
inputs: Union[
81+
List[Tuple[fx.Node, int]],
82+
List[Tuple[fx.Node, int, DerivedQuantizationSpec],],
83+
],
84+
spec: Optional[QuantizationSpec],
85+
) -> None:
86+
for node, idx, *custom_spec in inputs:
87+
# pyre-ignore[16]: no attribute
88+
annotation = node.meta.get(
89+
"quantization_annotation",
90+
QuantizationAnnotation(_annotated=True),
91+
)
92+
arg = (
93+
# pyre-ignore[16]: no attribute
94+
node.args[idx]
95+
if isinstance(idx, int)
96+
# pyre-ignore[16]: no attribute
97+
else node.args[idx[0]][idx[1]]
98+
)
99+
annotation.input_qspec_map[arg] = (
100+
custom_spec[0] if custom_spec else spec
101+
)
102+
# pyre-ignore[16]: no attribute
103+
node.meta["quantization_annotation"] = annotation
104+
105+
def annotate_weights_or_biases(
106+
weights_or_biases: List[Tuple[fx.Node, int]],
107+
spec: Optional[QuantizationSpec],
108+
) -> None:
109+
for node, idx, *custom_spec in weights_or_biases:
110+
annotation = node.meta.get(
111+
"quantization_annotation",
112+
QuantizationAnnotation(_annotated=True),
113+
)
114+
annotation.input_qspec_map[node.args[idx]] = (
115+
custom_spec[0] if custom_spec else spec
116+
)
117+
node.meta["quantization_annotation"] = annotation
118+
119+
# pyre-ignore[6]: incompatible parameter type
120+
annotate_inputs(anchors.inputs, input_act_qspec)
121+
annotate_weights_or_biases(anchors.weights, weight_qspec)
122+
# pyre-ignore[6]: incompatible parameter type
123+
annotate_weights_or_biases(anchors.biases, bias_qspec)
124+
return model
125+
126+
def validate(self, model: fx.GraphModule) -> None:
127+
pass
128+
129+
@classmethod
130+
def get_supported_operators(cls) -> List[OperatorConfig]:
131+
return []
132+
133+
134+
# Quantization Specification used by Neutron NPU
135+
act_qspec = QuantizationSpec(
136+
dtype=torch.int8,
137+
quant_min=-128,
138+
quant_max=127,
139+
qscheme=torch.per_tensor_affine,
140+
is_dynamic=False,
141+
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2 ** -12),
142+
)
143+
144+
wgt_qspec = QuantizationSpec(
145+
dtype=torch.int8,
146+
quant_min=-127,
147+
quant_max=127,
148+
qscheme=torch.per_tensor_symmetric,
149+
is_dynamic=False,
150+
observer_or_fake_quant_ctr=MinMaxObserver,
151+
ch_axis=0
152+
)
153+
154+
wgt_fc_qspec = QuantizationSpec(
155+
dtype=torch.int8,
156+
quant_min=-127,
157+
quant_max=127,
158+
qscheme=torch.per_tensor_symmetric,
159+
is_dynamic=False,
160+
observer_or_fake_quant_ctr=MinMaxObserver,
161+
)
162+
# Bias Quantization Specification is as follows:
163+
# dtype = torch.int32
164+
# quant_min, quant_max - full int32 range
165+
# qcheme = torch.per_channel_symetric (for Conv), torch.per_tensor_symetric for Addmmn ==> i.e. zero_point = 0
166+
# scale = input_scale * weight_scale
167+
# Is set by the *PatternQuantizer directly.
168+
bias_qspec = None
169+
170+
171+
class NeutronQuantizer(ComposableQuantizer):
172+
def __init__(self):
173+
static_qconfig = QuantizationConfig(
174+
act_qspec,
175+
act_qspec,
176+
wgt_qspec,
177+
None,
178+
)
179+
static_fc_qconfig = QuantizationConfig(
180+
act_qspec,
181+
act_qspec,
182+
wgt_fc_qspec,
183+
None
184+
)
185+
super().__init__(
186+
[
187+
NeutronAtenQuantizer(AddmmPattern(), static_fc_qconfig),
188+
NeutronAtenQuantizer(Conv1dPattern(), static_qconfig),
189+
NeutronAtenQuantizer(Conv2dPattern(), static_qconfig),
190+
NeutronAtenQuantizer(LinearPattern(), static_fc_qconfig),
191+
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig),
192+
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
193+
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
194+
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
195+
NeutronAtenQuantizer(PadPattern(), static_qconfig),
196+
NeutronAtenQuantizer(ReluPattern(), static_qconfig),
197+
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig),
198+
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig),
199+
]
200+
)
201+
202+
def transform_for_annotation(
203+
self, model: torch.fx.GraphModule
204+
) -> torch.fx.GraphModule:
205+
return model

0 commit comments

Comments
 (0)