Skip to content

Commit

Permalink
Improve parameters test (triton-inference-server#5514)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tabrizian authored Mar 21, 2023
1 parent 4f25388 commit 8f03231
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
24 changes: 20 additions & 4 deletions qa/L0_parameters/model_repository/parameter/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,23 @@

import triton_python_backend_utils as pb_utils
import numpy as np
import json


class TritonPythonModel:

@staticmethod
def auto_complete_config(auto_complete_model_config):
inputs = [{'name': 'INPUT0', 'data_type': 'TYPE_FP32', 'dims': [1]}]
outputs = [{'name': 'OUTPUT0', 'data_type': 'TYPE_STRING', 'dims': [1]}]
outputs = [{
'name': 'key',
'data_type': 'TYPE_STRING',
'dims': [-1]
}, {
'name': 'value',
'data_type': 'TYPE_STRING',
'dims': [-1]
}]

config = auto_complete_model_config.as_dict()
input_names = []
Expand All @@ -58,10 +67,17 @@ def execute(self, requests):
# output.
responses = []
for request in requests:
output0 = np.asarray([request.parameters()], dtype=object)
output_tensor = pb_utils.Tensor("OUTPUT0", output0)
parameters = json.loads(request.parameters())
keys = []
values = []
for key, value in parameters.items():
keys.append(key)
values.append(value)
key_output = pb_utils.Tensor("key", np.asarray(keys, dtype=object))
value_output = pb_utils.Tensor("value",
np.asarray(values, dtype=object))
inference_response = pb_utils.InferenceResponse(
output_tensors=[output_tensor])
output_tensors=[key_output, value_output])
responses.append(inference_response)

return responses
12 changes: 10 additions & 2 deletions qa/L0_parameters/parameters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,16 @@ async def send_request_and_verify(self,
self.verify_outputs(result, parameters)

def verify_outputs(self, result, parameters):
result = result.as_numpy('OUTPUT0')
self.assertEqual(json.loads(result[0]), parameters)
keys = result.as_numpy('key')
values = result.as_numpy('value')
self.assertEqual(set(keys.astype(str).tolist()),
set(list(parameters.keys())))

# We have to convert the parameter values to string
expected_values = []
for expected_value in list(parameters.values()):
expected_values.append(str(expected_value))
self.assertEqual(set(values.astype(str).tolist()), set(expected_values))

async def test_grpc_parameter(self):
await self.send_request_and_verify(grpcclient, self.grpc)
Expand Down

0 comments on commit 8f03231

Please sign in to comment.