Skip to content

Commit 27c48b9

Browse files
NXP backend: Replace move relu before concat optimization (#15394)
### Summary This PR replaces optimization in `move_relu_before_concat.py` by `MoveActivationBeforeConcat` aten pass. The pass moves selected activations that are supported for fusion on Neutron (Relu, Relu6, Sigmoid, Tanh) before the `concat` node if concat input nodes are either Conv 2D or Linear 2D. The whole node Logic is determined by target specs, now supporting Neutron-C. Tests updated. ### Test plan Unit tests provided (test_move_activation_before_concatenation.py). cc @robert-kalmar
1 parent acf5b4b commit 27c48b9

12 files changed

+1283
-161
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import torch
8+
9+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
10+
11+
from torch.fx import GraphModule, Node
12+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
13+
14+
15+
class MoveActivationBeforeConcat(PassBase):
16+
"""Move some operators around in the following pattern.
17+
This is a common pattern that emerges from the conversion of separable convolutions.
18+
This optimization works together with joint quantization of compute nodes and activations. Without it,
19+
it is not beneficial.
20+
21+
│ │ │ │
22+
┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐
23+
│ aten.conv2d │ ... │ aten.conv2d │ │ aten.conv2d │ ... │ aten.conv2d │
24+
└──────┬──────┘ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘
25+
└───────┐ ┌──────┘ │ │
26+
┌──▼─────▼─┐ replace with ┌─────▼─────┐ ┌─────▼─────┐
27+
│ aten.cat │ ──────────────► │ aten.relu │ ... │ aten.relu │
28+
└────┬─────┘ └─────┬─────┘ └─────┬─────┘
29+
│ └───────┐ ┌───────┘
30+
┌─────▼─────┐ ┌──▼─────▼─┐
31+
│ aten.relu │ │ aten.cat │
32+
└─────┬─────┘ └────┬─────┘
33+
│ │
34+
"""
35+
36+
def __init__(self, neutron_target_spec: NeutronTargetSpec):
37+
self.neutron_target_spec = neutron_target_spec
38+
39+
def call(self, module: GraphModule) -> bool:
40+
def _is_concat(node_: Node) -> bool:
41+
return (
42+
node_.op == "call_function"
43+
and node_.target == torch.ops.aten.cat.default
44+
)
45+
46+
made_changes = False
47+
48+
for node in module.graph.nodes:
49+
if not _is_concat(node):
50+
continue # Not cat node.
51+
52+
cat_node = node
53+
activation = next(iter(cat_node.users))
54+
55+
# Check if all cat inputs nodes are conv 2D or linear 2D type and their only user is cat.
56+
if not all(
57+
self.neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten(
58+
input_node
59+
)
60+
and len(input_node.users) == 1
61+
for input_node in cat_node.all_input_nodes
62+
):
63+
continue
64+
65+
# Check if following activation is supported on Neutron as fused activation.
66+
if not (
67+
len(cat_node.users) == 1
68+
and self.neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten(
69+
activation
70+
)
71+
):
72+
continue
73+
74+
# Loop all Cat input nodes and insert new activation after node.
75+
for input_node in cat_node.all_input_nodes:
76+
with module.graph.inserting_after(input_node):
77+
new_activation = module.graph.call_function(
78+
activation.target,
79+
args=(),
80+
kwargs=activation.kwargs,
81+
)
82+
83+
new_activation.meta["source_fn_stack"] = [
84+
(
85+
new_activation.name,
86+
activation.meta["source_fn_stack"][-1][-1],
87+
)
88+
]
89+
new_activation.meta["val"] = input_node.meta["val"]
90+
91+
# Replace the uses of the input node with the new activation node.
92+
input_node.replace_all_uses_with(new_activation)
93+
new_activation.args = (input_node, *activation.args[1:])
94+
95+
# Replace the uses of the activation node with the cat node.
96+
activation.replace_all_uses_with(cat_node)
97+
98+
module.graph.erase_node(activation)
99+
100+
made_changes = True
101+
102+
return PassResult(module, made_changes)

backends/nxp/aten_passes/neutron_aten_pass_manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from executorch.backends.nxp.aten_passes.fuse_linear_and_add_pass import (
1717
FuseLinearAndAddPass,
1818
)
19+
from executorch.backends.nxp.aten_passes.move_activation_before_concat import (
20+
MoveActivationBeforeConcat,
21+
)
1922
from executorch.backends.nxp.aten_passes.remove_nodes_with_known_outputs import (
2023
RemoveNodesWithKnownOutputs,
2124
)
@@ -25,6 +28,7 @@
2528
from executorch.backends.nxp.aten_passes.split_gru_based_on_num_layers import (
2629
SplitGRUBasedOnNumLayers,
2730
)
31+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
2832
from executorch.exir.pass_manager import PassManager
2933
from torch import nn
3034
from torch.fx.passes.infra.pass_base import PassResult
@@ -34,14 +38,17 @@
3438

3539
class NeutronAtenPassManager(PassManager):
3640

37-
def __init__(self, passes: list[PassType] = None):
41+
def __init__(
42+
self, neutron_target_spec: NeutronTargetSpec, passes: list[PassType] = None
43+
):
3844
passes: list[PassType] = passes or [
3945
FuseBatchNormWithConvPass(),
4046
FuseBatchNormWithLinearPass(),
4147
SplitGroupConvolution(),
4248
SplitGRUBasedOnNumLayers(),
4349
RemoveNodesWithKnownOutputs(),
4450
FuseLinearAndAddPass(),
51+
MoveActivationBeforeConcat(neutron_target_spec),
4552
]
4653

4754
super().__init__(passes)

backends/nxp/backend/ir/tflite_optimizer/optimizations/move_relu_before_concat.py

Lines changed: 0 additions & 107 deletions
This file was deleted.

backends/nxp/backend/ir/tflite_optimizer/optimizer.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111

1212
from executorch.backends.nxp.backend.ir import logger
1313
from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig
14-
from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.move_relu_before_concat import (
15-
MoveActivationBeforeConcatenation,
16-
)
1714
from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.permute_fully_connected_weights_after_reshape import (
1815
PermuteFullyConnectedWeightsAfterReshape,
1916
)
@@ -29,8 +26,6 @@ class Optimization(Enum):
2926

3027
PERMUTE_FULLY_CONNECTED_WEIGHTS_AFTER_RESHAPE = 12
3128

32-
MOVE_ACTIVATION_BEFORE_CONCAT = 15
33-
3429

3530
class Optimizer:
3631
"""
@@ -68,9 +63,6 @@ def __init__(
6863
Optimization.PERMUTE_FULLY_CONNECTED_WEIGHTS_AFTER_RESHAPE: PermuteFullyConnectedWeightsAfterReshape(
6964
builder, conversion_config
7065
),
71-
Optimization.MOVE_ACTIVATION_BEFORE_CONCAT: MoveActivationBeforeConcatenation(
72-
builder, conversion_config
73-
),
7466
}
7567

7668
def optimize(

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
1313
from executorch.backends.nxp.quantizer.patterns import (
1414
AbsPattern,
15+
ActivationsConcatClusterPattern,
1516
AdaptiveAvgPoolPattern,
1617
AddmmPattern,
1718
AddTensorPattern,
@@ -225,13 +226,16 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec):
225226
self.op_to_applied_quantizer = {
226227
pt: False for q in self.quantizers for pt in q.pattern.partition_types()
227228
}
229+
self.cluster_quantizers = [
230+
NeutronAtenQuantizer(ActivationsConcatClusterPattern(self), static_qconfig)
231+
]
228232

229233
def transform_for_annotation(
230234
self, model: torch.fx.GraphModule
231235
) -> torch.fx.GraphModule:
232236
model.graph.eliminate_dead_code() # Remove dead code to simplify the graph for the passes.
233237

234-
model = NeutronAtenPassManager()(model).graph_module
238+
model = NeutronAtenPassManager(self.neutron_target_spec)(model).graph_module
235239

236240
model.graph.eliminate_dead_code() # Remove dead code again, in case it was created by the passes.
237241

@@ -240,6 +244,10 @@ def transform_for_annotation(
240244
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
241245
self._annotate_inputs(model)
242246

247+
# Annotate node clusters in model
248+
for cluster_quantizer in self.cluster_quantizers:
249+
cluster_quantizer.annotate(model)
250+
243251
nodes = list(model.graph.nodes)
244252
for node in nodes:
245253
if (

0 commit comments

Comments
 (0)