9
9
10
10
import argparse
11
11
import logging
12
+ import os
12
13
13
14
import torch
14
15
@@ -48,6 +49,19 @@ def get_model_and_inputs_from_name(model_name: str):
48
49
model , example_inputs , _ = EagerModelFactory .create_model (
49
50
* MODEL_NAME_TO_MODEL [model_name ]
50
51
)
52
+ # Case 3: Model is in an external python file loaded as a module.
53
+ # ModelUnderTest should be a torch.nn.module instance
54
+ # ModelInputs should be a tuple of inputs to the forward function
55
+ elif model_name .endswith (".py" ):
56
+ import importlib .util
57
+
58
+ # load model's module and add it
59
+ spec = importlib .util .spec_from_file_location ("tmp_model" , model_name )
60
+ module = importlib .util .module_from_spec (spec )
61
+ spec .loader .exec_module (module )
62
+ model = module .ModelUnderTest
63
+ example_inputs = module .ModelInputs
64
+
51
65
else :
52
66
raise RuntimeError (
53
67
f"Model '{ model_name } ' is not a valid name. Use --help for a list of available models."
@@ -133,7 +147,51 @@ def forward(self, x):
133
147
"softmax" : SoftmaxModule ,
134
148
}
135
149
136
- if __name__ == "__main__" :
150
+ targets = [
151
+ "ethos-u85-128" ,
152
+ "ethos-u55-128" ,
153
+ "TOSA" ,
154
+ ]
155
+
156
+
157
+ def get_compile_spec (target : str , intermediates : bool ) -> ArmCompileSpecBuilder :
158
+ spec_builder = None
159
+ if target == "TOSA" :
160
+ spec_builder = (
161
+ ArmCompileSpecBuilder ().tosa_compile_spec ().set_permute_memory_format (True )
162
+ )
163
+ elif target == "ethos-u55-128" :
164
+ spec_builder = (
165
+ ArmCompileSpecBuilder ()
166
+ .ethosu_compile_spec (
167
+ "ethos-u55-128" ,
168
+ system_config = "Ethos_U55_High_End_Embedded" ,
169
+ memory_mode = "Shared_Sram" ,
170
+ extra_flags = "--debug-force-regor --output-format=raw" ,
171
+ )
172
+ .set_permute_memory_format (args .model_name in MODEL_NAME_TO_MODEL .keys ())
173
+ .set_quantize_io (True )
174
+ )
175
+ elif target == "ethos-u85-128" :
176
+ spec_builder = (
177
+ ArmCompileSpecBuilder ()
178
+ .ethosu_compile_spec (
179
+ "ethos-u85-128" ,
180
+ system_config = "Ethos_U85_SYS_DRAM_Mid" ,
181
+ memory_mode = "Shared_Sram" ,
182
+ extra_flags = "--output-format=raw" ,
183
+ )
184
+ .set_permute_memory_format (True )
185
+ .set_quantize_io (True )
186
+ )
187
+
188
+ if intermediates is not None :
189
+ spec_builder .dump_intermediate_artifacts_to (args .intermediates )
190
+
191
+ return spec_builder .build ()
192
+
193
+
194
+ def get_args ():
137
195
parser = argparse .ArgumentParser ()
138
196
parser .add_argument (
139
197
"-m" ,
@@ -149,6 +207,15 @@ def forward(self, x):
149
207
default = False ,
150
208
help = "Flag for producing ArmBackend delegated model" ,
151
209
)
210
+ parser .add_argument (
211
+ "-t" ,
212
+ "--target" ,
213
+ action = "store" ,
214
+ required = False ,
215
+ default = "ethos-u55-128" ,
216
+ choices = targets ,
217
+ help = f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are { targets } " ,
218
+ )
152
219
parser .add_argument (
153
220
"-q" ,
154
221
"--quantize" ,
@@ -167,8 +234,26 @@ def forward(self, x):
167
234
parser .add_argument (
168
235
"--debug" , action = "store_true" , help = "Set the logging level to debug."
169
236
)
170
-
237
+ parser .add_argument (
238
+ "-i" ,
239
+ "--intermediates" ,
240
+ action = "store" ,
241
+ required = False ,
242
+ help = "Store intermediate output (like TOSA artefacts) somewhere." ,
243
+ )
244
+ parser .add_argument (
245
+ "-o" ,
246
+ "--output" ,
247
+ action = "store" ,
248
+ required = False ,
249
+ help = "Location for outputs, if not the default of cwd." ,
250
+ )
171
251
args = parser .parse_args ()
252
+ return args
253
+
254
+
255
+ if __name__ == "__main__" :
256
+ args = get_args ()
172
257
173
258
if args .debug :
174
259
logging .basicConfig (level = logging .DEBUG , format = FORMAT , force = True )
@@ -191,7 +276,7 @@ def forward(self, x):
191
276
):
192
277
raise RuntimeError (f"Model { args .model_name } cannot be delegated." )
193
278
194
- # 1. pick model from one of the supported lists
279
+ # Pick model from one of the supported lists
195
280
model , example_inputs = get_model_and_inputs_from_name (args .model_name )
196
281
model = model .eval ()
197
282
@@ -209,23 +294,18 @@ def forward(self, x):
209
294
_check_ir_validity = False ,
210
295
),
211
296
)
297
+
298
+ # As we can target multiple output encodings from ArmBackend, one must
299
+ # be specified.
300
+ compile_spec = (
301
+ get_compile_spec (args .target , args .intermediates )
302
+ if args .delegate is True
303
+ else None
304
+ )
305
+
212
306
logging .debug (f"Exported graph:\n { edge .exported_program ().graph } " )
213
307
if args .delegate is True :
214
- edge = edge .to_backend (
215
- ArmPartitioner (
216
- ArmCompileSpecBuilder ()
217
- .ethosu_compile_spec (
218
- "ethos-u55-128" ,
219
- system_config = "Ethos_U55_High_End_Embedded" ,
220
- memory_mode = "Shared_Sram" ,
221
- )
222
- .set_permute_memory_format (
223
- args .model_name in MODEL_NAME_TO_MODEL .keys ()
224
- )
225
- .set_quantize_io (True )
226
- .build ()
227
- )
228
- )
308
+ edge = edge .to_backend (ArmPartitioner (compile_spec ))
229
309
logging .debug (f"Lowered graph:\n { edge .exported_program ().graph } " )
230
310
231
311
try :
@@ -241,7 +321,12 @@ def forward(self, x):
241
321
else :
242
322
raise e
243
323
244
- model_name = f"{ args .model_name } " + (
245
- "_arm_delegate" if args .delegate is True else ""
324
+ model_name = os .path .basename (os .path .splitext (args .model_name )[0 ])
325
+ output_name = f"{ model_name } " + (
326
+ f"_arm_delegate_{ args .target } " if args .delegate is True else ""
246
327
)
247
- save_pte_program (exec_prog , model_name )
328
+
329
+ if args .output is not None :
330
+ output_name = os .path .join (args .output , output_name )
331
+
332
+ save_pte_program (exec_prog , output_name )
0 commit comments