Skip to content

Commit 6d1694d

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Migrate the quantizer to use aten ops directly (#4195)
Summary: Pull Request resolved: #4195 This major change allows a lot more flexibility in the quantizer, and reduces the dependency on the decompositions/graph tracing tools. The motivation is that some of those do not preserve or propagate `source_fn_stack` information, resulting in quantization misses. SDPA is an example, where the underlying `bmm` ops cannot be quantized with `source_fn_stack` information alone, or MHA, which can hide its SDPA component and sometimes even `linear` ops depending on the model (see ViT for an example). Also note than in most cases, we match single nodes anyway, with a 1-1 mapping between the op (either nn.Module or nn.functional) and the aten op, so using the aten op directly is simply easier. Summary of the changes: - change the quantizer to match aten ops directly, through `node.target` - propagate required changes to the `QuantFusion` pass - update/remove existing patterns Reviewed By: dulinriley Differential Revision: D59552606
1 parent fbe0af1 commit 6d1694d

File tree

6 files changed

+148
-113
lines changed

6 files changed

+148
-113
lines changed

backends/cadence/aot/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
2121
from executorch.backends.cadence.aot.quantizer.quantizer import (
22-
CadenceGenericQuantizer,
22+
CadenceAtenQuantizer,
2323
CadenceQuantizer,
2424
)
2525
from executorch.backends.cadence.aot.utils import model_is_quantized
@@ -64,7 +64,7 @@ def quantize_pt2(
6464

6565
# Get patterns and apply fusion of dq -> op -> q to qop
6666
patterns = [
67-
assert_is_instance(q, CadenceGenericQuantizer).pattern
67+
assert_is_instance(q, CadenceAtenQuantizer).pattern
6868
for q in quantizer.quantizers
6969
]
7070
QuantFusion(patterns)(converted_model)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,19 @@
1414
BmmPattern,
1515
Conv1dPattern,
1616
Conv2dPattern,
17-
LayerNormFunctionalPattern,
1817
LayerNormPattern,
19-
LinearFunctionalPattern,
2018
LinearPattern,
2119
MatmulPattern,
2220
ReluPattern,
2321
)
2422
from executorch.backends.cadence.aot.quantizer.utils import (
2523
create_zero_bias_int32,
24+
find_sequential_partitions_aten,
2625
get_conv_args,
2726
quantize_tensor_multiplier,
2827
)
2928
from executorch.exir.pass_base import ExportPass
3029
from torch import fx
31-
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
3230
from torch.fx import GraphModule
3331
from torch.fx.passes.infra.pass_base import PassResult
3432
from torch.fx.passes.utils.fuser_utils import legalize_graph
@@ -310,7 +308,7 @@ def __init__(self, patterns) -> None:
310308

311309
def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
312310
for pattern in self.patterns:
313-
fused_partitions = find_sequential_partitions(
311+
fused_partitions = find_sequential_partitions_aten(
314312
graph_module,
315313
pattern.partition_types(),
316314
)
@@ -373,9 +371,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
373371
quant_node,
374372
op_node,
375373
)
376-
elif isinstance(pattern, LinearPattern) or isinstance(
377-
pattern, LinearFunctionalPattern
378-
):
374+
elif isinstance(pattern, LinearPattern):
379375
args, kwargs = get_args_and_kwargs_linear(
380376
graph_module,
381377
inputs_inputs,
@@ -385,9 +381,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
385381
bias_inputs,
386382
quant_node,
387383
)
388-
elif isinstance(pattern, LayerNormPattern) or isinstance(
389-
pattern, LayerNormFunctionalPattern
390-
):
384+
elif isinstance(pattern, LayerNormPattern):
391385
args, kwargs = get_args_and_kwargs_layer_norm(
392386
graph_module,
393387
inputs_inputs,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 20 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from abc import ABC, abstractmethod
1010
from dataclasses import dataclass, field
11-
from typing import Callable, List, Optional, Tuple, Type, Union
11+
from typing import List, Optional, Tuple, Union
1212

1313
import torch
1414
from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams
@@ -47,17 +47,15 @@ class PartitionAnchors:
4747

4848
class QuantizationPattern(ABC):
4949
@abstractmethod
50-
def partition_types(
51-
self,
52-
) -> Union[List[Type[torch.nn.Module]], List[Callable[..., torch.Tensor]]]:
50+
def partition_types(self) -> list[OpOverload]:
5351
"""
54-
List of types to be passed to find_sequential_partitions.
52+
List of types to be passed to find_sequential_partitions_aten.
5553
"""
5654
pass
5755

5856
@abstractmethod
5957
def get_anchors(
60-
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
58+
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
6159
) -> Optional[PartitionAnchors]:
6260
pass
6361

@@ -71,8 +69,8 @@ def replacement_op(self) -> OpOverload:
7169

7270

7371
class AddmmPattern(QuantizationPattern):
74-
def partition_types(self) -> List[Type[torch.nn.Module]]:
75-
return [torch.addmm]
72+
def partition_types(self) -> List[OpOverload]:
73+
return [torch.ops.aten.addmm.default]
7674

7775
def get_anchors(
7876
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -103,8 +101,8 @@ def replacement_op(self) -> OpOverload:
103101

104102

105103
class BmmPattern(QuantizationPattern):
106-
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
107-
return [torch.bmm]
104+
def partition_types(self) -> List[OpOverload]:
105+
return [torch.ops.aten.bmm.default]
108106

109107
def get_anchors(
110108
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -123,8 +121,8 @@ def replacement_op(self) -> OpOverload:
123121

124122

125123
class Conv1dPattern(QuantizationPattern):
126-
def partition_types(self) -> List[Type[torch.nn.Module]]:
127-
return [torch.nn.Conv1d]
124+
def partition_types(self) -> List[OpOverload]:
125+
return [torch.ops.aten.conv1d.default]
128126

129127
def get_anchors(
130128
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -161,8 +159,8 @@ def replacement_op(self) -> OpOverload:
161159

162160

163161
class Conv2dPattern(QuantizationPattern):
164-
def partition_types(self) -> List[Type[torch.nn.Module]]:
165-
return [torch.nn.Conv2d]
162+
def partition_types(self) -> List[OpOverload]:
163+
return [torch.ops.aten.conv2d.default]
166164

167165
def get_anchors(
168166
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -199,32 +197,8 @@ def replacement_op(self) -> OpOverload:
199197

200198

201199
class LayerNormPattern(QuantizationPattern):
202-
def partition_types(self) -> List[Type[torch.nn.Module]]:
203-
return [torch.nn.LayerNorm]
204-
205-
def get_anchors(
206-
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
207-
) -> PartitionAnchors:
208-
layer_norm_node = fused_partition[0].nodes[-1]
209-
210-
# Weights and biases are used as fp32 by our kernel, so they are
211-
# passed in as others here along with the normalized shape.
212-
return PartitionAnchors(
213-
inputs=[(layer_norm_node, 0)],
214-
weights=[],
215-
biases=[],
216-
# Ordering: normalized_shape, weights, bias
217-
others=[(layer_norm_node, 1), (layer_norm_node, 2), (layer_norm_node, 3)],
218-
output=[(layer_norm_node,)],
219-
)
220-
221-
def replacement_op(self) -> OpOverload:
222-
return torch.ops.cadence.quantized_layer_norm.default
223-
224-
225-
class LayerNormFunctionalPattern(QuantizationPattern):
226-
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
227-
return [torch.nn.functional.layer_norm]
200+
def partition_types(self) -> List[OpOverload]:
201+
return [torch.ops.aten.layer_norm.default]
228202

229203
def get_anchors(
230204
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -257,8 +231,8 @@ def replacement_op(self) -> OpOverload:
257231

258232

259233
class LinearPattern(QuantizationPattern):
260-
def partition_types(self) -> List[Type[torch.nn.Module]]:
261-
return [torch.nn.Linear]
234+
def partition_types(self) -> List[OpOverload]:
235+
return [torch.ops.aten.linear.default]
262236

263237
def get_anchors(
264238
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -294,47 +268,9 @@ def replacement_op(self) -> OpOverload:
294268
return torch.ops.cadence.quantized_linear.default
295269

296270

297-
class LinearFunctionalPattern(QuantizationPattern):
298-
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
299-
return [torch.nn.functional.linear]
300-
301-
def get_anchors(
302-
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
303-
) -> PartitionAnchors:
304-
linear_node = fused_partition[0].nodes[-1]
305-
306-
bias_qspec = DerivedQuantizationSpec(
307-
derived_from=[
308-
(linear_node.args[0], linear_node),
309-
(linear_node.args[1], linear_node),
310-
],
311-
derive_qparams_fn=get_bias_qparams,
312-
dtype=torch.int32,
313-
quant_min=-(2**31),
314-
quant_max=2**31 - 1,
315-
qscheme=torch.per_tensor_affine,
316-
)
317-
318-
# Keep bias empty if not supplied
319-
bias = []
320-
if len(linear_node.args) > 2 and linear_node.args[2] is not None:
321-
bias = [(linear_node, 2, bias_qspec)]
322-
323-
return PartitionAnchors(
324-
inputs=[(linear_node, 0)],
325-
weights=[(linear_node, 1)],
326-
# pyre-fixme[6]: Incompatible parameter type
327-
biases=bias,
328-
output=[(linear_node,)],
329-
)
330-
331-
def replacement_op(self) -> OpOverload:
332-
return torch.ops.cadence.quantized_linear.default
333-
334-
335271
class MatmulPattern(QuantizationPattern):
336-
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
337-
return [torch.matmul]
272+
def partition_types(self) -> List[OpOverload]:
273+
return [torch.ops.aten.matmul.default]
338274

339275
def get_anchors(
340276
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -353,8 +289,8 @@ def replacement_op(self) -> OpOverload:
353289

354290

355291
class ReluPattern(QuantizationPattern):
356-
def partition_types(self) -> List[Type[torch.nn.Module]]:
357-
return [torch.nn.ReLU]
292+
def partition_types(self) -> List[OpOverload]:
293+
return [torch.ops.aten.relu.default]
358294

359295
def get_anchors(
360296
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,14 @@
1414
BmmPattern,
1515
Conv1dPattern,
1616
Conv2dPattern,
17-
LayerNormFunctionalPattern,
1817
LayerNormPattern,
19-
LinearFunctionalPattern,
2018
LinearPattern,
2119
MatmulPattern,
2220
QuantizationPattern,
2321
ReluPattern,
2422
)
2523
from executorch.backends.cadence.aot.quantizer.utils import (
24+
find_sequential_partitions_aten,
2625
is_annotated,
2726
no_outside_users,
2827
)
@@ -31,7 +30,6 @@
3130
from torch import fx
3231

3332
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
34-
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
3533
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
3634
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
3735
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
@@ -63,7 +61,7 @@
6361
bias_qspec: Optional[QuantizationSpec] = None
6462

6563

66-
class CadenceGenericQuantizer(Quantizer):
64+
class CadenceAtenQuantizer(Quantizer):
6765
def __init__(
6866
self, pattern: QuantizationPattern, quantization_config: QuantizationConfig
6967
) -> None:
@@ -72,7 +70,7 @@ def __init__(
7270
self.quantization_config = quantization_config
7371

7472
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
75-
fused_partitions = find_sequential_partitions(
73+
fused_partitions = find_sequential_partitions_aten(
7674
model,
7775
self.pattern.partition_types(),
7876
)
@@ -154,15 +152,13 @@ def __init__(self) -> None:
154152
)
155153
super().__init__(
156154
[
157-
CadenceGenericQuantizer(AddmmPattern(), static_qconfig),
158-
CadenceGenericQuantizer(BmmPattern(), static_qconfig),
159-
CadenceGenericQuantizer(Conv1dPattern(), static_qconfig),
160-
CadenceGenericQuantizer(Conv2dPattern(), static_qconfig),
161-
CadenceGenericQuantizer(LayerNormPattern(), static_qconfig),
162-
CadenceGenericQuantizer(LayerNormFunctionalPattern(), static_qconfig),
163-
CadenceGenericQuantizer(LinearPattern(), static_qconfig),
164-
CadenceGenericQuantizer(LinearFunctionalPattern(), static_qconfig),
165-
CadenceGenericQuantizer(MatmulPattern(), static_qconfig),
166-
CadenceGenericQuantizer(ReluPattern(), static_qconfig),
155+
CadenceAtenQuantizer(AddmmPattern(), static_qconfig),
156+
CadenceAtenQuantizer(BmmPattern(), static_qconfig),
157+
CadenceAtenQuantizer(Conv1dPattern(), static_qconfig),
158+
CadenceAtenQuantizer(Conv2dPattern(), static_qconfig),
159+
CadenceAtenQuantizer(LayerNormPattern(), static_qconfig),
160+
CadenceAtenQuantizer(LinearPattern(), static_qconfig),
161+
CadenceAtenQuantizer(MatmulPattern(), static_qconfig),
162+
CadenceAtenQuantizer(ReluPattern(), static_qconfig),
167163
]
168164
)

0 commit comments

Comments
 (0)