|
| 1 | +# Copyright 2024 NXP |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +# Quantizer for Neutron NPU. |
| 7 | + |
| 8 | +from typing import List, Type |
| 9 | + |
| 10 | +import torch |
| 11 | +from torch import fx |
| 12 | +from torch._ops import OpOverload |
| 13 | +from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver |
| 14 | +from torch.ao.quantization.quantizer import ( |
| 15 | + FixedQParamsQuantizationSpec, |
| 16 | + SharedQuantizationSpec, |
| 17 | +) |
| 18 | +from torch.ao.quantization.quantizer import QuantizationSpec |
| 19 | +from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer |
| 20 | +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig |
| 21 | + |
| 22 | +from executorch.backends.cadence.aot.quantizer.patterns import ( |
| 23 | + QuantizationPattern, |
| 24 | + PartitionAnchors, |
| 25 | + AddmmPattern, |
| 26 | + Conv1dPattern, |
| 27 | + Conv2dPattern, |
| 28 | + LinearPattern, |
| 29 | +) |
| 30 | +from executorch.backends.cadence.aot.quantizer.quantizer import CadenceAtenQuantizer |
| 31 | + |
| 32 | +# Quantization Specification used by Neutron NPU |
| 33 | +act_qspec = QuantizationSpec( |
| 34 | + dtype=torch.int8, |
| 35 | + quant_min=-128, |
| 36 | + quant_max=127, |
| 37 | + qscheme=torch.per_tensor_affine, |
| 38 | + is_dynamic=False, |
| 39 | + observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2 ** -12), |
| 40 | +) |
| 41 | + |
| 42 | +wgt_qspec = QuantizationSpec( |
| 43 | + dtype=torch.int8, |
| 44 | + quant_min=-128, |
| 45 | + quant_max=127, |
| 46 | + qscheme=torch.per_tensor_symmetric, |
| 47 | + is_dynamic=False, |
| 48 | + observer_or_fake_quant_ctr=MinMaxObserver, |
| 49 | + ch_axis=0 |
| 50 | +) |
| 51 | + |
| 52 | +wgt_fc_qspec = QuantizationSpec( |
| 53 | + dtype=torch.int8, |
| 54 | + quant_min=-128, |
| 55 | + quant_max=127, |
| 56 | + qscheme=torch.per_tensor_symmetric, |
| 57 | + is_dynamic=False, |
| 58 | + observer_or_fake_quant_ctr=MinMaxObserver, |
| 59 | +) |
| 60 | +# Bias Quantization Specification is as follows: |
| 61 | +# dtype = torch.int32 |
| 62 | +# quant_min, quant_max - full int32 range |
| 63 | +# qcheme = torch.per_channel_symetric (for Conv), torch.per_tensor_symetric for Addmmn ==> i.e. zero_point = 0 |
| 64 | +# scale = input_scale * weight_scale |
| 65 | +# Is set by the *PatternQuantizer directly. |
| 66 | +bias_qspec = None |
| 67 | + |
| 68 | + |
| 69 | +class SharedSpecPattern(QuantizationPattern): |
| 70 | + """ |
| 71 | + Quantization pattern for shared quantization. |
| 72 | +
|
| 73 | + The quantization is derived from the previous node quantization and the input and output shares the same |
| 74 | + quantization parameters (scale and zero-point). |
| 75 | + """ |
| 76 | + |
| 77 | + def partition_types(self) -> List[Type[torch.nn.Module]]: |
| 78 | + pass |
| 79 | + |
| 80 | + def get_anchors( |
| 81 | + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] |
| 82 | + ) -> PartitionAnchors | None: |
| 83 | + node = fused_partition[0].nodes[-1] |
| 84 | + assert len(fused_partition[0].input_nodes) == 1 |
| 85 | + prev_node = fused_partition[0].input_nodes[0] |
| 86 | + |
| 87 | + # In the case of a node with shared quantization spec has no previous node, return None to not quantize the node |
| 88 | + if not hasattr(prev_node, "meta") or "quantization_annotation" not in prev_node.meta: |
| 89 | + return None |
| 90 | + else: |
| 91 | + qspec = SharedQuantizationSpec(prev_node) |
| 92 | + |
| 93 | + return PartitionAnchors( |
| 94 | + inputs=[(node, 0)], |
| 95 | + weights=[], |
| 96 | + biases=[], |
| 97 | + output=[(node, qspec), ], |
| 98 | + ) |
| 99 | + |
| 100 | + def replacement_op(self): |
| 101 | + assert False |
| 102 | + |
| 103 | + |
| 104 | +class MaxPoolPattern(SharedSpecPattern): |
| 105 | + """ |
| 106 | + Quantizer for MaxPool2D operator. |
| 107 | + """ |
| 108 | + |
| 109 | + def partition_types(self): |
| 110 | + return [torch.ops.aten.max_pool2d.default] |
| 111 | + |
| 112 | + |
| 113 | +class AvgPoolPattern(SharedSpecPattern): |
| 114 | + """ |
| 115 | + Quantizer for AvgPool2D operator. |
| 116 | + """ |
| 117 | + |
| 118 | + def partition_types(self): |
| 119 | + return [torch.ops.aten.avg_pool2d.default] |
| 120 | + |
| 121 | + |
| 122 | +class PadPattern(SharedSpecPattern): |
| 123 | + """ |
| 124 | + Quantizer for Pad operator. |
| 125 | + """ |
| 126 | + |
| 127 | + def partition_types(self): |
| 128 | + return [torch.ops.aten.pad.default] |
| 129 | + |
| 130 | + |
| 131 | +class ReluPattern(SharedSpecPattern): |
| 132 | + """ |
| 133 | + Quantizer for Relu operator. Shared quantization spec is selected, as ReLU usually follows computation layer. |
| 134 | + """ |
| 135 | + |
| 136 | + def partition_types(self): |
| 137 | + return [torch.ops.aten.relu.default] |
| 138 | + |
| 139 | + |
| 140 | +class ReluInPlacePattern(SharedSpecPattern): |
| 141 | + """ |
| 142 | + Quantizer for Relu operator with param inplace=True. Shared quantization spec is selected, as ReLU usually |
| 143 | + follows computation layer. |
| 144 | + """ |
| 145 | + |
| 146 | + def partition_types(self): |
| 147 | + return [torch.ops.aten.relu_.default] |
| 148 | + |
| 149 | + |
| 150 | +class ReshapePattern(SharedSpecPattern): |
| 151 | + """ |
| 152 | + Quantizer for Reshape operator. |
| 153 | + """ |
| 154 | + |
| 155 | + def partition_types(self): |
| 156 | + return [torch.ops.aten.reshape.default] |
| 157 | + |
| 158 | + |
| 159 | +class PermutePattern(SharedSpecPattern): |
| 160 | + """ |
| 161 | + Quantizer for Permute operator. |
| 162 | + """ |
| 163 | + |
| 164 | + def partition_types(self): |
| 165 | + return [torch.ops.aten.permute.default] |
| 166 | + |
| 167 | + |
| 168 | +class SoftMaxPattern(QuantizationPattern): |
| 169 | + """ |
| 170 | + Quantizer for Softmax operator. |
| 171 | +
|
| 172 | + The quantization of Softmax output is fixed to scale 1/256, zero point -128, dtype int8. |
| 173 | + """ |
| 174 | + |
| 175 | + def partition_types(self) -> List[OpOverload]: |
| 176 | + return [torch.ops.aten.softmax.int] |
| 177 | + |
| 178 | + def get_anchors( |
| 179 | + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] |
| 180 | + ) -> PartitionAnchors: |
| 181 | + node = fused_partition[0].nodes[-1] |
| 182 | + assert len(fused_partition[0].input_nodes) == 1 |
| 183 | + |
| 184 | + qspec = FixedQParamsQuantizationSpec( |
| 185 | + dtype=torch.int8, |
| 186 | + scale=1.0 / 256.0, |
| 187 | + zero_point=-128, |
| 188 | + quant_min=-128, |
| 189 | + quant_max=127, |
| 190 | + qscheme=torch.per_tensor_affine, |
| 191 | + ) |
| 192 | + |
| 193 | + return PartitionAnchors( |
| 194 | + inputs=[(node, 0)], |
| 195 | + weights=[], |
| 196 | + biases=[], |
| 197 | + output=[(node, qspec), ], |
| 198 | + ) |
| 199 | + |
| 200 | + def replacement_op(self): |
| 201 | + assert False |
| 202 | + |
| 203 | + |
| 204 | +class NeutronQuantizer(ComposableQuantizer): |
| 205 | + def __init__(self): |
| 206 | + static_qconfig = QuantizationConfig( |
| 207 | + act_qspec, |
| 208 | + act_qspec, |
| 209 | + wgt_qspec, |
| 210 | + None, |
| 211 | + ) |
| 212 | + static_fc_qconfig = QuantizationConfig( |
| 213 | + act_qspec, |
| 214 | + act_qspec, |
| 215 | + wgt_fc_qspec, |
| 216 | + None |
| 217 | + ) |
| 218 | + super().__init__( |
| 219 | + [ |
| 220 | + CadenceAtenQuantizer(AddmmPattern(), static_fc_qconfig), |
| 221 | + CadenceAtenQuantizer(Conv1dPattern(), static_qconfig), |
| 222 | + CadenceAtenQuantizer(Conv2dPattern(), static_qconfig), |
| 223 | + CadenceAtenQuantizer(LinearPattern(), static_fc_qconfig), |
| 224 | + CadenceAtenQuantizer(MaxPoolPattern(), static_qconfig), |
| 225 | + CadenceAtenQuantizer(SoftMaxPattern(), static_qconfig), |
| 226 | + CadenceAtenQuantizer(ReshapePattern(), static_qconfig), |
| 227 | + CadenceAtenQuantizer(PermutePattern(), static_qconfig), |
| 228 | + CadenceAtenQuantizer(PadPattern(), static_qconfig), |
| 229 | + CadenceAtenQuantizer(ReluPattern(), static_qconfig), |
| 230 | + CadenceAtenQuantizer(ReluInPlacePattern(), static_qconfig), |
| 231 | + CadenceAtenQuantizer(AvgPoolPattern(), static_qconfig), |
| 232 | + ] |
| 233 | + ) |
| 234 | + |
| 235 | + def transform_for_annotation( |
| 236 | + self, model: torch.fx.GraphModule |
| 237 | + ) -> torch.fx.GraphModule: |
| 238 | + return model |
| 239 | + |
| 240 | + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 241 | + return super().annotate(model) |
| 242 | + |
| 243 | + def validate(self, model: torch.fx.GraphModule) -> None: |
| 244 | + return super().validate(model) |
0 commit comments