Skip to content

Commit

Permalink
Add support for quantized bmm (#4047)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4047

The current quantizer only captures "fake" bmm from matmuls with specific shapes. Add support for `torch.bmm` as well

Reviewed By: dulinriley, zonglinpengmeta, hsharma35

Differential Revision: D58959269

fbshipit-source-id: cb36eede25047a144bb0334da1847ca4381f7929
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jul 12, 2024
1 parent e9aa542 commit cfbe63d
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 39 deletions.
1 change: 1 addition & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ python_library(
"compiler.py",
],
deps = [
"fbsource//third-party/pypi/pyre-extensions:pyre-extensions",
":passes",
":utils",
"//caffe2:torch",
Expand Down
12 changes: 9 additions & 3 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
ReplaceSqueezeAndUnsqueezeWithViewPass,
)
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceGenericQuantizer,
CadenceQuantizer,
)
from executorch.backends.cadence.aot.utils import model_is_quantized
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from pyre_extensions import assert_is_instance
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.pt2e.export_utils import model_is_exported
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
Expand Down Expand Up @@ -53,8 +57,10 @@ def quantize_pt2(
converted_model = convert_pt2e(prepared_model)

# Get patterns and apply fusion of dq -> op -> q to qop
# pyre-fixme[16]: Pyre doesn't get that CadenceQuantizer has a patterns attribute
patterns = [q.pattern for q in quantizer.quantizers]
patterns = [
assert_is_instance(q, CadenceGenericQuantizer).pattern
for q in quantizer.quantizers
]
QuantFusion(patterns)(converted_model)

return converted_model
Expand Down
3 changes: 3 additions & 0 deletions backends/cadence/aot/quantizer/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ python_library(
srcs = [
"patterns.py",
],
typing = True,
deps = [
":utils",
"//caffe2:torch",
Expand All @@ -28,7 +29,9 @@ python_library(
srcs = [
"quantizer.py",
],
typing = True,
deps = [
"fbsource//third-party/pypi/pyre-extensions:pyre-extensions",
":patterns",
":utils",
"//caffe2:torch",
Expand Down
7 changes: 3 additions & 4 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from executorch.backends.cadence.aot.quantizer.patterns import (
AddmmPattern,
BmmPattern,
Conv1dPattern,
Conv2dPattern,
LayerNormFunctionalPattern,
Expand Down Expand Up @@ -361,9 +362,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
inputs_inputs + weights_inputs + other_inputs + bias_inputs
)
kwargs = {}
if isinstance(pattern, Conv1dPattern) or isinstance(
pattern, Conv2dPattern
):
if isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
args, kwargs = get_args_and_kwargs_conv(
graph_module,
inputs_inputs,
Expand Down Expand Up @@ -396,7 +395,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
other_inputs,
quant_node,
)
elif isinstance(pattern, MatmulPattern):
elif isinstance(pattern, (BmmPattern, MatmulPattern)):
args, kwargs = get_args_and_kwargs_matmul(
inputs_inputs,
dequants_inputs,
Expand Down
69 changes: 50 additions & 19 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Tuple, Type, Union
from typing import Callable, List, Optional, Tuple, Type, Union

import torch
from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams

from torch import fx
from torch._ops import OpOverload
from torch.ao.quantization.quantizer import (
DerivedQuantizationSpec,
SharedQuantizationSpec,
Expand Down Expand Up @@ -44,18 +47,22 @@ class PartitionAnchors:

class QuantizationPattern(ABC):
@abstractmethod
def partition_types(self):
def partition_types(
self,
) -> Union[List[Type[torch.nn.Module]], List[Callable[..., torch.Tensor]]]:
"""
List of types to be passed to find_sequential_partitions.
"""
pass

@abstractmethod
def get_anchors(self, gm, fused_partition) -> Optional[PartitionAnchors]:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Optional[PartitionAnchors]:
pass

@abstractmethod
def replacement_op(self) -> Callable[..., Any]:
def replacement_op(self) -> OpOverload:
"""
Operator (most likely a custom one) that this partition should be fused into in
the backend. Refer to the QuantFusion pass for examples.
Expand Down Expand Up @@ -91,10 +98,30 @@ def get_anchors(
output=[(addmm_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear


class BmmPattern(QuantizationPattern):
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
return [torch.bmm]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
bmm_node = fused_partition[0].nodes[-1]

return PartitionAnchors(
inputs=[(bmm_node, 0), (bmm_node, 1)],
weights=[],
biases=[],
output=[(bmm_node,)],
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_matmul.default


class Conv1dPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.Conv1d]
Expand Down Expand Up @@ -129,7 +156,7 @@ def get_anchors(
output=[(conv1d_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv.default


Expand Down Expand Up @@ -167,15 +194,17 @@ def get_anchors(
output=[(conv2d_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv.default


class LayerNormPattern(QuantizationPattern):
def partition_types(self):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.LayerNorm]

def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
layer_norm_node = fused_partition[0].nodes[-1]

# Weights and biases are used as fp32 by our kernel, so they are
Expand All @@ -189,15 +218,17 @@ def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
output=[(layer_norm_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_layer_norm.default


class LayerNormFunctionalPattern(QuantizationPattern):
def partition_types(self):
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
return [torch.nn.functional.layer_norm]

def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
layer_norm_node = fused_partition[0].nodes[-1]

others = [(layer_norm_node, 1)]
Expand All @@ -221,7 +252,7 @@ def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
output=[(layer_norm_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_layer_norm.default


Expand Down Expand Up @@ -259,12 +290,12 @@ def get_anchors(
output=[(linear_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear.default


class LinearFunctionalPattern(QuantizationPattern):
def partition_types(self):
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
return [torch.nn.functional.linear]

def get_anchors(
Expand Down Expand Up @@ -297,12 +328,12 @@ def get_anchors(
output=[(linear_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear.default


class MatmulPattern(QuantizationPattern):
def partition_types(self):
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
return [torch.matmul]

def get_anchors(
Expand All @@ -317,7 +348,7 @@ def get_anchors(
output=[(matmul_node,)],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_matmul.default


Expand All @@ -339,5 +370,5 @@ def get_anchors(
],
)

def replacement_op(self):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_relu.default
49 changes: 36 additions & 13 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,35 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List
# pyre-strict

from typing import List, Optional, Tuple, Union

import torch
from executorch.backends.cadence.aot.quantizer.patterns import (
AddmmPattern,
BmmPattern,
Conv1dPattern,
Conv2dPattern,
LayerNormFunctionalPattern,
LayerNormPattern,
LinearFunctionalPattern,
LinearPattern,
MatmulPattern,
QuantizationPattern,
ReluPattern,
)
from executorch.backends.cadence.aot.quantizer.utils import (
is_annotated,
no_outside_users,
)
from pyre_extensions import assert_is_instance

from torch import fx

from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
OperatorConfig,
Expand Down Expand Up @@ -55,16 +60,18 @@
observer_or_fake_quant_ctr=MinMaxObserver,
)

bias_qspec = None
bias_qspec: Optional[QuantizationSpec] = None


class CadenceGenericQuantizer(Quantizer):
def __init__(self, pattern, quantization_config):
def __init__(
self, pattern: QuantizationPattern, quantization_config: QuantizationConfig
) -> None:
super().__init__()
self.pattern = pattern
self.quantization_config = quantization_config

def annotate(self, model):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
fused_partitions = find_sequential_partitions(
model,
self.pattern.partition_types(),
Expand Down Expand Up @@ -94,25 +101,40 @@ def annotate(self, model):
continue

for output, *custom_spec in anchors.output:
output.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=custom_spec[0] if custom_spec else output_act_qspec,
_annotated=True,
assert_is_instance(output, fx.Node).meta["quantization_annotation"] = (
QuantizationAnnotation(
# pyre-ignore[6]: incompatible parameter type
output_qspec=(
custom_spec[0] if custom_spec else output_act_qspec
),
_annotated=True,
)
)

def annotate_inputs(inputs, spec):
def annotate_inputs(
inputs: Union[
List[Tuple[fx.Node, int]],
List[Tuple[fx.Node, int, DerivedQuantizationSpec],],
],
spec: Optional[QuantizationSpec],
) -> None:
for node, idx, *custom_spec in inputs:
annotation = node.meta.get(
_node = assert_is_instance(node, fx.Node)
annotation = _node.meta.get(
"quantization_annotation",
QuantizationAnnotation(_annotated=True),
)
annotation.input_qspec_map[node.args[idx]] = (
# pyre-ignore[6]: incompatible parameter type
annotation.input_qspec_map[_node.args[idx]] = (
custom_spec[0] if custom_spec else spec
)
node.meta["quantization_annotation"] = annotation
_node.meta["quantization_annotation"] = annotation

annotate_inputs(anchors.inputs, input_act_qspec)
annotate_inputs(anchors.weights, weight_qspec)
# pyre-ignore[6]: incompatible parameter type
annotate_inputs(anchors.biases, bias_qspec)
return model

def validate(self, model: fx.GraphModule) -> None:
pass
Expand All @@ -123,7 +145,7 @@ def get_supported_operators(cls) -> List[OperatorConfig]:


class CadenceQuantizer(ComposableQuantizer):
def __init__(self):
def __init__(self) -> None:
static_qconfig = QuantizationConfig(
act_qspec,
act_qspec,
Expand All @@ -133,6 +155,7 @@ def __init__(self):
super().__init__(
[
CadenceGenericQuantizer(AddmmPattern(), static_qconfig),
CadenceGenericQuantizer(BmmPattern(), static_qconfig),
CadenceGenericQuantizer(Conv1dPattern(), static_qconfig),
CadenceGenericQuantizer(Conv2dPattern(), static_qconfig),
CadenceGenericQuantizer(LayerNormPattern(), static_qconfig),
Expand Down

0 comments on commit cfbe63d

Please sign in to comment.