Skip to content

Commit

Permalink
VAD perf improvements - Improve aten::full performance (#3974)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #3974

Reviewed By: dulinriley, zonglinpengmeta

Differential Revision: D56524303

fbshipit-source-id: 5bdeb29d473446d0350d397e48eb70eeeaed16d3
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jun 21, 2024
1 parent 5e22836 commit 78688b7
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 16 deletions.
57 changes: 48 additions & 9 deletions backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,34 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Tuple

import torch
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, ProxyValue
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
from torch._subclasses import FakeTensor
from torch.utils._pytree import tree_map_only


# pyre-strict

# Similar to what's done in executorch/exir/pass_base.py
Argument = Any # pyre-ignore


class ReplacePT2QuantWithCadenceQuantPass(ExportPass):
"""
Replace the pt2 quantization ops with custom cadence quantization ops.
"""

def call_operator(self, op, args, kwargs, meta):
def call_operator(
self,
op, # pyre-ignore
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {exir_ops.edge.quantized_decomposed.quantize_per_tensor.default}:
return super().call_operator(op, args, kwargs, meta)

Expand All @@ -34,7 +48,13 @@ class ReplacePT2DequantWithCadenceDequantPass(ExportPass):
Replace the pt2 dequantization ops with custom cadence dequantization ops.
"""

def call_operator(self, op, args, kwargs, meta):
def call_operator(
self,
op, # pyre-ignore
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default}:
return super().call_operator(op, args, kwargs, meta)

Expand All @@ -51,7 +71,13 @@ class ReplaceScalarTensorWithFullPass(ExportPass):
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
"""

def call_operator(self, op, args, kwargs, meta):
def call_operator(
self,
op, # pyre-ignore
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {
exir_ops.edge.aten.scalar_tensor.default,
torch.ops.aten.scalar_tensor.default,
Expand All @@ -64,7 +90,7 @@ def call_operator(self, op, args, kwargs, meta):
[1],
args[0],
),
{},
{"dtype": torch.float32},
meta,
)

Expand All @@ -75,7 +101,13 @@ class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass):
view_copy op
"""

def call_operator(self, op, args, kwargs, meta):
def call_operator(
self,
op, # pyre-ignore
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
# Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket,
# which allows us to cover all overloads.
if get_edge_overload_packet(op) not in {
Expand All @@ -99,7 +131,13 @@ def call_operator(self, op, args, kwargs, meta):


class RemoveZeroSizedCatArgsPass(ExportPass):
def call_operator(self, op, args, kwargs, meta):
def call_operator(
self,
op, # pyre-ignore
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op != exir_ops.edge.aten.cat.default:
return super().call_operator(op, args, kwargs, meta)

Expand All @@ -122,6 +160,7 @@ def call_operator(self, op, args, kwargs, meta):
# TODO(matthiascremon): confirm this is the best way to do this.
if isinstance(result, FakeTensor):
result.constant = result
# pyre-ignore[7]: Incompatible return type.
return torch.empty_like(result)

# If there was only one tensor in the new_args list,
Expand All @@ -130,7 +169,7 @@ def call_operator(self, op, args, kwargs, meta):
return new_args[0]

# Otherwise, we replace args[0] with new_args.
args = list(args)
args[0] = new_args
init_args = list(args)
init_args[0] = new_args
args = tuple(args)
return super().call_operator(op, args, kwargs, meta)
27 changes: 20 additions & 7 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, Dict, List, Tuple

import torch
Expand Down Expand Up @@ -31,6 +33,11 @@
from torch.fx.passes.utils.fuser_utils import legalize_graph


# Use this to avoid pyre errors
# pyre-ignore[33]: `_ModelInputsType` cannot alias to `Any`.
ArgsType = Any


# Helper function to get the args and kwargs for the linear replacement op
def get_args_and_kwargs_linear(
graph_module: GraphModule,
Expand All @@ -40,7 +47,7 @@ def get_args_and_kwargs_linear(
dequants_weights: List[fx.Node],
bias_inputs: List[fx.Node],
quant_node: fx.Node,
) -> Tuple[Tuple[Any], Dict[str, Any]]:
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
"""
Returns the args and kwargs for the linear replacement op.
"""
Expand Down Expand Up @@ -98,7 +105,7 @@ def get_args_and_kwargs_layer_norm(
dequants_inputs: List[fx.Node],
other_inputs: List[fx.Node],
quant_node: fx.Node,
) -> Tuple[Tuple[Any], Dict[str, Any]]:
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
"""
Returns the args and kwargs for the layer norm replacement op.
"""
Expand Down Expand Up @@ -167,7 +174,7 @@ def get_args_and_kwargs_matmul(
inputs_inputs: List[fx.Node],
dequants_inputs: List[fx.Node],
quant_node: fx.Node,
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
requantize_scale = (
# pyre-ignore[58]: Unsupported operand
dequants_inputs[0].args[1]
Expand Down Expand Up @@ -203,7 +210,7 @@ def get_args_and_kwargs_conv(
bias_inputs: List[fx.Node],
quant_node: fx.Node,
op_node: fx.Node,
):
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
weight_scale = dequants_weights[0].args[1]
weight_zero_point = dequants_weights[0].args[2]
# pyre-fixme[58]: Unsupported operand types
Expand Down Expand Up @@ -277,12 +284,14 @@ def get_args_and_kwargs_relu(
graph_module: GraphModule,
inputs_inputs: List[fx.Node],
dequants_inputs: List[fx.Node],
):
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
# Make the args and kwargs for the replacement op
args = tuple(inputs_inputs)

X_zero_point = graph_module.graph.call_function(
torch.ops.aten.full.default, ([1], dequants_inputs[0].args[2])
torch.ops.aten.full.default,
([1], dequants_inputs[0].args[2]),
{"dtype": torch.int32},
)

kwargs = {
Expand All @@ -292,8 +301,10 @@ def get_args_and_kwargs_relu(


class QuantFusion(ExportPass):
def __init__(self, patterns):
# pyre-ignore[2]: Parameter `patterns` has no type specified
def __init__(self, patterns) -> None:
super().__init__()
# pyre-ignore[4]: Parameter `patterns` of class `QuantFusion` has no type specified
self.patterns = patterns

def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
Expand Down Expand Up @@ -427,10 +438,12 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
graph_module.recompile()

@classmethod
# pyre-ignore[2]: Parameter `nodes` has no type specified
def is_fused(cls, nodes) -> bool:
return any(cls.__qualname__ in n.meta for n in nodes)

@classmethod
# pyre-ignore[2]: Parameter `nodes` has no type specified
def mark_fused(cls, nodes) -> bool:
for n in nodes:
# pyre-fixme[7]: Incompatible return type
Expand Down

0 comments on commit 78688b7

Please sign in to comment.