diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index ff893f4e45..1c5d47bda9 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -19,7 +19,7 @@ ) from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion from executorch.backends.cadence.aot.quantizer.quantizer import ( - CadenceGenericQuantizer, + CadenceAtenQuantizer, CadenceQuantizer, ) from executorch.backends.cadence.aot.utils import model_is_quantized @@ -58,7 +58,7 @@ def quantize_pt2( # Get patterns and apply fusion of dq -> op -> q to qop patterns = [ - assert_is_instance(q, CadenceGenericQuantizer).pattern + assert_is_instance(q, CadenceAtenQuantizer).pattern for q in quantizer.quantizers ] QuantFusion(patterns)(converted_model) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index d00028e422..5f0d847e84 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -14,21 +14,19 @@ BmmPattern, Conv1dPattern, Conv2dPattern, - LayerNormFunctionalPattern, LayerNormPattern, - LinearFunctionalPattern, LinearPattern, MatmulPattern, ReluPattern, ) from executorch.backends.cadence.aot.quantizer.utils import ( create_zero_bias_int32, + find_sequential_partitions_aten, get_conv_args, quantize_tensor_multiplier, ) from executorch.exir.pass_base import ExportPass from torch import fx -from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions from torch.fx import GraphModule from torch.fx.passes.infra.pass_base import PassResult from torch.fx.passes.utils.fuser_utils import legalize_graph @@ -310,7 +308,7 @@ def __init__(self, patterns) -> None: def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 for pattern in self.patterns: - fused_partitions = find_sequential_partitions( + fused_partitions = find_sequential_partitions_aten( graph_module, pattern.partition_types(), ) @@ -373,9 +371,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 quant_node, op_node, ) - elif isinstance(pattern, LinearPattern) or isinstance( - pattern, LinearFunctionalPattern - ): + elif isinstance(pattern, LinearPattern): args, kwargs = get_args_and_kwargs_linear( graph_module, inputs_inputs, @@ -385,9 +381,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 bias_inputs, quant_node, ) - elif isinstance(pattern, LayerNormPattern) or isinstance( - pattern, LayerNormFunctionalPattern - ): + elif isinstance(pattern, LayerNormPattern): args, kwargs = get_args_and_kwargs_layer_norm( graph_module, inputs_inputs, diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index e4fcbe901d..943b9e473a 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Callable, List, Optional, Tuple, Type, Union +from typing import List, Optional, Tuple, Union import torch from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams @@ -47,17 +47,15 @@ class PartitionAnchors: class QuantizationPattern(ABC): @abstractmethod - def partition_types( - self, - ) -> Union[List[Type[torch.nn.Module]], List[Callable[..., torch.Tensor]]]: + def partition_types(self) -> list[OpOverload]: """ - List of types to be passed to find_sequential_partitions. + List of types to be passed to find_sequential_partitions_aten. """ pass @abstractmethod def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule] ) -> Optional[PartitionAnchors]: pass @@ -71,8 +69,8 @@ def replacement_op(self) -> OpOverload: class AddmmPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.addmm] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.addmm.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -103,8 +101,8 @@ def replacement_op(self) -> OpOverload: class BmmPattern(QuantizationPattern): - def partition_types(self) -> List[Callable[..., torch.Tensor]]: - return [torch.bmm] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.bmm.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -123,8 +121,8 @@ def replacement_op(self) -> OpOverload: class Conv1dPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.nn.Conv1d] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv1d.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -161,8 +159,8 @@ def replacement_op(self) -> OpOverload: class Conv2dPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.nn.Conv2d] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv2d.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -199,32 +197,8 @@ def replacement_op(self) -> OpOverload: class LayerNormPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.nn.LayerNorm] - - def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: - layer_norm_node = fused_partition[0].nodes[-1] - - # Weights and biases are used as fp32 by our kernel, so they are - # passed in as others here along with the normalized shape. - return PartitionAnchors( - inputs=[(layer_norm_node, 0)], - weights=[], - biases=[], - # Ordering: normalized_shape, weights, bias - others=[(layer_norm_node, 1), (layer_norm_node, 2), (layer_norm_node, 3)], - output=[(layer_norm_node,)], - ) - - def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_layer_norm.default - - -class LayerNormFunctionalPattern(QuantizationPattern): - def partition_types(self) -> List[Callable[..., torch.Tensor]]: - return [torch.nn.functional.layer_norm] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.layer_norm.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -257,8 +231,8 @@ def replacement_op(self) -> OpOverload: class LinearPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.nn.Linear] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.linear.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -294,47 +268,9 @@ def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_linear.default -class LinearFunctionalPattern(QuantizationPattern): - def partition_types(self) -> List[Callable[..., torch.Tensor]]: - return [torch.nn.functional.linear] - - def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: - linear_node = fused_partition[0].nodes[-1] - - bias_qspec = DerivedQuantizationSpec( - derived_from=[ - (linear_node.args[0], linear_node), - (linear_node.args[1], linear_node), - ], - derive_qparams_fn=get_bias_qparams, - dtype=torch.int32, - quant_min=-(2**31), - quant_max=2**31 - 1, - qscheme=torch.per_tensor_affine, - ) - - # Keep bias empty if not supplied - bias = [] - if len(linear_node.args) > 2 and linear_node.args[2] is not None: - bias = [(linear_node, 2, bias_qspec)] - - return PartitionAnchors( - inputs=[(linear_node, 0)], - weights=[(linear_node, 1)], - # pyre-fixme[6]: Incompatible parameter type - biases=bias, - output=[(linear_node,)], - ) - - def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_linear.default - - class MatmulPattern(QuantizationPattern): - def partition_types(self) -> List[Callable[..., torch.Tensor]]: - return [torch.matmul] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.matmul.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -353,8 +289,8 @@ def replacement_op(self) -> OpOverload: class ReluPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.nn.ReLU] + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.relu.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 54d7640359..5a2c101512 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -14,15 +14,14 @@ BmmPattern, Conv1dPattern, Conv2dPattern, - LayerNormFunctionalPattern, LayerNormPattern, - LinearFunctionalPattern, LinearPattern, MatmulPattern, QuantizationPattern, ReluPattern, ) from executorch.backends.cadence.aot.quantizer.utils import ( + find_sequential_partitions_aten, is_annotated, no_outside_users, ) @@ -31,7 +30,6 @@ from torch import fx from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver -from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( @@ -63,7 +61,7 @@ bias_qspec: Optional[QuantizationSpec] = None -class CadenceGenericQuantizer(Quantizer): +class CadenceAtenQuantizer(Quantizer): def __init__( self, pattern: QuantizationPattern, quantization_config: QuantizationConfig ) -> None: @@ -72,7 +70,7 @@ def __init__( self.quantization_config = quantization_config def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - fused_partitions = find_sequential_partitions( + fused_partitions = find_sequential_partitions_aten( model, self.pattern.partition_types(), ) @@ -154,15 +152,13 @@ def __init__(self) -> None: ) super().__init__( [ - CadenceGenericQuantizer(AddmmPattern(), static_qconfig), - CadenceGenericQuantizer(BmmPattern(), static_qconfig), - CadenceGenericQuantizer(Conv1dPattern(), static_qconfig), - CadenceGenericQuantizer(Conv2dPattern(), static_qconfig), - CadenceGenericQuantizer(LayerNormPattern(), static_qconfig), - CadenceGenericQuantizer(LayerNormFunctionalPattern(), static_qconfig), - CadenceGenericQuantizer(LinearPattern(), static_qconfig), - CadenceGenericQuantizer(LinearFunctionalPattern(), static_qconfig), - CadenceGenericQuantizer(MatmulPattern(), static_qconfig), - CadenceGenericQuantizer(ReluPattern(), static_qconfig), + CadenceAtenQuantizer(AddmmPattern(), static_qconfig), + CadenceAtenQuantizer(BmmPattern(), static_qconfig), + CadenceAtenQuantizer(Conv1dPattern(), static_qconfig), + CadenceAtenQuantizer(Conv2dPattern(), static_qconfig), + CadenceAtenQuantizer(LayerNormPattern(), static_qconfig), + CadenceAtenQuantizer(LinearPattern(), static_qconfig), + CadenceAtenQuantizer(MatmulPattern(), static_qconfig), + CadenceAtenQuantizer(ReluPattern(), static_qconfig), ] ) diff --git a/backends/cadence/aot/quantizer/utils.py b/backends/cadence/aot/quantizer/utils.py index 21dac6b0b0..2afe5aba32 100644 --- a/backends/cadence/aot/quantizer/utils.py +++ b/backends/cadence/aot/quantizer/utils.py @@ -4,14 +4,23 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + +import itertools +from collections import OrderedDict from math import frexp, isclose, trunc -from typing import List, Tuple +from typing import Any, Dict, List, Tuple, Type import torch from torch import fx +from torch._ops import OpOverload from torch.ao.quantization import ObserverOrFakeQuantize from torch.fx import GraphModule +from torch.fx.passes.utils.source_matcher_utils import ( + check_subgraphs_connected, + SourcePartition, +) def quantize_tensor_multiplier( @@ -127,3 +136,101 @@ def get_bias_qparams( def get_conv_args(arg, first_val: int) -> List[fx.Node]: return arg if len(arg) == 2 else [first_val, arg[0]] + + +def get_aten_node_target_partitions( + graph: torch.fx.Graph, + wanted_original_aten_op: List[OpOverload], +): + """ + Args: + graph: The graph we want to partition + wanted_sources: List of orginal_aten ops (OpOverload) + + Returns: + Dictionary mapping aten ops that were given to a list of SourcePartitions + that correspond to the list of nodes that were decomposed from the given + aten ops. + """ + modules: Dict[Type, Dict[str, List[torch.fx.Node]]] = {} + + for node in graph.nodes: + # The metadata source_fn should contain a tuple of a unique name for the + # source, and the source function if the node is decomposed from a + # function, or the type of module if the node is decomposed from a leaf + # module + # TODO(matthiascremon): look into ways to avoid using source_fn_stack + if (source_fn_st := node.meta.get("source_fn_stack")) is None: + continue + + source_fn = source_fn_st[-1] + if node.target not in wanted_original_aten_op: + continue + + diff_modules = modules.setdefault(source_fn[1], {}) + partition = diff_modules.setdefault(node.name, []) + partition.append(node) + + def make_partition( + nodes: List[torch.fx.Node], module_type: Type + ) -> SourcePartition: + input_nodes = set() + output_nodes = set() + params = set() + for node in nodes: + for arg in node.args: + if isinstance(arg, torch.fx.Node) and arg not in nodes: + input_nodes.add(arg) + + if node.op == "get_attr": + params.add(node) + + for user in node.users.keys(): + if user not in nodes: + output_nodes.add(node) + + return SourcePartition( + nodes, + module_type, + list(input_nodes), + list(output_nodes), + list(params), # type: ignore[arg-type] + ) + + ret: Dict[Type[Any], List[SourcePartition]] = {} + + for k, v in modules.items(): + ret[k] = [make_partition(partition, k) for partition in v.values()] + + return ret + + +def _partitions_sequential(partitions: Tuple[SourcePartition]) -> bool: + prev_partition = None + for partition in partitions: + if prev_partition is not None and not check_subgraphs_connected( + prev_partition, partition + ): + return False + prev_partition = partition + return True + + +def find_sequential_partitions_aten( + gm: torch.fx.GraphModule, + partition_types: List[Any], +): + typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict() + for partition_type in partition_types: + partitions = get_aten_node_target_partitions(gm.graph, [partition_type]) + typed_partitions[partition_type] = list( + itertools.chain.from_iterable(partitions.values()) + ) + + typed_partitions_list = list(typed_partitions.values()) + fusion_candidates = itertools.product(*typed_partitions_list) + fused_partitions = [] + for candidate in fusion_candidates: + if _partitions_sequential(candidate): + fused_partitions.append(candidate) + return fused_partitions diff --git a/backends/cadence/aot/utils.py b/backends/cadence/aot/utils.py index 90ba68e538..f0c294260a 100644 --- a/backends/cadence/aot/utils.py +++ b/backends/cadence/aot/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import logging import operator from typing import Dict, List, Tuple @@ -116,7 +118,7 @@ def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]: def print_ops_info( to_edge_gm: torch.fx.GraphModule, jarvis_gm: torch.fx.GraphModule, -): +) -> None: to_edge_ops_count = get_ops_count(to_edge_gm) jarvis_ops_count = get_ops_count(jarvis_gm)