Skip to content

Commit 6af34b2

Browse files
committed
simplify function for lintrunner
Signed-off-by: Rob Elliott <robert.elliott@arm.com>
1 parent 71607fa commit 6af34b2

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

examples/arm/aot_arm_compiler.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,14 @@ def forward(self, x):
154154
]
155155

156156

157-
def get_compile_spec(target: str) -> ArmCompileSpecBuilder:
157+
def get_compile_spec(target: str, intermediates: bool) -> ArmCompileSpecBuilder:
158+
spec_builder = None
158159
if target == "TOSA":
159-
return (
160+
spec_builder = (
160161
ArmCompileSpecBuilder().tosa_compile_spec().set_permute_memory_format(True)
161162
)
162163
elif target == "ethos-u55-128":
163-
return (
164+
spec_builder = (
164165
ArmCompileSpecBuilder()
165166
.ethosu_compile_spec(
166167
"ethos-u55-128",
@@ -172,7 +173,7 @@ def get_compile_spec(target: str) -> ArmCompileSpecBuilder:
172173
.set_quantize_io(True)
173174
)
174175
elif target == "ethos-u85-128":
175-
return (
176+
spec_builder = (
176177
ArmCompileSpecBuilder()
177178
.ethosu_compile_spec(
178179
"ethos-u85-128",
@@ -183,8 +184,13 @@ def get_compile_spec(target: str) -> ArmCompileSpecBuilder:
183184
.set_permute_memory_format(True)
184185
)
185186

187+
if intermediates is not None:
188+
spec_builder.dump_intermediate_artifacts_to(args.intermediates)
186189

187-
if __name__ == "__main__":
190+
return spec_builder.build()
191+
192+
193+
def get_args():
188194
parser = argparse.ArgumentParser()
189195
parser.add_argument(
190196
"-m",
@@ -241,8 +247,12 @@ def get_compile_spec(target: str) -> ArmCompileSpecBuilder:
241247
required=False,
242248
help="Location for outputs, if not the default of cwd.",
243249
)
244-
245250
args = parser.parse_args()
251+
return args
252+
253+
254+
if __name__ == "__main__":
255+
args = get_args()
246256

247257
if args.debug:
248258
logging.basicConfig(level=logging.DEBUG, format=FORMAT, force=True)
@@ -286,12 +296,11 @@ def get_compile_spec(target: str) -> ArmCompileSpecBuilder:
286296

287297
# As we can target multiple output encodings from ArmBackend, one must
288298
# be specified.
289-
compile_spec = None
290-
if args.delegate is True:
291-
compile_spec = get_compile_spec(args.target)
292-
if args.intermediates is not None:
293-
compile_spec.dump_intermediate_artifacts_to(args.intermediates)
294-
compile_spec = compile_spec.build()
299+
compile_spec = (
300+
get_compile_spec(args.target, args.intermediates)
301+
if args.delegate is True
302+
else None
303+
)
295304

296305
logging.debug(f"Exported graph:\n{edge.exported_program().graph}")
297306
if args.delegate is True:

0 commit comments

Comments
 (0)