diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index ff893f4e45..1c5d47bda9 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -19,7 +19,7 @@ ) from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion from executorch.backends.cadence.aot.quantizer.quantizer import ( - CadenceGenericQuantizer, + CadenceAtenQuantizer, CadenceQuantizer, ) from executorch.backends.cadence.aot.utils import model_is_quantized @@ -58,7 +58,7 @@ def quantize_pt2( # Get patterns and apply fusion of dq -> op -> q to qop patterns = [ - assert_is_instance(q, CadenceGenericQuantizer).pattern + assert_is_instance(q, CadenceAtenQuantizer).pattern for q in quantizer.quantizers ] QuantFusion(patterns)(converted_model) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 803379b3bd..af1a1c720a 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -14,21 +14,19 @@ BmmPattern, Conv1dPattern, Conv2dPattern, - LayerNormFunctionalPattern, LayerNormPattern, - LinearFunctionalPattern, LinearPattern, MatmulPattern, ReluPattern, ) from executorch.backends.cadence.aot.quantizer.utils import ( create_zero_bias_int32, + find_sequential_partitions_aten, get_conv_args, quantize_tensor_multiplier, ) from executorch.exir.pass_base import ExportPass from torch import fx -from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions from torch.fx import GraphModule from torch.fx.passes.infra.pass_base import PassResult from torch.fx.passes.utils.fuser_utils import legalize_graph @@ -310,7 +308,7 @@ def __init__(self, patterns) -> None: def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 for pattern in self.patterns: - fused_partitions = find_sequential_partitions( + fused_partitions = find_sequential_partitions_aten( graph_module, pattern.partition_types(), ) @@ -375,9 +373,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 quant_node, op_node, ) - elif isinstance(pattern, LinearPattern) or isinstance( - pattern, LinearFunctionalPattern - ): + elif isinstance(pattern, LinearPattern): args, kwargs = get_args_and_kwargs_linear( graph_module, inputs_inputs, @@ -387,9 +383,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 bias_inputs, quant_node, ) - elif isinstance(pattern, LayerNormPattern) or isinstance( - pattern, LayerNormFunctionalPattern - ): + elif isinstance(pattern, LayerNormPattern): args, kwargs = get_args_and_kwargs_layer_norm( graph_module, inputs_inputs, diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 381bbbf6f1..943b9e473a 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -4,14 +4,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional, Tuple, Type, Union +from typing import List, Optional, Tuple, Union import torch from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams from torch import fx +from torch._ops import OpOverload from torch.ao.quantization.quantizer import ( DerivedQuantizationSpec, SharedQuantizationSpec, @@ -44,18 +47,20 @@ class PartitionAnchors: class QuantizationPattern(ABC): @abstractmethod - def partition_types(self): + def partition_types(self) -> list[OpOverload]: """ - List of types to be passed to find_sequential_partitions. + List of types to be passed to find_sequential_partitions_aten. """ pass @abstractmethod - def get_anchors(self, gm, fused_partition) -> Optional[PartitionAnchors]: + def get_anchors( + self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Optional[PartitionAnchors]: pass @abstractmethod - def replacement_op(self) -> Callable[..., Any]: + def replacement_op(self) -> OpOverload: """ Operator (most likely a custom one) that this partition should be fused into in the backend. Refer to the QuantFusion pass for examples. @@ -64,8 +69,8 @@ def replacement_op(self) -> Callable[..., Any]: class AddmmPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.addmm] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.addmm.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -91,13 +96,13 @@ def get_anchors( output=[(addmm_node,)], ) - def replacement_op(self): + def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_linear class BmmPattern(QuantizationPattern): - def partition_types(self) -> List[Callable[..., torch.Tensor]]: - return [torch.bmm] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.bmm.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -111,13 +116,13 @@ def get_anchors( output=[(bmm_node,)], ) - def replacement_op(self): + def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_matmul.default class Conv1dPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.nn.Conv1d] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv1d.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -149,13 +154,13 @@ def get_anchors( output=[(conv1d_node,)], ) - def replacement_op(self): + def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_conv.default class Conv2dPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.nn.Conv2d] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv2d.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -187,37 +192,17 @@ def get_anchors( output=[(conv2d_node,)], ) - def replacement_op(self): + def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_conv.default class LayerNormPattern(QuantizationPattern): - def partition_types(self): - return [torch.nn.LayerNorm] - - def get_anchors(self, gm, fused_partition) -> PartitionAnchors: - layer_norm_node = fused_partition[0].nodes[-1] - - # Weights and biases are used as fp32 by our kernel, so they are - # passed in as others here along with the normalized shape. - return PartitionAnchors( - inputs=[(layer_norm_node, 0)], - weights=[], - biases=[], - # Ordering: normalized_shape, weights, bias - others=[(layer_norm_node, 1), (layer_norm_node, 2), (layer_norm_node, 3)], - output=[(layer_norm_node,)], - ) + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.layer_norm.default] - def replacement_op(self): - return torch.ops.cadence.quantized_layer_norm.default - - -class LayerNormFunctionalPattern(QuantizationPattern): - def partition_types(self): - return [torch.nn.functional.layer_norm] - - def get_anchors(self, gm, fused_partition) -> PartitionAnchors: + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: layer_norm_node = fused_partition[0].nodes[-1] others = [(layer_norm_node, 1)] @@ -241,13 +226,13 @@ def get_anchors(self, gm, fused_partition) -> PartitionAnchors: output=[(layer_norm_node,)], ) - def replacement_op(self): + def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_layer_norm.default class LinearPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.nn.Linear] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.linear.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -279,51 +264,13 @@ def get_anchors( output=[(linear_node,)], ) - def replacement_op(self): - return torch.ops.cadence.quantized_linear.default - - -class LinearFunctionalPattern(QuantizationPattern): - def partition_types(self): - return [torch.nn.functional.linear] - - def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: - linear_node = fused_partition[0].nodes[-1] - - bias_qspec = DerivedQuantizationSpec( - derived_from=[ - (linear_node.args[0], linear_node), - (linear_node.args[1], linear_node), - ], - derive_qparams_fn=get_bias_qparams, - dtype=torch.int32, - quant_min=-(2**31), - quant_max=2**31 - 1, - qscheme=torch.per_tensor_affine, - ) - - # Keep bias empty if not supplied - bias = [] - if len(linear_node.args) > 2 and linear_node.args[2] is not None: - bias = [(linear_node, 2, bias_qspec)] - - return PartitionAnchors( - inputs=[(linear_node, 0)], - weights=[(linear_node, 1)], - # pyre-fixme[6]: Incompatible parameter type - biases=bias, - output=[(linear_node,)], - ) - - def replacement_op(self): + def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_linear.default class MatmulPattern(QuantizationPattern): - def partition_types(self): - return [torch.matmul] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.matmul.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -337,13 +284,13 @@ def get_anchors( output=[(matmul_node,)], ) - def replacement_op(self): + def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_matmul.default class ReluPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.nn.ReLU] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.relu.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -359,5 +306,5 @@ def get_anchors( ], ) - def replacement_op(self): + def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_relu.default diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index f527562962..130e9436a6 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List +from typing import List, Optional import torch from executorch.backends.cadence.aot.quantizer.patterns import ( @@ -12,14 +12,14 @@ BmmPattern, Conv1dPattern, Conv2dPattern, - LayerNormFunctionalPattern, LayerNormPattern, - LinearFunctionalPattern, LinearPattern, MatmulPattern, + QuantizationPattern, ReluPattern, ) from executorch.backends.cadence.aot.quantizer.utils import ( + find_sequential_partitions_aten, is_annotated, no_outside_users, ) @@ -27,7 +27,6 @@ from torch import fx from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver -from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions from torch.ao.quantization.quantizer import Quantizer from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( @@ -56,17 +55,19 @@ observer_or_fake_quant_ctr=MinMaxObserver, ) -bias_qspec = None +bias_qspec: Optional[QuantizationSpec] = None -class CadenceGenericQuantizer(Quantizer): - def __init__(self, pattern, quantization_config): +class CadenceAtenQuantizer(Quantizer): + def __init__( + self, pattern: QuantizationPattern, quantization_config: QuantizationConfig + ): super().__init__() self.pattern = pattern self.quantization_config = quantization_config def annotate(self, model): - fused_partitions = find_sequential_partitions( + fused_partitions = find_sequential_partitions_aten( model, self.pattern.partition_types(), ) @@ -133,15 +134,13 @@ def __init__(self): ) super().__init__( [ - CadenceGenericQuantizer(AddmmPattern(), static_qconfig), - CadenceGenericQuantizer(BmmPattern(), static_qconfig), - CadenceGenericQuantizer(Conv1dPattern(), static_qconfig), - CadenceGenericQuantizer(Conv2dPattern(), static_qconfig), - CadenceGenericQuantizer(LayerNormPattern(), static_qconfig), - CadenceGenericQuantizer(LayerNormFunctionalPattern(), static_qconfig), - CadenceGenericQuantizer(LinearPattern(), static_qconfig), - CadenceGenericQuantizer(LinearFunctionalPattern(), static_qconfig), - CadenceGenericQuantizer(MatmulPattern(), static_qconfig), - CadenceGenericQuantizer(ReluPattern(), static_qconfig), + CadenceAtenQuantizer(AddmmPattern(), static_qconfig), + CadenceAtenQuantizer(BmmPattern(), static_qconfig), + CadenceAtenQuantizer(Conv1dPattern(), static_qconfig), + CadenceAtenQuantizer(Conv2dPattern(), static_qconfig), + CadenceAtenQuantizer(LayerNormPattern(), static_qconfig), + CadenceAtenQuantizer(LinearPattern(), static_qconfig), + CadenceAtenQuantizer(MatmulPattern(), static_qconfig), + CadenceAtenQuantizer(ReluPattern(), static_qconfig), ] ) diff --git a/backends/cadence/aot/quantizer/utils.py b/backends/cadence/aot/quantizer/utils.py index 21dac6b0b0..0fe320c16f 100644 --- a/backends/cadence/aot/quantizer/utils.py +++ b/backends/cadence/aot/quantizer/utils.py @@ -4,14 +4,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import itertools +from collections import OrderedDict from math import frexp, isclose, trunc -from typing import List, Tuple +from typing import Any, Dict, List, Tuple, Type import torch from torch import fx +from torch._ops import OpOverload from torch.ao.quantization import ObserverOrFakeQuantize from torch.fx import GraphModule +from torch.fx.passes.utils.source_matcher_utils import ( + check_subgraphs_connected, + SourcePartition, +) def quantize_tensor_multiplier( @@ -127,3 +134,101 @@ def get_bias_qparams( def get_conv_args(arg, first_val: int) -> List[fx.Node]: return arg if len(arg) == 2 else [first_val, arg[0]] + + +def get_aten_node_target_partitions( + graph: torch.fx.Graph, + wanted_original_aten_op: List[OpOverload], +) -> Dict[Any, List[SourcePartition]]: + """ + Args: + graph: The graph we want to partition + wanted_sources: List of orginal_aten ops (OpOverload) + + Returns: + Dictionary mapping aten ops that were given to a list of SourcePartitions + that correspond to the list of nodes that were decomposed from the given + aten ops. + """ + modules: Dict[Type, Dict[str, List[torch.fx.Node]]] = {} + + for node in graph.nodes: + # The metadata source_fn should contain a tuple of a unique name for the + # source, and the source function if the node is decomposed from a + # function, or the type of module if the node is decomposed from a leaf + # module + + if (source_fn_st := node.meta.get("source_fn_stack", None)) is None: + continue + + source_fn = source_fn_st[-1] + if node.target not in wanted_original_aten_op: + continue + + diff_modules = modules.setdefault(source_fn[1], {}) + partition = diff_modules.setdefault(node.name, []) + partition.append(node) + + def make_partition( + nodes: List[torch.fx.Node], module_type: Type + ) -> SourcePartition: + input_nodes = set() + output_nodes = set() + params = set() + for node in nodes: + for arg in node.args: + if isinstance(arg, torch.fx.Node) and arg not in nodes: + input_nodes.add(arg) + + if node.op == "get_attr": + params.add(node) + + for user in node.users.keys(): + if user not in nodes: + output_nodes.add(node) + + return SourcePartition( + nodes, + module_type, + list(input_nodes), + list(output_nodes), + list(params), # type: ignore[arg-type] + ) + + ret: Dict[Type[Any], List[SourcePartition]] = {} + + for k, v in modules.items(): + ret[k] = [make_partition(partition, k) for partition in v.values()] + + return ret + + +def _partitions_sequential(partitions: List[SourcePartition]): + prev_partition = None + for partition in partitions: + if prev_partition is not None and not check_subgraphs_connected( + prev_partition, partition + ): + return False + prev_partition = partition + return True + + +def find_sequential_partitions_aten( + gm: torch.fx.GraphModule, + partition_types: List[Any], +): + typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict() + for partition_type in partition_types: + partitions = get_aten_node_target_partitions(gm.graph, [partition_type]) + typed_partitions[partition_type] = list( + itertools.chain.from_iterable(partitions.values()) + ) + + typed_partitions_list = list(typed_partitions.values()) + fusion_candidates = itertools.product(*typed_partitions_list) + fused_partitions = [] + for candidate in fusion_candidates: + if _partitions_sequential(candidate): # type: ignore[arg-type] + fused_partitions.append(candidate) + return fused_partitions