Skip to content

Commit

Permalink
Migrate the quantizer to use aten ops directly (#4195)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4195

This major change allows a lot more flexibility in the quantizer, and reduces the dependency on the decompositions/graph tracing tools.

The motivation is that some of those do not preserve or propagate `source_fn_stack` information, resulting in quantization misses. SDPA is an example, where the underlying `bmm` ops cannot be quantized with `source_fn_stack` information alone, or MHA, which can hide its SDPA component and sometimes even `linear` ops depending on the model (see ViT for an example).

Summary of the changes:
- change the quantizer to match aten ops directly, through `node.target`
- propagate required changes to the `QuantFusion` pass
- update/remove existing patterns

Differential Revision: D59552606
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jul 10, 2024
1 parent 238850b commit 4542695
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 121 deletions.
4 changes: 2 additions & 2 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceGenericQuantizer,
CadenceAtenQuantizer,
CadenceQuantizer,
)
from executorch.backends.cadence.aot.utils import model_is_quantized
Expand Down Expand Up @@ -58,7 +58,7 @@ def quantize_pt2(

# Get patterns and apply fusion of dq -> op -> q to qop
patterns = [
assert_is_instance(q, CadenceGenericQuantizer).pattern
assert_is_instance(q, CadenceAtenQuantizer).pattern
for q in quantizer.quantizers
]
QuantFusion(patterns)(converted_model)
Expand Down
14 changes: 4 additions & 10 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,19 @@
BmmPattern,
Conv1dPattern,
Conv2dPattern,
LayerNormFunctionalPattern,
LayerNormPattern,
LinearFunctionalPattern,
LinearPattern,
MatmulPattern,
ReluPattern,
)
from executorch.backends.cadence.aot.quantizer.utils import (
create_zero_bias_int32,
find_sequential_partitions_aten,
get_conv_args,
quantize_tensor_multiplier,
)
from executorch.exir.pass_base import ExportPass
from torch import fx
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.utils.fuser_utils import legalize_graph
Expand Down Expand Up @@ -310,7 +308,7 @@ def __init__(self, patterns) -> None:

def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
for pattern in self.patterns:
fused_partitions = find_sequential_partitions(
fused_partitions = find_sequential_partitions_aten(
graph_module,
pattern.partition_types(),
)
Expand Down Expand Up @@ -375,9 +373,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
quant_node,
op_node,
)
elif isinstance(pattern, LinearPattern) or isinstance(
pattern, LinearFunctionalPattern
):
elif isinstance(pattern, LinearPattern):
args, kwargs = get_args_and_kwargs_linear(
graph_module,
inputs_inputs,
Expand All @@ -387,9 +383,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
bias_inputs,
quant_node,
)
elif isinstance(pattern, LayerNormPattern) or isinstance(
pattern, LayerNormFunctionalPattern
):
elif isinstance(pattern, LayerNormPattern):
args, kwargs = get_args_and_kwargs_layer_norm(
graph_module,
inputs_inputs,
Expand Down
127 changes: 37 additions & 90 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Tuple, Type, Union
from typing import List, Optional, Tuple, Union

import torch
from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams

from torch import fx
from torch._ops import OpOverload
from torch.ao.quantization.quantizer import (
DerivedQuantizationSpec,
SharedQuantizationSpec,
Expand Down Expand Up @@ -44,18 +47,20 @@ class PartitionAnchors:

class QuantizationPattern(ABC):
@abstractmethod
def partition_types(self):
def partition_types(self) -> list[OpOverload]:
"""
List of types to be passed to find_sequential_partitions.
List of types to be passed to find_sequential_partitions_aten.
"""
pass

@abstractmethod
def get_anchors(self, gm, fused_partition) -> Optional[PartitionAnchors]:
def get_anchors(
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Optional[PartitionAnchors]:
pass

@abstractmethod
def replacement_op(self) -> Callable[..., Any]:
def replacement_op(self) -> OpOverload:
"""
Operator (most likely a custom one) that this partition should be fused into in
the backend. Refer to the QuantFusion pass for examples.
Expand All @@ -64,8 +69,8 @@ def replacement_op(self) -> Callable[..., Any]:


class AddmmPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.addmm]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.addmm.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand All @@ -91,13 +96,13 @@ def get_anchors(
output=[(addmm_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear


class BmmPattern(QuantizationPattern):
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
return [torch.bmm]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.bmm.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand All @@ -111,13 +116,13 @@ def get_anchors(
output=[(bmm_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_matmul.default


class Conv1dPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.Conv1d]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down Expand Up @@ -149,13 +154,13 @@ def get_anchors(
output=[(conv1d_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv.default


class Conv2dPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.Conv2d]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv2d.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down Expand Up @@ -187,37 +192,17 @@ def get_anchors(
output=[(conv2d_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv.default


class LayerNormPattern(QuantizationPattern):
def partition_types(self):
return [torch.nn.LayerNorm]

def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
layer_norm_node = fused_partition[0].nodes[-1]

# Weights and biases are used as fp32 by our kernel, so they are
# passed in as others here along with the normalized shape.
return PartitionAnchors(
inputs=[(layer_norm_node, 0)],
weights=[],
biases=[],
# Ordering: normalized_shape, weights, bias
others=[(layer_norm_node, 1), (layer_norm_node, 2), (layer_norm_node, 3)],
output=[(layer_norm_node,)],
)
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.layer_norm.default]

def replacement_op(self):
return torch.ops.cadence.quantized_layer_norm.default


class LayerNormFunctionalPattern(QuantizationPattern):
def partition_types(self):
return [torch.nn.functional.layer_norm]

def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
layer_norm_node = fused_partition[0].nodes[-1]

others = [(layer_norm_node, 1)]
Expand All @@ -241,13 +226,13 @@ def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
output=[(layer_norm_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_layer_norm.default


class LinearPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.Linear]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.linear.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down Expand Up @@ -279,51 +264,13 @@ def get_anchors(
output=[(linear_node,)],
)

def replacement_op(self):
return torch.ops.cadence.quantized_linear.default


class LinearFunctionalPattern(QuantizationPattern):
def partition_types(self):
return [torch.nn.functional.linear]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
linear_node = fused_partition[0].nodes[-1]

bias_qspec = DerivedQuantizationSpec(
derived_from=[
(linear_node.args[0], linear_node),
(linear_node.args[1], linear_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)

# Keep bias empty if not supplied
bias = []
if len(linear_node.args) > 2 and linear_node.args[2] is not None:
bias = [(linear_node, 2, bias_qspec)]

return PartitionAnchors(
inputs=[(linear_node, 0)],
weights=[(linear_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(linear_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear.default


class MatmulPattern(QuantizationPattern):
def partition_types(self):
return [torch.matmul]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.matmul.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand All @@ -337,13 +284,13 @@ def get_anchors(
output=[(matmul_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_matmul.default


class ReluPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.ReLU]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.relu.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand All @@ -359,5 +306,5 @@ def get_anchors(
],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_relu.default
35 changes: 17 additions & 18 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,29 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List
from typing import List, Optional

import torch
from executorch.backends.cadence.aot.quantizer.patterns import (
AddmmPattern,
BmmPattern,
Conv1dPattern,
Conv2dPattern,
LayerNormFunctionalPattern,
LayerNormPattern,
LinearFunctionalPattern,
LinearPattern,
MatmulPattern,
QuantizationPattern,
ReluPattern,
)
from executorch.backends.cadence.aot.quantizer.utils import (
find_sequential_partitions_aten,
is_annotated,
no_outside_users,
)

from torch import fx

from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
Expand Down Expand Up @@ -56,17 +55,19 @@
observer_or_fake_quant_ctr=MinMaxObserver,
)

bias_qspec = None
bias_qspec: Optional[QuantizationSpec] = None


class CadenceGenericQuantizer(Quantizer):
def __init__(self, pattern, quantization_config):
class CadenceAtenQuantizer(Quantizer):
def __init__(
self, pattern: QuantizationPattern, quantization_config: QuantizationConfig
):
super().__init__()
self.pattern = pattern
self.quantization_config = quantization_config

def annotate(self, model):
fused_partitions = find_sequential_partitions(
fused_partitions = find_sequential_partitions_aten(
model,
self.pattern.partition_types(),
)
Expand Down Expand Up @@ -133,15 +134,13 @@ def __init__(self):
)
super().__init__(
[
CadenceGenericQuantizer(AddmmPattern(), static_qconfig),
CadenceGenericQuantizer(BmmPattern(), static_qconfig),
CadenceGenericQuantizer(Conv1dPattern(), static_qconfig),
CadenceGenericQuantizer(Conv2dPattern(), static_qconfig),
CadenceGenericQuantizer(LayerNormPattern(), static_qconfig),
CadenceGenericQuantizer(LayerNormFunctionalPattern(), static_qconfig),
CadenceGenericQuantizer(LinearPattern(), static_qconfig),
CadenceGenericQuantizer(LinearFunctionalPattern(), static_qconfig),
CadenceGenericQuantizer(MatmulPattern(), static_qconfig),
CadenceGenericQuantizer(ReluPattern(), static_qconfig),
CadenceAtenQuantizer(AddmmPattern(), static_qconfig),
CadenceAtenQuantizer(BmmPattern(), static_qconfig),
CadenceAtenQuantizer(Conv1dPattern(), static_qconfig),
CadenceAtenQuantizer(Conv2dPattern(), static_qconfig),
CadenceAtenQuantizer(LayerNormPattern(), static_qconfig),
CadenceAtenQuantizer(LinearPattern(), static_qconfig),
CadenceAtenQuantizer(MatmulPattern(), static_qconfig),
CadenceAtenQuantizer(ReluPattern(), static_qconfig),
]
)
Loading

0 comments on commit 4542695

Please sign in to comment.