Skip to content

Commit

Permalink
Add support for quantized bmm (#4047)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4047

The current quantizer only captures "fake" bmm from matmuls with specific shapes. Add support for `torch.bmm` as well.

Use a decomposition for SDPA to make sure LLaMa bmms get quantized.

Differential Revision: D58959269

Reviewed By: zonglinpengmeta, hsharma35
  • Loading branch information
Matthias Cremon authored and facebook-github-bot committed Jul 10, 2024
1 parent 561c035 commit 238850b
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 4 deletions.
1 change: 1 addition & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ python_library(
"compiler.py",
],
deps = [
"fbsource//third-party/pypi/pyre-extensions:pyre-extensions",
":passes",
":utils",
"//caffe2:torch",
Expand Down
12 changes: 9 additions & 3 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
ReplaceSqueezeAndUnsqueezeWithViewPass,
)
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceGenericQuantizer,
CadenceQuantizer,
)
from executorch.backends.cadence.aot.utils import model_is_quantized
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from pyre_extensions import assert_is_instance
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.pt2e.export_utils import model_is_exported
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
Expand Down Expand Up @@ -53,8 +57,10 @@ def quantize_pt2(
converted_model = convert_pt2e(prepared_model)

# Get patterns and apply fusion of dq -> op -> q to qop
# pyre-fixme[16]: Pyre doesn't get that CadenceQuantizer has a patterns attribute
patterns = [q.pattern for q in quantizer.quantizers]
patterns = [
assert_is_instance(q, CadenceGenericQuantizer).pattern
for q in quantizer.quantizers
]
QuantFusion(patterns)(converted_model)

return converted_model
Expand Down
5 changes: 4 additions & 1 deletion backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from executorch.backends.cadence.aot.quantizer.patterns import (
AddmmPattern,
BmmPattern,
Conv1dPattern,
Conv2dPattern,
LayerNormFunctionalPattern,
Expand Down Expand Up @@ -396,7 +397,9 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
other_inputs,
quant_node,
)
elif isinstance(pattern, MatmulPattern):
elif isinstance(pattern, BmmPattern) or isinstance(
pattern, MatmulPattern
):
args, kwargs = get_args_and_kwargs_matmul(
inputs_inputs,
dequants_inputs,
Expand Down
20 changes: 20 additions & 0 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,26 @@ def replacement_op(self):
return torch.ops.cadence.quantized_linear


class BmmPattern(QuantizationPattern):
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
return [torch.bmm]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
bmm_node = fused_partition[0].nodes[-1]

return PartitionAnchors(
inputs=[(bmm_node, 0), (bmm_node, 1)],
weights=[],
biases=[],
output=[(bmm_node,)],
)

def replacement_op(self):
return torch.ops.cadence.quantized_matmul.default


class Conv1dPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.Conv1d]
Expand Down
2 changes: 2 additions & 0 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from executorch.backends.cadence.aot.quantizer.patterns import (
AddmmPattern,
BmmPattern,
Conv1dPattern,
Conv2dPattern,
LayerNormFunctionalPattern,
Expand Down Expand Up @@ -133,6 +134,7 @@ def __init__(self):
super().__init__(
[
CadenceGenericQuantizer(AddmmPattern(), static_qconfig),
CadenceGenericQuantizer(BmmPattern(), static_qconfig),
CadenceGenericQuantizer(Conv1dPattern(), static_qconfig),
CadenceGenericQuantizer(Conv2dPattern(), static_qconfig),
CadenceGenericQuantizer(LayerNormPattern(), static_qconfig),
Expand Down

0 comments on commit 238850b

Please sign in to comment.