Skip to content

Commit

Permalink
Migrate the quantizer to use aten ops directly (#4195)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4195

This major change allows a lot more flexibility in the quantizer, and reduces the dependency on the decompositions/graph tracing tools.

The motivation is that some of those do not preserve or propagate `source_fn_stack` information, resulting in quantization misses. SDPA is an example, where the underlying `bmm` ops cannot be quantized with `source_fn_stack` information alone, or MHA, which can hide its SDPA component and sometimes even `linear` ops depending on the model (see ViT for an example).

Also note than in most cases, we match single nodes anyway, with a 1-1 mapping between the op (either nn.Module or nn.functional) and the aten op, so using the aten op directly is simply easier.

Summary of the changes:
- change the quantizer to match aten ops directly, through `node.target`
- propagate required changes to the `QuantFusion` pass
- update/remove existing patterns

Reviewed By: dulinriley

Differential Revision: D59552606

fbshipit-source-id: 0bc39679df9d4abfcca0f0091ec96fb94e5177b8
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jul 16, 2024
1 parent a7ac3d5 commit a22e809
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 113 deletions.
4 changes: 2 additions & 2 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceGenericQuantizer,
CadenceAtenQuantizer,
CadenceQuantizer,
)
from executorch.backends.cadence.aot.utils import model_is_quantized
Expand Down Expand Up @@ -64,7 +64,7 @@ def quantize_pt2(

# Get patterns and apply fusion of dq -> op -> q to qop
patterns = [
assert_is_instance(q, CadenceGenericQuantizer).pattern
assert_is_instance(q, CadenceAtenQuantizer).pattern
for q in quantizer.quantizers
]
QuantFusion(patterns)(converted_model)
Expand Down
14 changes: 4 additions & 10 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,19 @@
BmmPattern,
Conv1dPattern,
Conv2dPattern,
LayerNormFunctionalPattern,
LayerNormPattern,
LinearFunctionalPattern,
LinearPattern,
MatmulPattern,
ReluPattern,
)
from executorch.backends.cadence.aot.quantizer.utils import (
create_zero_bias_int32,
find_sequential_partitions_aten,
get_conv_args,
quantize_tensor_multiplier,
)
from executorch.exir.pass_base import ExportPass
from torch import fx
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.utils.fuser_utils import legalize_graph
Expand Down Expand Up @@ -310,7 +308,7 @@ def __init__(self, patterns) -> None:

def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
for pattern in self.patterns:
fused_partitions = find_sequential_partitions(
fused_partitions = find_sequential_partitions_aten(
graph_module,
pattern.partition_types(),
)
Expand Down Expand Up @@ -373,9 +371,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
quant_node,
op_node,
)
elif isinstance(pattern, LinearPattern) or isinstance(
pattern, LinearFunctionalPattern
):
elif isinstance(pattern, LinearPattern):
args, kwargs = get_args_and_kwargs_linear(
graph_module,
inputs_inputs,
Expand All @@ -385,9 +381,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
bias_inputs,
quant_node,
)
elif isinstance(pattern, LayerNormPattern) or isinstance(
pattern, LayerNormFunctionalPattern
):
elif isinstance(pattern, LayerNormPattern):
args, kwargs = get_args_and_kwargs_layer_norm(
graph_module,
inputs_inputs,
Expand Down
104 changes: 20 additions & 84 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Callable, List, Optional, Tuple, Type, Union
from typing import List, Optional, Tuple, Union

import torch
from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams
Expand Down Expand Up @@ -47,17 +47,15 @@ class PartitionAnchors:

class QuantizationPattern(ABC):
@abstractmethod
def partition_types(
self,
) -> Union[List[Type[torch.nn.Module]], List[Callable[..., torch.Tensor]]]:
def partition_types(self) -> list[OpOverload]:
"""
List of types to be passed to find_sequential_partitions.
List of types to be passed to find_sequential_partitions_aten.
"""
pass

@abstractmethod
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Optional[PartitionAnchors]:
pass

Expand All @@ -71,8 +69,8 @@ def replacement_op(self) -> OpOverload:


class AddmmPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.addmm]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.addmm.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down Expand Up @@ -103,8 +101,8 @@ def replacement_op(self) -> OpOverload:


class BmmPattern(QuantizationPattern):
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
return [torch.bmm]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.bmm.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand All @@ -123,8 +121,8 @@ def replacement_op(self) -> OpOverload:


class Conv1dPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.Conv1d]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down Expand Up @@ -161,8 +159,8 @@ def replacement_op(self) -> OpOverload:


class Conv2dPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.Conv2d]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv2d.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down Expand Up @@ -199,32 +197,8 @@ def replacement_op(self) -> OpOverload:


class LayerNormPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.LayerNorm]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
layer_norm_node = fused_partition[0].nodes[-1]

# Weights and biases are used as fp32 by our kernel, so they are
# passed in as others here along with the normalized shape.
return PartitionAnchors(
inputs=[(layer_norm_node, 0)],
weights=[],
biases=[],
# Ordering: normalized_shape, weights, bias
others=[(layer_norm_node, 1), (layer_norm_node, 2), (layer_norm_node, 3)],
output=[(layer_norm_node,)],
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_layer_norm.default


class LayerNormFunctionalPattern(QuantizationPattern):
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
return [torch.nn.functional.layer_norm]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.layer_norm.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down Expand Up @@ -257,8 +231,8 @@ def replacement_op(self) -> OpOverload:


class LinearPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.Linear]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.linear.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down Expand Up @@ -294,47 +268,9 @@ def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear.default


class LinearFunctionalPattern(QuantizationPattern):
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
return [torch.nn.functional.linear]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
linear_node = fused_partition[0].nodes[-1]

bias_qspec = DerivedQuantizationSpec(
derived_from=[
(linear_node.args[0], linear_node),
(linear_node.args[1], linear_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)

# Keep bias empty if not supplied
bias = []
if len(linear_node.args) > 2 and linear_node.args[2] is not None:
bias = [(linear_node, 2, bias_qspec)]

return PartitionAnchors(
inputs=[(linear_node, 0)],
weights=[(linear_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(linear_node,)],
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear.default


class MatmulPattern(QuantizationPattern):
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
return [torch.matmul]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.matmul.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand All @@ -353,8 +289,8 @@ def replacement_op(self) -> OpOverload:


class ReluPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.ReLU]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.relu.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down
26 changes: 11 additions & 15 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
BmmPattern,
Conv1dPattern,
Conv2dPattern,
LayerNormFunctionalPattern,
LayerNormPattern,
LinearFunctionalPattern,
LinearPattern,
MatmulPattern,
QuantizationPattern,
ReluPattern,
)
from executorch.backends.cadence.aot.quantizer.utils import (
find_sequential_partitions_aten,
is_annotated,
no_outside_users,
)
Expand All @@ -31,7 +30,6 @@
from torch import fx

from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
Expand Down Expand Up @@ -63,7 +61,7 @@
bias_qspec: Optional[QuantizationSpec] = None


class CadenceGenericQuantizer(Quantizer):
class CadenceAtenQuantizer(Quantizer):
def __init__(
self, pattern: QuantizationPattern, quantization_config: QuantizationConfig
) -> None:
Expand All @@ -72,7 +70,7 @@ def __init__(
self.quantization_config = quantization_config

def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
fused_partitions = find_sequential_partitions(
fused_partitions = find_sequential_partitions_aten(
model,
self.pattern.partition_types(),
)
Expand Down Expand Up @@ -154,15 +152,13 @@ def __init__(self) -> None:
)
super().__init__(
[
CadenceGenericQuantizer(AddmmPattern(), static_qconfig),
CadenceGenericQuantizer(BmmPattern(), static_qconfig),
CadenceGenericQuantizer(Conv1dPattern(), static_qconfig),
CadenceGenericQuantizer(Conv2dPattern(), static_qconfig),
CadenceGenericQuantizer(LayerNormPattern(), static_qconfig),
CadenceGenericQuantizer(LayerNormFunctionalPattern(), static_qconfig),
CadenceGenericQuantizer(LinearPattern(), static_qconfig),
CadenceGenericQuantizer(LinearFunctionalPattern(), static_qconfig),
CadenceGenericQuantizer(MatmulPattern(), static_qconfig),
CadenceGenericQuantizer(ReluPattern(), static_qconfig),
CadenceAtenQuantizer(AddmmPattern(), static_qconfig),
CadenceAtenQuantizer(BmmPattern(), static_qconfig),
CadenceAtenQuantizer(Conv1dPattern(), static_qconfig),
CadenceAtenQuantizer(Conv2dPattern(), static_qconfig),
CadenceAtenQuantizer(LayerNormPattern(), static_qconfig),
CadenceAtenQuantizer(LinearPattern(), static_qconfig),
CadenceAtenQuantizer(MatmulPattern(), static_qconfig),
CadenceAtenQuantizer(ReluPattern(), static_qconfig),
]
)
Loading

0 comments on commit a22e809

Please sign in to comment.