@@ -154,13 +154,14 @@ def forward(self, x):
154
154
]
155
155
156
156
157
- def get_compile_spec (target : str ) -> ArmCompileSpecBuilder :
157
+ def get_compile_spec (target : str , intermediates : bool ) -> ArmCompileSpecBuilder :
158
+ spec_builder = None
158
159
if target == "TOSA" :
159
- return (
160
+ spec_builder = (
160
161
ArmCompileSpecBuilder ().tosa_compile_spec ().set_permute_memory_format (True )
161
162
)
162
163
elif target == "ethos-u55-128" :
163
- return (
164
+ spec_builder = (
164
165
ArmCompileSpecBuilder ()
165
166
.ethosu_compile_spec (
166
167
"ethos-u55-128" ,
@@ -172,7 +173,7 @@ def get_compile_spec(target: str) -> ArmCompileSpecBuilder:
172
173
.set_quantize_io (True )
173
174
)
174
175
elif target == "ethos-u85-128" :
175
- return (
176
+ spec_builder = (
176
177
ArmCompileSpecBuilder ()
177
178
.ethosu_compile_spec (
178
179
"ethos-u85-128" ,
@@ -183,8 +184,13 @@ def get_compile_spec(target: str) -> ArmCompileSpecBuilder:
183
184
.set_permute_memory_format (True )
184
185
)
185
186
187
+ if intermediates is not None :
188
+ spec_builder .dump_intermediate_artifacts_to (args .intermediates )
186
189
187
- if __name__ == "__main__" :
190
+ return spec_builder .build ()
191
+
192
+
193
+ def get_args ():
188
194
parser = argparse .ArgumentParser ()
189
195
parser .add_argument (
190
196
"-m" ,
@@ -241,8 +247,12 @@ def get_compile_spec(target: str) -> ArmCompileSpecBuilder:
241
247
required = False ,
242
248
help = "Location for outputs, if not the default of cwd." ,
243
249
)
244
-
245
250
args = parser .parse_args ()
251
+ return args
252
+
253
+
254
+ if __name__ == "__main__" :
255
+ args = get_args ()
246
256
247
257
if args .debug :
248
258
logging .basicConfig (level = logging .DEBUG , format = FORMAT , force = True )
@@ -286,12 +296,11 @@ def get_compile_spec(target: str) -> ArmCompileSpecBuilder:
286
296
287
297
# As we can target multiple output encodings from ArmBackend, one must
288
298
# be specified.
289
- compile_spec = None
290
- if args .delegate is True :
291
- compile_spec = get_compile_spec (args .target )
292
- if args .intermediates is not None :
293
- compile_spec .dump_intermediate_artifacts_to (args .intermediates )
294
- compile_spec = compile_spec .build ()
299
+ compile_spec = (
300
+ get_compile_spec (args .target , args .intermediates )
301
+ if args .delegate is True
302
+ else None
303
+ )
295
304
296
305
logging .debug (f"Exported graph:\n { edge .exported_program ().graph } " )
297
306
if args .delegate is True :
0 commit comments