diff --git a/python/tvm/testing/usmp.py b/python/tvm/testing/usmp.py index a8607373b7044..87ab686d4f252 100644 --- a/python/tvm/testing/usmp.py +++ b/python/tvm/testing/usmp.py @@ -19,9 +19,12 @@ import tvm -def check_for_no_tvm_backendallocworkspace_calls(mod: tvm.runtime.module): - """This checker checks whether any c-source produced has TVMBackendAllocWorkspace calls. - If USMP is invoked, none of them should have TVMBAW calls""" +def is_tvm_backendallocworkspace_calls(mod: tvm.runtime.module) -> bool: + """TVMBackendAllocWorkspace call check. + + This checker checks whether any c-source produced has TVMBackendAllocWorkspace calls. + If USMP is invoked, none of them should have TVMBAW calls + """ dso_modules = mod._collect_dso_modules() for dso_mod in dso_modules: if dso_mod.type_key not in ["c", "llvm"]: @@ -30,6 +33,7 @@ def check_for_no_tvm_backendallocworkspace_calls(mod: tvm.runtime.module): ), 'Current AoT codegen flow should only produce type "c" or "llvm" runtime modules' source = dso_mod.get_source() - assert ( - source.count("TVMBackendAllocWorkspace") == 0 - ), "This is failing because USMP was unable to plan for every tir.allocate node" + if source.count("TVMBackendAllocWorkspace") != 0: + return False + + return True diff --git a/tests/python/contrib/test_hexagon/test_usmp.py b/tests/python/contrib/test_hexagon/test_usmp.py index 85ff03b3ea931..1dc031f561d82 100644 --- a/tests/python/contrib/test_hexagon/test_usmp.py +++ b/tests/python/contrib/test_hexagon/test_usmp.py @@ -24,13 +24,15 @@ from tvm import relay from tvm.relay.backend import Executor, Runtime from tvm.contrib.hexagon.session import Session -from tvm.testing.usmp import check_for_no_tvm_backendallocworkspace_calls +from tvm.testing.usmp import is_tvm_backendallocworkspace_calls from .conftest import requires_hexagon_toolchain +usmp_enabled = tvm.testing.parameter(False, True) + @requires_hexagon_toolchain -def test_conv2d(hexagon_session: Session, aot_host_target, aot_target): +def test_conv2d(hexagon_session: Session, aot_host_target, aot_target, usmp_enabled): dtype = "float32" input_shape = (1, 8, 8, 3) w1_shape = (5, 5, 3, 1) @@ -73,7 +75,7 @@ def test_conv2d(hexagon_session: Session, aot_host_target, aot_target): params = {"weight1": weight1_data, "weight2": weight2_data} inputs = {"data": input_data} - with tvm.transform.PassContext(opt_level=3, config={"tir.usmp.enable": True}): + with tvm.transform.PassContext(opt_level=3, config={"tir.usmp.enable": usmp_enabled}): lowered = tvm.relay.build( relay_mod, params=params, @@ -82,7 +84,7 @@ def test_conv2d(hexagon_session: Session, aot_host_target, aot_target): executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}), ) - check_for_no_tvm_backendallocworkspace_calls(lowered.lib) + assert is_tvm_backendallocworkspace_calls(lowered.lib) == usmp_enabled aot_mod = hexagon_session.get_executor_from_factory(lowered) aot_mod.set_input(**inputs) diff --git a/tests/python/relay/aot/test_crt_aot_usmp.py b/tests/python/relay/aot/test_crt_aot_usmp.py index 46a6db62eacd1..524dae37570bb 100644 --- a/tests/python/relay/aot/test_crt_aot_usmp.py +++ b/tests/python/relay/aot/test_crt_aot_usmp.py @@ -43,7 +43,13 @@ run_and_check, create_relay_module_and_inputs_from_tflite_file, ) -from tvm.testing.usmp import check_for_no_tvm_backendallocworkspace_calls +from tvm.testing.usmp import is_tvm_backendallocworkspace_calls + + +def _check_for_no_tvm_backendallocworkspace_calls(mod: tvm.runtime.module): + assert ( + is_tvm_backendallocworkspace_calls(mod) + ), "This is failing because USMP was unable to plan for every tir.allocate node." @pytest.mark.parametrize( @@ -125,7 +131,7 @@ def test_conv2d(interface_api, use_unpacked_api, test_runner, groups, weight_sha ) for compiled_model in compiled_test_mods: - check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + _check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) run_and_check( models=compiled_test_mods, @@ -184,7 +190,7 @@ def test_byoc_microtvm(merge_compiler_regions): ) for compiled_model in compiled_test_mods: - check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + _check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) run_and_check( models=compiled_test_mods, @@ -238,7 +244,7 @@ def test_tflite_model_u1_usecase(model_url, usmp_algo, workspace_size): ) for compiled_model in compiled_test_mods: - check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + _check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) # Checking the workspace size reported in model library format mlf_memory_map = mlf._build_function_memory_map( @@ -317,7 +323,7 @@ def test_tflite_model_u3_usecase_single_external_pool(model_url, usmp_algo): ) for compiled_model in compiled_test_mods: - check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + _check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) run_and_check( models=compiled_test_mods, @@ -377,7 +383,7 @@ def test_tflite_model_u3_usecase_two_external_pools(model_url, usmp_algo): ) for compiled_model in compiled_test_mods: - check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + _check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) run_and_check( models=compiled_test_mods, @@ -445,7 +451,7 @@ def test_tflite_model_u2_usecase_two_models_with_a_single_external_pool(model_ur ) for compiled_model in compiled_test_mods: - check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + _check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) run_and_check( models=compiled_test_mods, @@ -513,7 +519,7 @@ def test_tflite_model_u4_usecase_single_external_pool(model_url, usmp_algo): ) for compiled_model in compiled_test_mods: - check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + _check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) run_and_check( models=compiled_test_mods, @@ -589,7 +595,7 @@ def test_tflite_model_u4_usecase_two_external_pools(model_url, usmp_algo): ) for compiled_model in compiled_test_mods: - check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + _check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) run_and_check( models=compiled_test_mods,