|
16 | 16 | CadencePassAttribute,
|
17 | 17 | register_cadence_pass,
|
18 | 18 | )
|
| 19 | +from executorch.exir import ExportedProgram |
19 | 20 |
|
20 | 21 | from executorch.exir.dialects._ops import ops as exir_ops
|
21 | 22 | from executorch.exir.pass_base import ExportPass, ProxyValue
|
@@ -109,6 +110,53 @@ def call_operator(self, op, args, kwargs, meta):
|
109 | 110 | return super().call_operator(op, new_args, kwargs, meta)
|
110 | 111 |
|
111 | 112 |
|
| 113 | +def FuseMulToQuantPass( |
| 114 | + exported_program: ExportedProgram, |
| 115 | +) -> ExportedProgram: |
| 116 | + """ |
| 117 | + If a mul op using a scalar constant input is followed by a quantize op, we can fuse the mul |
| 118 | + into the quantize op by updating its scale. Unfortunately, lifted constants are not stored in |
| 119 | + the nodes, so we need to find the constant value in the constants dict. |
| 120 | + """ |
| 121 | + graph = exported_program.graph_module.graph |
| 122 | + for node in graph.nodes: |
| 123 | + # We are only interested in mul ops |
| 124 | + if node.target != exir_ops.edge.aten.mul.Tensor: |
| 125 | + continue |
| 126 | + |
| 127 | + # Only applies if the following op is a quantize op |
| 128 | + user = list(node.users.keys())[0] |
| 129 | + if user.target not in ( |
| 130 | + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 131 | + exir_ops.edge.cadence.quantize_per_tensor.default, |
| 132 | + ): |
| 133 | + continue |
| 134 | + |
| 135 | + # Check that the second arg of the mul is a constant |
| 136 | + # Get the constant value |
| 137 | + if node.args[1].name not in exported_program.state_dict: |
| 138 | + continue |
| 139 | + |
| 140 | + tensor = exported_program.state_dict[node.args[1].name] |
| 141 | + |
| 142 | + args = list(user.args) |
| 143 | + |
| 144 | + # Update the scale of the quantize op |
| 145 | + args[0] = node.args[0] |
| 146 | + args[1] = user.args[1] / tensor.item() |
| 147 | + |
| 148 | + # Return the op with the updated args |
| 149 | + with graph.inserting_before(user): |
| 150 | + op_node = graph.call_function(user.target, args=tuple(args)) |
| 151 | + op_node.meta = node.meta |
| 152 | + user.replace_all_uses_with(op_node) |
| 153 | + |
| 154 | + exported_program.graph_module.recompile() |
| 155 | + exported_program.graph_module.graph.eliminate_dead_code() |
| 156 | + |
| 157 | + return exported_program |
| 158 | + |
| 159 | + |
112 | 160 | # This class encapsulates all the functions that simplify the op's args
|
113 | 161 | class CadenceSimplifyOpsInGraph:
|
114 | 162 | passes = [
|
|
0 commit comments