From 737100c53e1506a4fe0811dcfcea064929abb827 Mon Sep 17 00:00:00 2001 From: dyastremsky <58150256+dyastremsky@users.noreply.github.com> Date: Wed, 21 Jun 2023 11:08:26 -0700 Subject: [PATCH] Ensure L0_batch_input requests received in order (#5963) * Add print statements for debugging * Add debugging print statements * Test using grpc client with stream to fix race * Use streaming client in all non-batch tests * Switch all clients to streaming GRPC * Remove unused imports, vars * Address comments * Remove random comment * Set inputs as separate function * Split set inputs based on test type --- qa/L0_batch_input/batch_input_test.py | 202 +++++++++++++++----------- 1 file changed, 119 insertions(+), 83 deletions(-) diff --git a/qa/L0_batch_input/batch_input_test.py b/qa/L0_batch_input/batch_input_test.py index 46cc569ad1..d5dfe2763d 100644 --- a/qa/L0_batch_input/batch_input_test.py +++ b/qa/L0_batch_input/batch_input_test.py @@ -30,47 +30,63 @@ import unittest import numpy as np +from functools import partial +import queue import test_util as tu -import tritonhttpclient -from tritonclientutils import InferenceServerException +import tritonclient.grpc as grpcclient +from tritonclient.utils import InferenceServerException class BatchInputTest(tu.TestResultCollector): def setUp(self): + self.client = grpcclient.InferenceServerClient(url='localhost:8001') + + def callback(user_data, result, error): + if error: + user_data.put(error) + else: + user_data.put(result) + + self.client_callback = callback + + def set_inputs(self, shapes, input_name): self.dtype_ = np.float32 self.inputs = [] - # 4 set of inputs with shape [2], [4], [1], [3] - for value in [2, 4, 1, 3]: - self.inputs.append([ - tritonhttpclient.InferInput('RAGGED_INPUT', [1, value], "FP32") - ]) + for shape in shapes: + self.inputs.append( + [grpcclient.InferInput(input_name, [1, shape[0]], "FP32")]) self.inputs[-1][0].set_data_from_numpy( - np.full([1, value], value, np.float32)) - self.client = tritonhttpclient.InferenceServerClient( - url="localhost:8000", concurrency=len(self.inputs)) + np.full([1, shape[0]], shape[0], np.float32)) - def test_ragged_output(self): - model_name = "ragged_io" - - # The model is identity model + def set_inputs_for_batch_item(self, shapes, input_name): + self.dtype_ = np.float32 self.inputs = [] - for value in [2, 4, 1, 3]: + for shape in shapes: self.inputs.append( - [tritonhttpclient.InferInput('INPUT0', [1, value], "FP32")]) + [grpcclient.InferInput(input_name, shape, "FP32")]) self.inputs[-1][0].set_data_from_numpy( - np.full([1, value], value, np.float32)) + np.full(shape, shape[0], np.float32)) + + def test_ragged_output(self): + model_name = "ragged_io" + # The model is an identity model + self.set_inputs([[2], [4], [1], [3]], "INPUT0") + user_data = queue.Queue() + self.client.start_stream( + callback=partial(self.client_callback, user_data)) + output_name = 'OUTPUT0' - outputs = [tritonhttpclient.InferRequestedOutput(output_name)] + outputs = [grpcclient.InferRequestedOutput(output_name)] async_requests = [] try: - for inputs in self.inputs: + for input in self.inputs: # Asynchronous inference call. async_requests.append( - self.client.async_infer(model_name=model_name, - inputs=inputs, - outputs=outputs)) + self.client.async_stream_infer(model_name=model_name, + inputs=input, + outputs=outputs)) expected_value_list = [[v] * v for v in [2, 4, 1, 3]] expected_value_list = [ @@ -80,7 +96,7 @@ def test_ragged_output(self): for idx in range(len(async_requests)): # Get the result from the initiated asynchronous inference request. # Note the call will block till the server responds. - result = async_requests[idx].get_result() + result = user_data.get() # Validate the results by comparing with precomputed values. output_data = result.as_numpy(output_name) @@ -90,21 +106,25 @@ def test_ragged_output(self): idx, expected_value_list[idx], output_data)) except InferenceServerException as ex: self.assertTrue(False, "unexpected error {}".format(ex)) + self.client.stop_stream() def test_ragged_input(self): model_name = "ragged_acc_shape" + self.set_inputs([[2], [4], [1], [3]], "RAGGED_INPUT") + user_data = queue.Queue() + self.client.start_stream( + callback=partial(self.client_callback, user_data)) output_name = 'RAGGED_OUTPUT' - outputs = [tritonhttpclient.InferRequestedOutput(output_name)] - + outputs = [grpcclient.InferRequestedOutput(output_name)] async_requests = [] try: - for inputs in self.inputs: + for input in self.inputs: # Asynchronous inference call. async_requests.append( - self.client.async_infer(model_name=model_name, - inputs=inputs, - outputs=outputs)) + self.client.async_stream_infer(model_name=model_name, + inputs=input, + outputs=outputs)) value_lists = [[v] * v for v in [2, 4, 1, 3]] expected_value = [] @@ -114,8 +134,7 @@ def test_ragged_input(self): for idx in range(len(async_requests)): # Get the result from the initiated asynchronous inference request. # Note the call will block till the server responds. - result = async_requests[idx].get_result() - + result = user_data.get() # Validate the results by comparing with precomputed values. output_data = result.as_numpy(output_name) self.assertTrue( @@ -124,27 +143,32 @@ def test_ragged_input(self): idx, expected_value, output_data)) except InferenceServerException as ex: self.assertTrue(False, "unexpected error {}".format(ex)) + self.client.stop_stream() def test_element_count(self): model_name = "ragged_element_count_acc_zero" + self.set_inputs([[2], [4], [1], [3]], "RAGGED_INPUT") + user_data = queue.Queue() + self.client.start_stream( + callback=partial(self.client_callback, user_data)) output_name = 'BATCH_AND_SIZE_OUTPUT' - outputs = [tritonhttpclient.InferRequestedOutput(output_name)] + outputs = [grpcclient.InferRequestedOutput(output_name)] async_requests = [] try: - for inputs in self.inputs: + for input in self.inputs: # Asynchronous inference call. async_requests.append( - self.client.async_infer(model_name=model_name, - inputs=inputs, - outputs=outputs)) + self.client.async_stream_infer(model_name=model_name, + inputs=input, + outputs=outputs)) expected_value = np.asarray([[2, 4, 1, 3]], np.float32) for idx in range(len(async_requests)): # Get the result from the initiated asynchronous inference request. # Note the call will block till the server responds. - result = async_requests[idx].get_result() + result = user_data.get() # Validate the results by comparing with precomputed values. output_data = result.as_numpy(output_name) @@ -154,27 +178,32 @@ def test_element_count(self): idx, expected_value, output_data)) except InferenceServerException as ex: self.assertTrue(False, "unexpected error {}".format(ex)) + self.client.stop_stream() def test_accumulated_element_count(self): model_name = "ragged_acc_shape" + self.set_inputs([[2], [4], [1], [3]], "RAGGED_INPUT") + user_data = queue.Queue() + self.client.start_stream( + callback=partial(self.client_callback, user_data)) output_name = 'BATCH_AND_SIZE_OUTPUT' - outputs = [tritonhttpclient.InferRequestedOutput(output_name)] + outputs = [grpcclient.InferRequestedOutput(output_name)] async_requests = [] try: - for inputs in self.inputs: + for input in self.inputs: # Asynchronous inference call. async_requests.append( - self.client.async_infer(model_name=model_name, - inputs=inputs, - outputs=outputs)) + self.client.async_stream_infer(model_name=model_name, + inputs=input, + outputs=outputs)) expected_value = np.asarray([[2, 6, 7, 10]], np.float32) for idx in range(len(async_requests)): # Get the result from the initiated asynchronous inference request. # Note the call will block till the server responds. - result = async_requests[idx].get_result() + result = user_data.get() # Validate the results by comparing with precomputed values. output_data = result.as_numpy(output_name) @@ -184,27 +213,32 @@ def test_accumulated_element_count(self): idx, expected_value, output_data)) except InferenceServerException as ex: self.assertTrue(False, "unexpected error {}".format(ex)) + self.client.stop_stream() def test_accumulated_element_count_with_zero(self): model_name = "ragged_element_count_acc_zero" + self.set_inputs([[2], [4], [1], [3]], "RAGGED_INPUT") + user_data = queue.Queue() + self.client.start_stream( + callback=partial(self.client_callback, user_data)) output_name = 'BATCH_OUTPUT' - outputs = [tritonhttpclient.InferRequestedOutput(output_name)] + outputs = [grpcclient.InferRequestedOutput(output_name)] async_requests = [] try: - for inputs in self.inputs: + for input in self.inputs: # Asynchronous inference call. async_requests.append( - self.client.async_infer(model_name=model_name, - inputs=inputs, - outputs=outputs)) + self.client.async_stream_infer(model_name=model_name, + inputs=input, + outputs=outputs)) expected_value = np.asarray([[0, 2, 6, 7, 10]], np.float32) for idx in range(len(async_requests)): # Get the result from the initiated asynchronous inference request. # Note the call will block till the server responds. - result = async_requests[idx].get_result() + result = user_data.get() # Validate the results by comparing with precomputed values. output_data = result.as_numpy(output_name) @@ -214,26 +248,31 @@ def test_accumulated_element_count_with_zero(self): idx, expected_value, output_data)) except InferenceServerException as ex: self.assertTrue(False, "unexpected error {}".format(ex)) + self.client.stop_stream() def test_max_element_count_as_shape(self): model_name = "ragged_acc_shape" + self.set_inputs([[2], [4], [1], [3]], "RAGGED_INPUT") + user_data = queue.Queue() + self.client.start_stream( + callback=partial(self.client_callback, user_data)) output_name = 'BATCH_OUTPUT' - outputs = [tritonhttpclient.InferRequestedOutput(output_name)] + outputs = [grpcclient.InferRequestedOutput(output_name)] async_requests = [] try: - for inputs in self.inputs: + for input in self.inputs: # Asynchronous inference call. async_requests.append( - self.client.async_infer(model_name=model_name, - inputs=inputs, - outputs=outputs)) + self.client.async_stream_infer(model_name=model_name, + inputs=input, + outputs=outputs)) for idx in range(len(async_requests)): # Get the result from the initiated asynchronous inference request. # Note the call will block till the server responds. - result = async_requests[idx].get_result() + result = user_data.get() # Validate the results by comparing with precomputed values. output_data = result.as_numpy(output_name) @@ -243,40 +282,38 @@ def test_max_element_count_as_shape(self): .format(idx, 4, output_data.shape)) except InferenceServerException as ex: self.assertTrue(False, "unexpected error {}".format(ex)) + self.client.stop_stream() def test_batch_item_shape_flatten(self): # Use 4 set of inputs with shape # [1, 4, 1], [1, 1, 2], [1, 1, 2], [1, 2, 2] # Note that the test only checks the formation of "BATCH_INPUT" where # the value of "RAGGED_INPUT" is irrelevant, only the shape matters - self.inputs = [] - for value in [[1, 4, 1], [1, 1, 2], [1, 1, 2], [1, 2, 2]]: - self.inputs.append( - [tritonhttpclient.InferInput('RAGGED_INPUT', value, "FP32")]) - self.inputs[-1][0].set_data_from_numpy( - np.full(value, value[0], np.float32)) - self.client = tritonhttpclient.InferenceServerClient( - url="localhost:8000", concurrency=len(self.inputs)) + self.set_inputs_for_batch_item( + [[1, 4, 1], [1, 1, 2], [1, 1, 2], [1, 2, 2]], "RAGGED_INPUT") model_name = "batch_item_flatten" + user_data = queue.Queue() + self.client.start_stream( + callback=partial(self.client_callback, user_data)) output_name = 'BATCH_OUTPUT' - outputs = [tritonhttpclient.InferRequestedOutput(output_name)] + outputs = [grpcclient.InferRequestedOutput(output_name)] async_requests = [] try: - for inputs in self.inputs: + for input in self.inputs: # Asynchronous inference call. async_requests.append( - self.client.async_infer(model_name=model_name, - inputs=inputs, - outputs=outputs)) + self.client.async_stream_infer(model_name=model_name, + inputs=input, + outputs=outputs)) expected_value = np.asarray([[4, 1, 1, 2, 1, 2, 2, 2]], np.float32) for idx in range(len(async_requests)): # Get the result from the initiated asynchronous inference request. # Note the call will block till the server responds. - result = async_requests[idx].get_result() + result = user_data.get() # Validate the results by comparing with precomputed values. output_data = result.as_numpy(output_name) @@ -286,19 +323,14 @@ def test_batch_item_shape_flatten(self): idx, expected_value, output_data)) except InferenceServerException as ex: self.assertTrue(False, "unexpected error {}".format(ex)) + self.client.stop_stream() def test_batch_item_shape(self): # Use 3 set of inputs with shape [2, 1, 2], [1, 1, 2], [1, 2, 2] # Note that the test only checks the formation of "BATCH_INPUT" where # the value of "RAGGED_INPUT" is irrelevant, only the shape matters - inputs = [] - for value in [[2, 1, 2], [1, 1, 2], [1, 2, 2]]: - inputs.append( - [tritonhttpclient.InferInput('RAGGED_INPUT', value, "FP32")]) - inputs[-1][0].set_data_from_numpy( - np.full(value, value[0], np.float32)) - client = tritonhttpclient.InferenceServerClient(url="localhost:8000", - concurrency=len(inputs)) + self.set_inputs_for_batch_item([[2, 1, 2], [1, 1, 2], [1, 2, 2]], + "RAGGED_INPUT") expected_outputs = [ np.array([[1.0, 2.0], [1.0, 2.0]]), @@ -307,23 +339,26 @@ def test_batch_item_shape(self): ] model_name = "batch_item" + user_data = queue.Queue() + self.client.start_stream( + callback=partial(self.client_callback, user_data)) output_name = 'BATCH_OUTPUT' - outputs = [tritonhttpclient.InferRequestedOutput(output_name)] + outputs = [grpcclient.InferRequestedOutput(output_name)] async_requests = [] try: - for request_inputs in inputs: + for input in self.inputs: # Asynchronous inference call. async_requests.append( - client.async_infer(model_name=model_name, - inputs=request_inputs, - outputs=outputs)) + self.client.async_stream_infer(model_name=model_name, + inputs=input, + outputs=outputs)) for idx in range(len(async_requests)): # Get the result from the initiated asynchronous inference request. # Note the call will block till the server responds. - result = async_requests[idx].get_result() + result = user_data.get() # Validate the results by comparing with precomputed values. output_data = result.as_numpy(output_name) @@ -334,6 +369,7 @@ def test_batch_item_shape(self): np.isclose(expected_outputs[idx], output_data))) except InferenceServerException as ex: self.assertTrue(False, "unexpected error {}".format(ex)) + self.client.stop_stream() if __name__ == '__main__':