15
15
from executorch .backends .arm .tosa_mapping import TosaArg
16
16
from executorch .backends .arm .tosa_specification import TosaSpecification
17
17
from executorch .backends .arm .tosa_utils import getNodeArgs , tosa_shape
18
+ from torch ._export .utils import (
19
+ get_buffer ,
20
+ get_lifted_tensor_constant ,
21
+ get_param ,
22
+ is_buffer ,
23
+ is_lifted_tensor_constant ,
24
+ is_param ,
25
+ )
18
26
from torch .export .exported_program import ExportedProgram
19
27
20
28
@@ -99,8 +107,7 @@ def process_inputs_to_parameters(
99
107
f"Failed processing parameter placeholder: { node .name } . "
100
108
"Is the original torch function supported?"
101
109
) from e
102
- parameter_name = edge_program .graph_signature .inputs_to_parameters [tosa_arg .name ]
103
- parameter_data = edge_program .state_dict [parameter_name ]
110
+ parameter_data = get_param (edge_program , node )
104
111
105
112
assert isinstance (parameter_data , torch .Tensor ), "Expect Attr to be tensor"
106
113
parameter_values = parameter_data .detach ().numpy ()
@@ -128,8 +135,7 @@ def process_inputs_to_buffers(
128
135
f"Failed processing buffer placeholder: { node .name } . "
129
136
"Is the original torch function supported?"
130
137
) from e
131
- buffer_name = edge_program .graph_signature .inputs_to_buffers [node .name ]
132
- buffer_data = edge_program .state_dict [buffer_name ]
138
+ buffer_data = get_buffer (edge_program , node )
133
139
134
140
assert isinstance (buffer_data , torch .Tensor ), "Expect Attr to be tensor"
135
141
buffer_values = buffer_data .detach ().numpy ()
@@ -156,11 +162,8 @@ def process_inputs_to_lifted_tensor_constants(
156
162
f"Failed processing lifted tensor constant placeholder: { node .name } . "
157
163
"Is the original torch function supported?"
158
164
) from e
159
- tensor_name = edge_program .graph_signature .inputs_to_lifted_tensor_constants [
160
- tosa_arg .name
161
- ]
162
- tensor = edge_program .tensor_constants [tensor_name ]
163
- tensor_data = tensor .detach ().numpy ()
165
+ tensor = get_lifted_tensor_constant (edge_program , node )
166
+ tensor_data = tensor .detach ().numpy () # type: ignore[union-attr]
164
167
165
168
tosa_graph .addConst (
166
169
tensor_data .shape , tosa_arg .dtype , tensor_data , name = tosa_arg .name
@@ -179,11 +182,11 @@ def process_placeholder(
179
182
180
183
if node .name in edge_program .graph_signature .user_inputs :
181
184
process_inputs (node , tosa_graph , tosa_spec )
182
- elif node . name in edge_program . graph_signature . inputs_to_parameters :
185
+ elif is_param ( edge_program , node ) :
183
186
process_inputs_to_parameters (node , tosa_graph , edge_program , tosa_spec )
184
- elif node . name in edge_program . graph_signature . inputs_to_buffers :
187
+ elif is_buffer ( edge_program , node ) :
185
188
process_inputs_to_buffers (node , tosa_graph , edge_program )
186
- elif node . name in edge_program . graph_signature . inputs_to_lifted_tensor_constants :
189
+ elif is_lifted_tensor_constant ( edge_program , node ) :
187
190
process_inputs_to_lifted_tensor_constants (node , tosa_graph , edge_program )
188
191
elif node .name in edge_program .graph_signature .inputs_to_lifted_custom_objs :
189
192
raise NotImplementedError (
0 commit comments