Skip to content

Commit 50a3b55

Browse files
iliyan-georgiev-armkirklandsign
authored andcommitted
Arm backend: Placeholder processing now handles non-persistent buffers (#9994)
- Update process_placeholder and called functions to use _export.utils - Add test to validate non-persistent input buffers going forward - Remove workaround from test_llama Signed-off-by: Iliyan Georgiev <Iliyan.Georgiev@arm.com>
1 parent 8678c1d commit 50a3b55

File tree

3 files changed

+64
-30
lines changed

3 files changed

+64
-30
lines changed

backends/arm/process_node.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616
from executorch.backends.arm.tosa_specification import TosaSpecification
1717
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
18+
from torch._export.utils import (
19+
get_buffer,
20+
get_lifted_tensor_constant,
21+
get_param,
22+
is_buffer,
23+
is_lifted_tensor_constant,
24+
is_param,
25+
)
1826
from torch.export.exported_program import ExportedProgram
1927

2028

@@ -99,8 +107,7 @@ def process_inputs_to_parameters(
99107
f"Failed processing parameter placeholder: {node.name}. "
100108
"Is the original torch function supported?"
101109
) from e
102-
parameter_name = edge_program.graph_signature.inputs_to_parameters[tosa_arg.name]
103-
parameter_data = edge_program.state_dict[parameter_name]
110+
parameter_data = get_param(edge_program, node)
104111

105112
assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor"
106113
parameter_values = parameter_data.detach().numpy()
@@ -128,8 +135,7 @@ def process_inputs_to_buffers(
128135
f"Failed processing buffer placeholder: {node.name}. "
129136
"Is the original torch function supported?"
130137
) from e
131-
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
132-
buffer_data = edge_program.state_dict[buffer_name]
138+
buffer_data = get_buffer(edge_program, node)
133139

134140
assert isinstance(buffer_data, torch.Tensor), "Expect Attr to be tensor"
135141
buffer_values = buffer_data.detach().numpy()
@@ -156,11 +162,8 @@ def process_inputs_to_lifted_tensor_constants(
156162
f"Failed processing lifted tensor constant placeholder: {node.name}. "
157163
"Is the original torch function supported?"
158164
) from e
159-
tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[
160-
tosa_arg.name
161-
]
162-
tensor = edge_program.tensor_constants[tensor_name]
163-
tensor_data = tensor.detach().numpy()
165+
tensor = get_lifted_tensor_constant(edge_program, node)
166+
tensor_data = tensor.detach().numpy() # type: ignore[union-attr]
164167

165168
tosa_graph.addConst(
166169
tensor_data.shape, tosa_arg.dtype, tensor_data, name=tosa_arg.name
@@ -179,11 +182,11 @@ def process_placeholder(
179182

180183
if node.name in edge_program.graph_signature.user_inputs:
181184
process_inputs(node, tosa_graph, tosa_spec)
182-
elif node.name in edge_program.graph_signature.inputs_to_parameters:
185+
elif is_param(edge_program, node):
183186
process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec)
184-
elif node.name in edge_program.graph_signature.inputs_to_buffers:
187+
elif is_buffer(edge_program, node):
185188
process_inputs_to_buffers(node, tosa_graph, edge_program)
186-
elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants:
189+
elif is_lifted_tensor_constant(edge_program, node):
187190
process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program)
188191
elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs:
189192
raise NotImplementedError(
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
import torch.nn as nn
8+
9+
from executorch.backends.arm.test.common import parametrize
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
TosaPipelineBI,
12+
TosaPipelineMI,
13+
)
14+
15+
16+
class NonPersistentBuffer(nn.Module):
17+
"""
18+
Min code version registering a non-persistent input buffer.
19+
"""
20+
21+
def __init__(self):
22+
super().__init__()
23+
self.register_buffer("test_buff", torch.rand(2, 2, 2, 2), persistent=False)
24+
25+
def forward(self, x):
26+
return x - self.test_buff
27+
28+
29+
test_input = {"input": (torch.ones(2, 2, 2, 2),)}
30+
31+
input_t = tuple[torch.Tensor]
32+
33+
34+
@parametrize("test_data", test_input)
35+
def test_non_persistent_buffer_MI(test_data: input_t):
36+
"""
37+
Test validates Arm backend handling of non-persistent buffers
38+
and ensures that there are no asserts or errors when they are used.
39+
"""
40+
TosaPipelineMI[input_t](NonPersistentBuffer(), test_data, "").run()
41+
42+
43+
@parametrize("test_data", test_input)
44+
def test_non_persistent_buffer_BI(test_data: input_t):
45+
"""
46+
Test validates Arm backend handling of non-persistent buffers
47+
and ensures that there are no asserts or errors when they are used.
48+
"""
49+
TosaPipelineBI[input_t](NonPersistentBuffer(), test_data, "").run()

backends/arm/test/models/test_llama.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,6 @@ def prepare_model(self):
7979

8080
llama_model, llama_inputs, llama_meta = get_llama_model(args)
8181

82-
# TODO: Remove workaround since attention mask should not be persistent,
83-
# it only works if input shape is always the same
84-
freqs_c = "freqs_cos"
85-
freqs_s = "freqs_sin"
86-
for i in range(llama_model.n_layers):
87-
val = llama_model.layers[i].attention.get_buffer("mask")
88-
llama_model.layers[i].attention.register_buffer(
89-
"mask", val, persistent=True
90-
)
91-
val = llama_model.layers[i].attention.rope.get_buffer(freqs_c)
92-
llama_model.layers[i].attention.rope.register_buffer(
93-
freqs_c, val, persistent=True
94-
)
95-
val = llama_model.layers[i].attention.rope.get_buffer(freqs_s)
96-
llama_model.layers[i].attention.rope.register_buffer(
97-
freqs_s, val, persistent=True
98-
)
99-
10082
return llama_model, llama_inputs, llama_meta
10183

10284
def test_llama_tosa_MI(self):

0 commit comments

Comments
 (0)