From 2ee9094ecef98f473d77b5266bf6b73ad7fab237 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 4 Aug 2021 20:06:28 -0400 Subject: [PATCH] [TEST] Refactor RPC test to isolate runs into a sub-function (#8656) We kill the rpc server in the del function. When a server co-exist with remote resources in the same function scope, the destruction order is not determined. This can cause server to be destructed before the actual remote array. As a side effect, it can cause sometime test to timeout due to waiting on the socket. --- tests/python/contrib/test_edgetpu_runtime.py | 7 +- tests/python/contrib/test_random.py | 21 ++-- tests/python/contrib/test_tflite_runtime.py | 24 ++-- tests/python/relay/test_vm.py | 42 +++---- tests/python/unittest/test_runtime_graph.py | 5 +- .../unittest/test_runtime_graph_debug.py | 6 +- .../test_runtime_module_based_interface.py | 48 ++++---- tests/python/unittest/test_runtime_rpc.py | 110 +++++++++++------- 8 files changed, 146 insertions(+), 117 deletions(-) diff --git a/tests/python/contrib/test_edgetpu_runtime.py b/tests/python/contrib/test_edgetpu_runtime.py index 7e59ab2e3cc6..2bf58106dfdc 100644 --- a/tests/python/contrib/test_edgetpu_runtime.py +++ b/tests/python/contrib/test_edgetpu_runtime.py @@ -51,7 +51,7 @@ def init_interpreter(model_path, target_edgetpu): interpreter = tflite.Interpreter(model_path=model_path) return interpreter - def check_remote(target_edgetpu=False): + def check_remote(server, target_edgetpu=False): tflite_model_path = get_tflite_model_path(target_edgetpu) # inference via tflite interpreter python apis @@ -67,7 +67,6 @@ def check_remote(target_edgetpu=False): tflite_output = interpreter.get_tensor(output_details[0]["index"]) # inference via remote tvm tflite runtime - server = rpc.Server("127.0.0.1") remote = rpc.connect(server.host, server.port) dev = remote.cpu(0) if target_edgetpu: @@ -83,9 +82,9 @@ def check_remote(target_edgetpu=False): np.testing.assert_equal(out.numpy(), tflite_output) # Target CPU on coral board - check_remote() + check_remote(rpc.Server("127.0.0.1")) # Target EdgeTPU on coral board - check_remote(target_edgetpu=True) + check_remote(rpc.Server("127.0.0.1"), target_edgetpu=True) if __name__ == "__main__": diff --git a/tests/python/contrib/test_random.py b/tests/python/contrib/test_random.py index 446efaabce0d..31756b273d25 100644 --- a/tests/python/contrib/test_random.py +++ b/tests/python/contrib/test_random.py @@ -122,17 +122,20 @@ def test_rpc(dtype): return np_ones = np.ones((512, 512), dtype=dtype) - server = rpc.Server("127.0.0.1") - remote = rpc.connect(server.host, server.port) - value = tvm.nd.empty((512, 512), dtype, remote.cpu()) - random_fill = remote.get_function("tvm.contrib.random.random_fill") - random_fill(value) - assert np.count_nonzero(value.numpy()) == 512 * 512 + def check_remote(server): + remote = rpc.connect(server.host, server.port) + value = tvm.nd.empty((512, 512), dtype, remote.cpu()) + random_fill = remote.get_function("tvm.contrib.random.random_fill") + random_fill(value) - # make sure arithmentic doesn't overflow too - np_values = value.numpy() - assert np.isfinite(np_values * np_values + np_values).any() + assert np.count_nonzero(value.numpy()) == 512 * 512 + + # make sure arithmentic doesn't overflow too + np_values = value.numpy() + assert np.isfinite(np_values * np_values + np_values).any() + + check_remote(rpc.Server("127.0.0.1")) for dtype in [ "bool", diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py index 93ab634feb15..6268a6aae615 100644 --- a/tests/python/contrib/test_tflite_runtime.py +++ b/tests/python/contrib/test_tflite_runtime.py @@ -128,18 +128,18 @@ def test_remote(): tflite_output = interpreter.get_tensor(output_details[0]["index"]) # inference via remote tvm tflite runtime - server = rpc.Server("127.0.0.1") - remote = rpc.connect(server.host, server.port) - a = remote.upload(tflite_model_path) - - with open(tflite_model_path, "rb") as model_fin: - runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) - runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) - runtime.invoke() - out = runtime.get_output(0) - np.testing.assert_equal(out.numpy(), tflite_output) - - server.terminate() + def check_remote(server): + remote = rpc.connect(server.host, server.port) + a = remote.upload(tflite_model_path) + + with open(tflite_model_path, "rb") as model_fin: + runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) + runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.numpy(), tflite_output) + + check_remote(rpc.Server("127.0.0.1")) if __name__ == "__main__": diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 6c229064b094..4496fe783459 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -879,29 +879,25 @@ def test_vm_rpc(): # Use local rpc server for testing. # Server must use popen so it doesn't inherit the current process state. It # will crash otherwise. - server = rpc.Server("localhost", port=9120) - remote = rpc.connect(server.host, server.port, session_timeout=10) - - # Upload the serialized Executable. - remote.upload(path) - # Get a handle to remote Executable. - rexec = remote.load_module("vm_library.so") - - ctx = remote.cpu() - # Build a VM out of the executable and context. - vm_factory = runtime.vm.VirtualMachine(rexec, ctx) - np_input = np.random.uniform(size=(10, 1)).astype("float32") - input_tensor = tvm.nd.array(np_input, ctx) - # Invoke its "main" function. - out = vm_factory.invoke("main", input_tensor) - # Check the result. - np.testing.assert_allclose(out.numpy(), np_input + np_input) - - # delete tensors before the server shuts down so we don't throw errors. - del input_tensor - del out - - server.terminate() + def check_remote(server): + remote = rpc.connect(server.host, server.port, session_timeout=10) + + # Upload the serialized Executable. + remote.upload(path) + # Get a handle to remote Executable. + rexec = remote.load_module("vm_library.so") + + ctx = remote.cpu() + # Build a VM out of the executable and context. + vm_factory = runtime.vm.VirtualMachine(rexec, ctx) + np_input = np.random.uniform(size=(10, 1)).astype("float32") + input_tensor = tvm.nd.array(np_input, ctx) + # Invoke its "main" function. + out = vm_factory.invoke("main", input_tensor) + # Check the result. + np.testing.assert_allclose(out.numpy(), np_input + np_input) + + check_remote(rpc.Server("127.0.0.1")) def test_get_output_single(): diff --git a/tests/python/unittest/test_runtime_graph.py b/tests/python/unittest/test_runtime_graph.py index 1259e77afbf8..458952fb5641 100644 --- a/tests/python/unittest/test_runtime_graph.py +++ b/tests/python/unittest/test_runtime_graph.py @@ -65,9 +65,8 @@ def check_verify(): out = mod.get_output(0, tvm.nd.empty((n,))) np.testing.assert_equal(out.numpy(), a + 1) - def check_remote(): + def check_remote(server): mlib = tvm.build(s, [A, B], "llvm", name="myadd") - server = rpc.Server("127.0.0.1") remote = rpc.connect(server.host, server.port) temp = utils.tempdir() dev = remote.cpu(0) @@ -115,7 +114,7 @@ def check_sharing(): del mod check_verify() - check_remote() + check_remote(rpc.Server("127.0.0.1")) check_sharing() diff --git a/tests/python/unittest/test_runtime_graph_debug.py b/tests/python/unittest/test_runtime_graph_debug.py index 192e0dad702f..cadc8ae6a4c0 100644 --- a/tests/python/unittest/test_runtime_graph_debug.py +++ b/tests/python/unittest/test_runtime_graph_debug.py @@ -32,6 +32,7 @@ @tvm.testing.requires_llvm +@tvm.testing.requires_rpc def test_graph_simple(): n = 4 A = te.placeholder((n,), name="A") @@ -160,9 +161,8 @@ def split_debug_line(i): # verify dump root delete after cleanup assert not os.path.exists(directory) - def check_remote(): + def check_remote(server): mlib = tvm.build(s, [A, B], "llvm", name="myadd") - server = rpc.Server("127.0.0.1") remote = rpc.connect(server.host, server.port) temp = utils.tempdir() dev = remote.cpu(0) @@ -182,7 +182,7 @@ def check_remote(): np.testing.assert_equal(out.numpy(), a + 1) check_verify() - check_remote() + check_remote(rpc.Server("127.0.0.1")) if __name__ == "__main__": diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 9bb05dfed65f..e984979ac14f 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -275,29 +275,31 @@ def verify_rpc_gpu_export(obj_format): from tvm import rpc - server = rpc.Server("127.0.0.1", port=9094) - remote = rpc.connect(server.host, server.port) - remote.upload(path_lib) - loaded_lib = remote.load_module(path_lib) - data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") - dev = remote.cuda() - - # raw api - gmod = loaded_lib["default"](dev) - set_input = gmod["set_input"] - run = gmod["run"] - get_output = gmod["get_output"] - set_input("data", tvm.nd.array(data, device=dev)) - run() - out = get_output(0).numpy() - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - - # graph executor wrapper - gmod = graph_executor.GraphModule(loaded_lib["default"](dev)) - gmod.set_input("data", data) - gmod.run() - out = gmod.get_output(0).numpy() - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + def check_remote(server): + remote = rpc.connect(server.host, server.port) + remote.upload(path_lib) + loaded_lib = remote.load_module(path_lib) + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") + dev = remote.cuda() + + # raw api + gmod = loaded_lib["default"](dev) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data, device=dev)) + run() + out = get_output(0).numpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph executor wrapper + gmod = graph_executor.GraphModule(loaded_lib["default"](dev)) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).numpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + check_remote(rpc.Server("127.0.0.1")) for obj_format in [".so", ".tar"]: verify_cpu_export(obj_format) diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index f90c9548ec02..22aea8d1fcea 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -53,6 +53,9 @@ ), ) +# NOTE: When writing tests, wrap remote related checking in a sub-function +# to ensure all the remote resources destructs before the server terminates + @tvm.testing.requires_rpc def test_bigendian_rpc(): @@ -90,38 +93,49 @@ def verify_rpc(remote, target, shape, dtype): def test_rpc_simple(): server = rpc.Server(key="x1") client = rpc.connect("127.0.0.1", server.port, key="x1") - f1 = client.get_function("rpc.test.addone") - assert f1(10) == 11 - f3 = client.get_function("rpc.test.except") - with pytest.raises(tvm._ffi.base.TVMError): - f3("abc") + def check_remote(): + f1 = client.get_function("rpc.test.addone") + assert f1(10) == 11 + f3 = client.get_function("rpc.test.except") + + with pytest.raises(tvm._ffi.base.TVMError): + f3("abc") + + f2 = client.get_function("rpc.test.strcat") + assert f2("abc", 11) == "abc:11" - f2 = client.get_function("rpc.test.strcat") - assert f2("abc", 11) == "abc:11" + check_remote() @tvm.testing.requires_rpc def test_rpc_runtime_string(): server = rpc.Server(key="x1") client = rpc.connect("127.0.0.1", server.port, key="x1") - func = client.get_function("rpc.test.runtime_str_concat") - x = tvm.runtime.container.String("abc") - y = tvm.runtime.container.String("def") - assert str(func(x, y)) == "abcdef" + + def check_remote(): + func = client.get_function("rpc.test.runtime_str_concat") + x = tvm.runtime.container.String("abc") + y = tvm.runtime.container.String("def") + assert str(func(x, y)) == "abcdef" + + check_remote() @tvm.testing.requires_rpc def test_rpc_array(): - x = np.ones((3, 4)) - server = rpc.Server() remote = rpc.connect("127.0.0.1", server.port) - r_cpu = tvm.nd.array(x, remote.cpu(0)) - assert str(r_cpu.device).startswith("remote") - np.testing.assert_equal(r_cpu.numpy(), x) - fremote = remote.get_function("rpc.test.remote_array_func") - fremote(r_cpu) + + def check_remote(): + x = np.ones((3, 4)) + r_cpu = tvm.nd.array(x, remote.cpu(0)) + assert str(r_cpu.device).startswith("remote") + np.testing.assert_equal(r_cpu.numpy(), x) + fremote = remote.get_function("rpc.test.remote_array_func") + fremote(r_cpu) + + check_remote() @tvm.testing.requires_rpc @@ -129,13 +143,17 @@ def test_rpc_large_array(): # testcase of large array creation server = rpc.Server() remote = rpc.connect("127.0.0.1", server.port) - dev = remote.cpu(0) - a_np = np.ones((5041, 720)).astype("float32") - b_np = np.ones((720, 192)).astype("float32") - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - np.testing.assert_equal(a.numpy(), a_np) - np.testing.assert_equal(b.numpy(), b_np) + + def check_remote(): + dev = remote.cpu(0) + a_np = np.ones((5041, 720)).astype("float32") + b_np = np.ones((720, 192)).astype("float32") + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + np.testing.assert_equal(a.numpy(), a_np) + np.testing.assert_equal(b.numpy(), b_np) + + check_remote() @tvm.testing.requires_rpc @@ -186,10 +204,14 @@ def check_minrpc(): def test_rpc_file_exchange(): server = rpc.Server() remote = rpc.connect("127.0.0.1", server.port) - blob = bytearray(np.random.randint(0, 10, size=(10))) - remote.upload(blob, "dat.bin") - rev = remote.download("dat.bin") - assert rev == blob + + def check_remote(): + blob = bytearray(np.random.randint(0, 10, size=(10))) + remote.upload(blob, "dat.bin") + rev = remote.download("dat.bin") + assert rev == blob + + check_remote() @tvm.testing.requires_rpc @@ -321,9 +343,13 @@ def check_remote_link_cl(remote): def test_rpc_return_func(): server = rpc.Server(key="x1") client = rpc.connect("127.0.0.1", server.port, key="x1") - f1 = client.get_function("rpc.test.add_to_lhs") - fadd = f1(10) - assert fadd(12) == 22 + + def check_remote(): + f1 = client.get_function("rpc.test.add_to_lhs") + fadd = f1(10) + assert fadd(12) == 22 + + check_remote() @tvm.testing.requires_rpc @@ -386,14 +412,18 @@ def run_arr_test(): @tvm.testing.requires_rpc def test_local_func(): client = rpc.LocalSession() - f1 = client.get_function("rpc.test.add_to_lhs") - fadd = f1(10) - assert fadd(12) == 22 - - blob = bytearray(np.random.randint(0, 10, size=(10))) - client.upload(blob, "dat.bin") - rev = client.download("dat.bin") - assert rev == blob + + def check_remote(): + f1 = client.get_function("rpc.test.add_to_lhs") + fadd = f1(10) + assert fadd(12) == 22 + + blob = bytearray(np.random.randint(0, 10, size=(10))) + client.upload(blob, "dat.bin") + rev = client.download("dat.bin") + assert rev == blob + + check_remote() @tvm.testing.requires_rpc