Skip to content

Commit 191c158

Browse files
committed
NXP backend: Add NeutronQuantizer
1 parent 4717459 commit 191c158

File tree

3 files changed

+669
-0
lines changed

3 files changed

+669
-0
lines changed
+244
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Copyright 2024 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# Quantizer for Neutron NPU.
7+
8+
from typing import List, Type
9+
10+
import torch
11+
from torch import fx
12+
from torch._ops import OpOverload
13+
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
14+
from torch.ao.quantization.quantizer import (
15+
FixedQParamsQuantizationSpec,
16+
SharedQuantizationSpec,
17+
)
18+
from torch.ao.quantization.quantizer import QuantizationSpec
19+
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
20+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig
21+
22+
from executorch.backends.cadence.aot.quantizer.patterns import (
23+
QuantizationPattern,
24+
PartitionAnchors,
25+
AddmmPattern,
26+
Conv1dPattern,
27+
Conv2dPattern,
28+
LinearPattern,
29+
)
30+
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceAtenQuantizer
31+
32+
# Quantization Specification used by Neutron NPU
33+
act_qspec = QuantizationSpec(
34+
dtype=torch.int8,
35+
quant_min=-128,
36+
quant_max=127,
37+
qscheme=torch.per_tensor_affine,
38+
is_dynamic=False,
39+
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2 ** -12),
40+
)
41+
42+
wgt_qspec = QuantizationSpec(
43+
dtype=torch.int8,
44+
quant_min=-128,
45+
quant_max=127,
46+
qscheme=torch.per_tensor_symmetric,
47+
is_dynamic=False,
48+
observer_or_fake_quant_ctr=MinMaxObserver,
49+
ch_axis=0
50+
)
51+
52+
wgt_fc_qspec = QuantizationSpec(
53+
dtype=torch.int8,
54+
quant_min=-128,
55+
quant_max=127,
56+
qscheme=torch.per_tensor_symmetric,
57+
is_dynamic=False,
58+
observer_or_fake_quant_ctr=MinMaxObserver,
59+
)
60+
# Bias Quantization Specification is as follows:
61+
# dtype = torch.int32
62+
# quant_min, quant_max - full int32 range
63+
# qcheme = torch.per_channel_symetric (for Conv), torch.per_tensor_symetric for Addmmn ==> i.e. zero_point = 0
64+
# scale = input_scale * weight_scale
65+
# Is set by the *PatternQuantizer directly.
66+
bias_qspec = None
67+
68+
69+
class SharedSpecPattern(QuantizationPattern):
70+
"""
71+
Quantization pattern for shared quantization.
72+
73+
The quantization is derived from the previous node quantization and the input and output shares the same
74+
quantization parameters (scale and zero-point).
75+
"""
76+
77+
def partition_types(self) -> List[Type[torch.nn.Module]]:
78+
pass
79+
80+
def get_anchors(
81+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
82+
) -> PartitionAnchors | None:
83+
node = fused_partition[0].nodes[-1]
84+
assert len(fused_partition[0].input_nodes) == 1
85+
prev_node = fused_partition[0].input_nodes[0]
86+
87+
# In the case of a node with shared quantization spec has no previous node, return None to not quantize the node
88+
if not hasattr(prev_node, "meta") or "quantization_annotation" not in prev_node.meta:
89+
return None
90+
else:
91+
qspec = SharedQuantizationSpec(prev_node)
92+
93+
return PartitionAnchors(
94+
inputs=[(node, 0)],
95+
weights=[],
96+
biases=[],
97+
output=[(node, qspec), ],
98+
)
99+
100+
def replacement_op(self):
101+
assert False
102+
103+
104+
class MaxPoolPattern(SharedSpecPattern):
105+
"""
106+
Quantizer for MaxPool2D operator.
107+
"""
108+
109+
def partition_types(self):
110+
return [torch.ops.aten.max_pool2d.default]
111+
112+
113+
class AvgPoolPattern(SharedSpecPattern):
114+
"""
115+
Quantizer for AvgPool2D operator.
116+
"""
117+
118+
def partition_types(self):
119+
return [torch.ops.aten.avg_pool2d.default]
120+
121+
122+
class PadPattern(SharedSpecPattern):
123+
"""
124+
Quantizer for Pad operator.
125+
"""
126+
127+
def partition_types(self):
128+
return [torch.ops.aten.pad.default]
129+
130+
131+
class ReluPattern(SharedSpecPattern):
132+
"""
133+
Quantizer for Relu operator. Shared quantization spec is selected, as ReLU usually follows computation layer.
134+
"""
135+
136+
def partition_types(self):
137+
return [torch.ops.aten.relu.default]
138+
139+
140+
class ReluInPlacePattern(SharedSpecPattern):
141+
"""
142+
Quantizer for Relu operator with param inplace=True. Shared quantization spec is selected, as ReLU usually
143+
follows computation layer.
144+
"""
145+
146+
def partition_types(self):
147+
return [torch.ops.aten.relu_.default]
148+
149+
150+
class ReshapePattern(SharedSpecPattern):
151+
"""
152+
Quantizer for Reshape operator.
153+
"""
154+
155+
def partition_types(self):
156+
return [torch.ops.aten.reshape.default]
157+
158+
159+
class PermutePattern(SharedSpecPattern):
160+
"""
161+
Quantizer for Permute operator.
162+
"""
163+
164+
def partition_types(self):
165+
return [torch.ops.aten.permute.default]
166+
167+
168+
class SoftMaxPattern(QuantizationPattern):
169+
"""
170+
Quantizer for Softmax operator.
171+
172+
The quantization of Softmax output is fixed to scale 1/256, zero point -128, dtype int8.
173+
"""
174+
175+
def partition_types(self) -> List[OpOverload]:
176+
return [torch.ops.aten.softmax.int]
177+
178+
def get_anchors(
179+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
180+
) -> PartitionAnchors:
181+
node = fused_partition[0].nodes[-1]
182+
assert len(fused_partition[0].input_nodes) == 1
183+
184+
qspec = FixedQParamsQuantizationSpec(
185+
dtype=torch.int8,
186+
scale=1.0 / 256.0,
187+
zero_point=-128,
188+
quant_min=-128,
189+
quant_max=127,
190+
qscheme=torch.per_tensor_affine,
191+
)
192+
193+
return PartitionAnchors(
194+
inputs=[(node, 0)],
195+
weights=[],
196+
biases=[],
197+
output=[(node, qspec), ],
198+
)
199+
200+
def replacement_op(self):
201+
assert False
202+
203+
204+
class NeutronQuantizer(ComposableQuantizer):
205+
def __init__(self):
206+
static_qconfig = QuantizationConfig(
207+
act_qspec,
208+
act_qspec,
209+
wgt_qspec,
210+
None,
211+
)
212+
static_fc_qconfig = QuantizationConfig(
213+
act_qspec,
214+
act_qspec,
215+
wgt_fc_qspec,
216+
None
217+
)
218+
super().__init__(
219+
[
220+
CadenceAtenQuantizer(AddmmPattern(), static_fc_qconfig),
221+
CadenceAtenQuantizer(Conv1dPattern(), static_qconfig),
222+
CadenceAtenQuantizer(Conv2dPattern(), static_qconfig),
223+
CadenceAtenQuantizer(LinearPattern(), static_fc_qconfig),
224+
CadenceAtenQuantizer(MaxPoolPattern(), static_qconfig),
225+
CadenceAtenQuantizer(SoftMaxPattern(), static_qconfig),
226+
CadenceAtenQuantizer(ReshapePattern(), static_qconfig),
227+
CadenceAtenQuantizer(PermutePattern(), static_qconfig),
228+
CadenceAtenQuantizer(PadPattern(), static_qconfig),
229+
CadenceAtenQuantizer(ReluPattern(), static_qconfig),
230+
CadenceAtenQuantizer(ReluInPlacePattern(), static_qconfig),
231+
CadenceAtenQuantizer(AvgPoolPattern(), static_qconfig),
232+
]
233+
)
234+
235+
def transform_for_annotation(
236+
self, model: torch.fx.GraphModule
237+
) -> torch.fx.GraphModule:
238+
return model
239+
240+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
241+
return super().annotate(model)
242+
243+
def validate(self, model: torch.fx.GraphModule) -> None:
244+
return super().validate(model)

0 commit comments

Comments
 (0)