Skip to content

Add a pass to fuse scalar mul with quant ops #10630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions backends/cadence/aot/simplify_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CadencePassAttribute,
register_cadence_pass,
)
from executorch.exir import ExportedProgram

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, ProxyValue
Expand Down Expand Up @@ -109,6 +110,53 @@ def call_operator(self, op, args, kwargs, meta):
return super().call_operator(op, new_args, kwargs, meta)


def FuseMulToQuantPass(
exported_program: ExportedProgram,
) -> ExportedProgram:
"""
If a mul op using a scalar constant input is followed by a quantize op, we can fuse the mul
into the quantize op by updating its scale. Unfortunately, lifted constants are not stored in
the nodes, so we need to find the constant value in the constants dict.
"""
graph = exported_program.graph_module.graph
for node in graph.nodes:
# We are only interested in mul ops
if node.target != exir_ops.edge.aten.mul.Tensor:
continue

# Only applies if the following op is a quantize op
user = list(node.users.keys())[0]
if user.target not in (
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.cadence.quantize_per_tensor.default,
):
continue

# Check that the second arg of the mul is a constant
# Get the constant value
if node.args[1].name not in exported_program.state_dict:
continue

tensor = exported_program.state_dict[node.args[1].name]

args = list(user.args)

# Update the scale of the quantize op
args[0] = node.args[0]
args[1] = user.args[1] / tensor.item()

# Return the op with the updated args
with graph.inserting_before(user):
op_node = graph.call_function(user.target, args=tuple(args))
op_node.meta = node.meta
user.replace_all_uses_with(op_node)

exported_program.graph_module.recompile()
exported_program.graph_module.graph.eliminate_dead_code()

return exported_program


# This class encapsulates all the functions that simplify the op's args
class CadenceSimplifyOpsInGraph:
passes = [
Expand Down
27 changes: 26 additions & 1 deletion backends/cadence/aot/tests/test_simplify_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
import torch
from executorch.backends.cadence.aot.compiler import export_to_edge
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
from executorch.backends.cadence.aot.simplify_ops import (
FuseMulToQuantPass,
SimplifySliceOpPass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from parameterized.parameterized import parameterized
from torch.fx.passes.infra.pass_base import PassResult
Expand Down Expand Up @@ -112,3 +115,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1
)


def test_fuse_mul_in_quant(self) -> None:
class FuseMulInQuant(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = x * 0.6
z = torch.ops.cadence.quantize_per_tensor(
y, 0.1, 0, -32768, 32767, torch.int16
)
return z

model = FuseMulInQuant()
inputs = (torch.randn(1, 4, 16),)

exported_program = export_to_edge(model, inputs).exported_program()
exported_program = FuseMulToQuantPass(exported_program)

# Assert that the mul op was removed
self.assertEqual(
count_node(exported_program.graph_module, exir_ops.edge.aten.mul.Tensor),
0,
)
Loading