3939
4040_LOG = logging .getLogger (__name__ )
4141
42+ AOT_SUCCESS_TOKEN = "AOT_TEST_SUCCESS"
43+ AOT_FAILURE_TOKEN = "AOT_TEST_FAILURE"
44+
4245
4346class AOTTestModel (NamedTuple ):
4447 """Class to describe a model under test
@@ -64,6 +67,38 @@ class AOTTestModel(NamedTuple):
6467 params : Optional [Dict [str , np .array ]] = None
6568
6669
70+ class AOTTestRunner (NamedTuple ):
71+ """Class to describe a test runner for AOT code
72+
73+ Parameters
74+ ----------
75+ makefile: str
76+ Premade Makefile to use from the AOT test folder
77+ prologue: str
78+ Code to prepend to the main function
79+ includes: List[str]
80+ Additional includes required to run the AOT test runner
81+ parameters: Map[str, str]
82+ Additional parameters to pass to the make command
83+ """
84+
85+ makefile : str = "default"
86+ prologue : str = ""
87+ includes : List [str ] = []
88+ parameters : Dict [str , str ] = {}
89+
90+
91+ AOT_DEFAULT_RUNNER = AOTTestRunner ()
92+ AOT_CORSTONE300_RUNNER = AOTTestRunner (
93+ makefile = "corstone300" ,
94+ prologue = """
95+ uart_init();
96+ """ ,
97+ includes = ["uart.h" ],
98+ parameters = {"NPU_VARIANT" : "256" },
99+ )
100+
101+
67102def mangle_name (mod_name , name ):
68103 mod_name = mangle_module_name (mod_name )
69104 return mod_name + "_" + name
@@ -114,17 +149,27 @@ def parametrize_aot_options(test):
114149
115150 interface_api = ["packed" , "c" ]
116151 use_unpacked_api = [True , False ]
117- use_calculated_workspaces = [True , False ]
152+ test_runner = [AOT_DEFAULT_RUNNER , AOT_CORSTONE300_RUNNER ]
153+
154+ all_combinations = itertools .product (interface_api , use_unpacked_api , test_runner )
118155
119- all_combinations = itertools .product (interface_api , use_unpacked_api , use_calculated_workspaces )
120156 # Filter out packed operators with c interface
121157 valid_combinations = filter (
122- lambda parameters : not (parameters [0 ] == "c" and parameters [1 ] == False ),
158+ lambda parameters : not (parameters [0 ] == "c" and not parameters [1 ]),
123159 all_combinations ,
124160 )
125161
162+ # Only use reference system for C interface and unpacked API calls
163+ valid_combinations = filter (
164+ lambda parameters : not (
165+ parameters [2 ] == AOT_CORSTONE300_RUNNER
166+ and (parameters [0 ] == "packed" or not parameters [1 ])
167+ ),
168+ valid_combinations ,
169+ )
170+
126171 return pytest .mark .parametrize (
127- ["interface_api" , "use_unpacked_api" , "use_calculated_workspaces " ],
172+ ["interface_api" , "use_unpacked_api" , "test_runner " ],
128173 valid_combinations ,
129174 )(test )
130175
@@ -160,7 +205,7 @@ def subprocess_log_output(cmd, cwd, logfile):
160205 return proc .wait ()
161206
162207
163- def emit_main_prologue (main_file , workspace_bytes ):
208+ def emit_main_prologue (main_file , custom_prologue , workspace_bytes ):
164209 # Add TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES because of memory alignment.
165210 main_file .write (
166211 f"#define WORKSPACE_SIZE ({ workspace_bytes } + TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES)\n "
@@ -185,6 +230,7 @@ def emit_main_prologue(main_file, workspace_bytes):
185230int main(){\n
186231"""
187232 )
233+ main_file .write (custom_prologue )
188234
189235
190236def emit_main_data (main_file , input_map , output_list , mod_name ):
@@ -297,11 +343,11 @@ def emit_main_compare(main_file, output_list, mod_name):
297343 main_file .write (f"for (int i = 0; i<{ actual_data_name } { i } _len; i++){{\n " )
298344 if is_float_dtype :
299345 main_file .write (
300- f'if (fabs({ actual_data_name } { i } [i]-{ expected_data_name } { i } [i]) > 0.001f){{\n \t printf("ko \\ n");\n \t return -1;}}\n '
346+ f'if (fabs({ actual_data_name } { i } [i]-{ expected_data_name } { i } [i]) > 0.001f){{\n \t printf("{ AOT_FAILURE_TOKEN } \\ n");\n \t return -1;}}\n '
301347 )
302348 else :
303349 main_file .write (
304- f'if ({ actual_data_name } { i } [i]!={ expected_data_name } { i } [i]){{\n \t printf("ko \\ n");\n \t return -1;}}\n '
350+ f'if ({ actual_data_name } { i } [i]!={ expected_data_name } { i } [i]){{\n \t printf("{ AOT_FAILURE_TOKEN } \\ n");\n \t return -1;}}\n '
305351 )
306352 main_file .write ("}\n " )
307353
@@ -312,36 +358,40 @@ def emit_main_init_memory_manager(main_file):
312358
313359
314360def emit_main_epilogue (main_file ):
315- main_file .write ('printf("ok \\ n");' )
361+ main_file .write (f 'printf("{ AOT_SUCCESS_TOKEN } \\ n");' )
316362 main_file .write ("return 0;" )
317363 main_file .write ("}\n " )
318364
319365
320- def emit_main_common_includes (main_file ):
366+ def emit_main_common_includes (main_file , custom_includes ):
321367 main_file .write ("#include <stdio.h>\n " )
322368 main_file .write ("#include <math.h>\n " )
323369 main_file .write ('#include "tvm/runtime/c_runtime_api.h"\n ' )
324370 main_file .write ('#include "tvm/runtime/crt/stack_allocator.h"\n ' )
371+ for include in custom_includes :
372+ main_file .write (f'#include "{ include } "\n ' )
325373
326374
327375def emit_main_micro_include (main_file , mod_name ):
328376 main_file .write (f"#include <{ mangle_module_name (mod_name )} .h>\n " )
329377
330378
331- def create_main (test_name , models , output_path , interface_api , workspace_bytes ):
379+ def create_main (
380+ test_name , models , output_path , custom_includes , custom_prologue , interface_api , workspace_bytes
381+ ):
332382 file_path = pathlib .Path (f"{ output_path } /" + test_name ).resolve ()
333383 # create header file
334384 raw_path = file_path .with_suffix (".c" ).resolve ()
335385 with open (raw_path , "w" ) as main_file :
336- emit_main_common_includes (main_file )
386+ emit_main_common_includes (main_file , custom_includes )
337387
338388 if interface_api == "c" :
339389 for model in models :
340390 emit_main_micro_include (main_file , model .name )
341-
342- emit_main_prologue (main_file , workspace_bytes )
343391 for model in models :
344392 emit_main_data (main_file , model .inputs , model .outputs , model .name )
393+
394+ emit_main_prologue (main_file , custom_prologue , workspace_bytes )
345395 emit_main_init_memory_manager (main_file )
346396
347397 if interface_api == "c" :
@@ -396,9 +446,10 @@ def extract_main_workspace_size_bytes(extract_dir):
396446
397447def compile_and_run (
398448 models : Union [List [AOTTestModel ], AOTTestModel ],
449+ runner : AOTTestRunner ,
399450 interface_api ,
400451 use_unpacked_api ,
401- use_calculated_workspaces ,
452+ debug_calculated_workspaces = False ,
402453 workspace_byte_alignment = 8 ,
403454 enable_op_fusion = True ,
404455):
@@ -414,7 +465,7 @@ def compile_and_run(
414465 models = [models ]
415466
416467 # The calculated workspaces will not account for stack allocator tags used for debugging
417- if not use_calculated_workspaces :
468+ if debug_calculated_workspaces :
418469 cflags += "-DTVM_CRT_STACK_ALLOCATOR_ENABLE_LIFO_CHECK "
419470
420471 config = {"tir.disable_vectorize" : True }
@@ -452,10 +503,7 @@ def compile_and_run(
452503 t = tarfile .open (tar_file )
453504 t .extractall (base_path )
454505
455- if use_calculated_workspaces :
456- workspace_bytes += extract_main_workspace_size_bytes (base_path )
457- else :
458- workspace_bytes += 16384 * 1024
506+ workspace_bytes += extract_main_workspace_size_bytes (base_path )
459507
460508 for key in model .inputs :
461509 create_header_file (
@@ -480,31 +528,41 @@ def compile_and_run(
480528 "test.c" ,
481529 models ,
482530 build_path ,
531+ runner .includes ,
532+ runner .prologue ,
483533 interface_api ,
484534 workspace_bytes ,
485535 )
486536
487537 # Verify that compiles fine
488538 file_dir = os .path .dirname (os .path .abspath (__file__ ))
489539 codegen_path = os .path .join (base_path , "codegen" )
490- makefile = os .path .join (file_dir , "aot_test.mk" )
491- make_cmd = (
492- f"make CFLAGS='{ cflags } ' -f { makefile } build_dir="
493- + build_path
540+ makefile = os .path .join (file_dir , f"{ runner .makefile } .mk" )
541+ custom_params = " " .join ([f" { param } ='{ value } '" for param , value in runner .parameters .items ()])
542+ make_command = (
543+ f"make -f { makefile } build_dir={ build_path } "
544+ + f" CFLAGS='{ cflags } '"
494545 + f" TVM_ROOT={ file_dir } /../../../.."
546+ + f" AOT_TEST_ROOT={ file_dir } "
495547 + f" CODEGEN_ROOT={ codegen_path } "
496548 + f" STANDALONE_CRT_DIR={ tvm .micro .get_standalone_crt_dir ()} "
549+ + custom_params
497550 )
498551
499552 compile_log_path = os .path .join (build_path , "test_compile.log" )
500- ret = subprocess_log_output (make_cmd , "." , compile_log_path )
553+ compile_command = f"{ make_command } aot_test_runner"
554+ ret = subprocess_log_output (compile_command , "." , compile_log_path )
501555 assert ret == 0
502556
503557 # Verify that runs fine
504558 run_log_path = os .path .join (build_path , "test_run.log" )
505- ret = subprocess_log_output ("./aot_test_runner" , build_path , run_log_path )
559+ run_command = f"{ make_command } run"
560+ ret = subprocess_log_output (run_command , build_path , run_log_path )
506561 assert ret == 0
507562
563+ with open (run_log_path ) as run_log :
564+ assert AOT_SUCCESS_TOKEN in run_log .read ()
565+
508566
509567def generate_ref_data (mod , input_data , params = None , target = "llvm" ):
510568 """Generate reference data through executing the relay module"""
0 commit comments