Skip to content

Commit e736798

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add a pass to fuse scalar mul with quant ops
Summary: As titled. When a mul op is followed by a quant op, and the mul op uses a scalar known at compile time, it can be folded into the scale of the quant op. Reviewed By: zonglinpeng Differential Revision: D72617001
1 parent e500d87 commit e736798

File tree

2 files changed

+74
-1
lines changed

2 files changed

+74
-1
lines changed

backends/cadence/aot/simplify_ops.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
CadencePassAttribute,
1717
register_cadence_pass,
1818
)
19+
from executorch.exir import ExportedProgram
1920

2021
from executorch.exir.dialects._ops import ops as exir_ops
2122
from executorch.exir.pass_base import ExportPass, ProxyValue
@@ -109,6 +110,53 @@ def call_operator(self, op, args, kwargs, meta):
109110
return super().call_operator(op, new_args, kwargs, meta)
110111

111112

113+
def FuseMulToQuantPass(
114+
exported_program: ExportedProgram,
115+
) -> ExportedProgram:
116+
"""
117+
If a mul op using a scalar constant input is followed by a quantize op, we can fuse the mul
118+
into the quantize op by updating its scale. Unfortunately, lifted constants are not stored in
119+
the nodes, so we need to find the constant value in the constants dict.
120+
"""
121+
graph = exported_program.graph_module.graph
122+
for node in graph.nodes:
123+
# We are only interested in mul ops
124+
if node.target != exir_ops.edge.aten.mul.Tensor:
125+
continue
126+
127+
# Only applies if the following op is a quantize op
128+
user = list(node.users.keys())[0]
129+
if user.target not in (
130+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
131+
exir_ops.edge.cadence.quantize_per_tensor.default,
132+
):
133+
continue
134+
135+
# Check that the second arg of the mul is a constant
136+
# Get the constant value
137+
if node.args[1].name not in exported_program.state_dict:
138+
continue
139+
140+
tensor = exported_program.state_dict[node.args[1].name]
141+
142+
args = list(user.args)
143+
144+
# Update the scale of the quantize op
145+
args[0] = node.args[0]
146+
args[1] = user.args[1] / tensor.item()
147+
148+
# Return the op with the updated args
149+
with graph.inserting_before(user):
150+
op_node = graph.call_function(user.target, args=tuple(args))
151+
op_node.meta = node.meta
152+
user.replace_all_uses_with(op_node)
153+
154+
exported_program.graph_module.recompile()
155+
exported_program.graph_module.graph.eliminate_dead_code()
156+
157+
return exported_program
158+
159+
112160
# This class encapsulates all the functions that simplify the op's args
113161
class CadenceSimplifyOpsInGraph:
114162
passes = [

backends/cadence/aot/tests/test_simplify_ops_passes.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
import torch
1515
from executorch.backends.cadence.aot.compiler import export_to_edge
1616
from executorch.backends.cadence.aot.pass_utils import count_node
17-
from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
17+
from executorch.backends.cadence.aot.simplify_ops import (
18+
FuseMulToQuantPass,
19+
SimplifySliceOpPass,
20+
)
1821
from executorch.exir.dialects._ops import ops as exir_ops
1922
from parameterized.parameterized import parameterized
2023
from torch.fx.passes.infra.pass_base import PassResult
@@ -112,3 +115,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
112115
self.assertEqual(
113116
count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1
114117
)
118+
119+
120+
def test_fuse_mul_in_quant(self) -> None:
121+
class FuseMulInQuant(torch.nn.Module):
122+
def forward(self, x: torch.Tensor) -> torch.Tensor:
123+
y = x * 0.6
124+
z = torch.ops.cadence.quantize_per_tensor(
125+
y, 0.1, 0, -32768, 32767, torch.int16
126+
)
127+
return z
128+
129+
model = FuseMulInQuant()
130+
inputs = (torch.randn(1, 4, 16),)
131+
132+
exported_program = export_to_edge(model, inputs).exported_program()
133+
exported_program = FuseMulToQuantPass(exported_program)
134+
135+
# Assert that the mul op was removed
136+
self.assertEqual(
137+
count_node(exported_program.graph_module, exir_ops.edge.aten.mul.Tensor),
138+
0,
139+
)

0 commit comments

Comments
 (0)