Skip to content

Commit

Permalink
Add testing for variable GPU outputs in Python BE (#5166)
Browse files Browse the repository at this point in the history
* Add testing for variable GPU outputs in Python BE

* Fix PyTorch version upgrade backward incompatiblities
  • Loading branch information
Tabrizian authored Dec 14, 2022
1 parent 2fa2ceb commit 95e386b
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 54 deletions.
103 changes: 65 additions & 38 deletions qa/L0_backend_python/io/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,43 +59,42 @@ class IOTest(tu.TestResultCollector):

def setUp(self):
self._shm_leak_detector = shm_util.ShmLeakDetector()
self._client = grpcclient.InferenceServerClient("localhost:8001")

def _run_ensemble_test(self):
model_name = "ensemble_io"
user_data = UserData()
with grpcclient.InferenceServerClient("localhost:8001") as client:
input0 = np.random.random([1000]).astype(np.float32)
client.start_stream(callback=partial(callback, user_data))
for model_1_in_gpu in [True, False]:
for model_2_in_gpu in [True, False]:
for model_3_in_gpu in [True, False]:
gpu_output = np.asarray(
[model_1_in_gpu, model_2_in_gpu, model_3_in_gpu],
dtype=bool)
inputs = [
grpcclient.InferInput(
"INPUT0", input0.shape,
np_to_triton_dtype(input0.dtype)),
grpcclient.InferInput(
"GPU_OUTPUT", gpu_output.shape,
np_to_triton_dtype(gpu_output.dtype))
]
inputs[0].set_data_from_numpy(input0)
inputs[1].set_data_from_numpy(gpu_output)
client.async_stream_infer(model_name=model_name,
inputs=inputs)
if TRIAL == 'default':
input0 = np.random.random([1000]).astype(np.float32)
self._client.start_stream(callback=partial(callback, user_data))
for model_1_in_gpu in [True, False]:
for model_2_in_gpu in [True, False]:
for model_3_in_gpu in [True, False]:
gpu_output = np.asarray(
[model_1_in_gpu, model_2_in_gpu, model_3_in_gpu],
dtype=bool)
inputs = [
grpcclient.InferInput("INPUT0", input0.shape,
np_to_triton_dtype(input0.dtype)),
grpcclient.InferInput(
"GPU_OUTPUT", gpu_output.shape,
np_to_triton_dtype(gpu_output.dtype))
]
inputs[0].set_data_from_numpy(input0)
inputs[1].set_data_from_numpy(gpu_output)
self._client.async_stream_infer(model_name=model_name,
inputs=inputs)
if TRIAL == 'default':
result = user_data._completed_requests.get()
output0 = result.as_numpy('OUTPUT0')
self.assertIsNotNone(output0)
self.assertTrue(np.all(output0 == input0))
else:
response_repeat = 2
for _ in range(response_repeat):
result = user_data._completed_requests.get()
output0 = result.as_numpy('OUTPUT0')
self.assertIsNotNone(output0)
self.assertTrue(np.all(output0 == input0))
else:
response_repeat = 2
for _ in range(response_repeat):
result = user_data._completed_requests.get()
output0 = result.as_numpy('OUTPUT0')
self.assertIsNotNone(output0)
self.assertTrue(np.all(output0 == input0))

def test_ensemble_io(self):
# Only run the shared memory leak detection with the default trial
Expand All @@ -107,17 +106,45 @@ def test_ensemble_io(self):

def test_empty_gpu_output(self):
model_name = 'dlpack_empty_output'
with httpclient.InferenceServerClient("localhost:8000") as client:
input_data = np.array([[1.0]], dtype=np.float32)
inputs = [
httpclient.InferInput("INPUT", input_data.shape,
np_to_triton_dtype(input_data.dtype))
]
inputs[0].set_data_from_numpy(input_data)
result = client.infer(model_name, inputs)
input_data = np.array([[1.0]], dtype=np.float32)
inputs = [
grpcclient.InferInput("INPUT", input_data.shape,
np_to_triton_dtype(input_data.dtype))
]
inputs[0].set_data_from_numpy(input_data)
result = self._client.infer(model_name, inputs)
output = result.as_numpy('OUTPUT')
self.assertIsNotNone(output)
self.assertEqual(output.size, 0)

def test_variable_gpu_output(self):
# Input is not important in this test
model_name = 'variable_gpu_output'
input_data = np.array([[1.0]], dtype=np.float32)
inputs = [
grpcclient.InferInput("INPUT", input_data.shape,
np_to_triton_dtype(input_data.dtype))
]
inputs[0].set_data_from_numpy(input_data)
user_data = UserData()

# The test sends five requests to the model and the model returns five
# responses with different GPU output shapes
num_requests = 5
for _ in range(num_requests):
result = self._client.async_infer(model_name=model_name,
inputs=inputs,
callback=partial(
callback, user_data))

for i in range(num_requests):
result = user_data._completed_requests.get()
if result is InferenceServerException:
self.assertTrue(False, result)
output = result.as_numpy('OUTPUT')
self.assertIsNotNone(output)
self.assertEqual(output.size, 0)
self.assertEqual(output.size, i + 1)
np.testing.assert_almost_equal(output, np.ones(i + 1) * (i + 1))


if __name__ == '__main__':
Expand Down
32 changes: 31 additions & 1 deletion qa/L0_backend_python/io/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ for trial in $TRIALS; do
RET=1
fi
fi

set -e

kill $SERVER_PID
Expand Down Expand Up @@ -130,7 +129,38 @@ else
RET=1
fi
fi
set -e

kill $SERVER_PID
wait $SERVER_PID

# IOTest.test_variable_gpu_output
rm -rf models && mkdir models
mkdir -p models/variable_gpu_output/1/
cp ../../python_models/variable_gpu_output/model.py ./models/variable_gpu_output/1/
cp ../../python_models/variable_gpu_output/config.pbtxt ./models/variable_gpu_output/

run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
RET=1
fi

set +e
python3 $UNITTEST_PY IOTest.test_variable_gpu_output > $CLIENT_LOG.test_variable_gpu_output
if [ $? -ne 0 ]; then
echo -e "\n***\n*** IOTest.variable_gpu_output FAILED. \n***"
cat $CLIENT_LOG.test_variable_gpu_output
RET=1
else
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
if [ $? -ne 0 ]; then
cat $CLIENT_LOG.test_variable_gpu_output
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi
set -e

kill $SERVER_PID
Expand Down
23 changes: 8 additions & 15 deletions qa/python_models/dlpack_test/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,11 @@ def test_pytorch_dlpack(self):
# Test different dtypes
pytorch_dtypes = [
torch.float16, torch.float32, torch.float64, torch.int8,
torch.int16, torch.int32, torch.int64, torch.uint8, torch.bool
torch.int16, torch.int32, torch.int64, torch.uint8
]

for pytorch_dtype in pytorch_dtypes:
pytorch_tensor = torch.rand([100], dtype=torch.float16) * 100
pytorch_tensor = pytorch_tensor.type(pytorch_dtype)
pytorch_tensor = torch.ones([100], dtype=pytorch_dtype)
dlpack_tensor = to_dlpack(pytorch_tensor)
pb_tensor = pb_utils.Tensor.from_dlpack('test_tensor',
dlpack_tensor)
Expand All @@ -54,14 +53,8 @@ def test_pytorch_dlpack(self):
pytorch_tensor_dlpack = from_dlpack(pb_tensor.to_dlpack())
self.assertTrue(torch.all(pytorch_tensor_dlpack == pytorch_tensor))

# DLPack does not properly support bool type:
# https://github.com/google/jax/issues/4719
if pytorch_dtype != torch.bool:
self.assertTrue(
pytorch_tensor.type() == pytorch_tensor_dlpack.type())
else:
self.assertFalse(
pytorch_tensor.type() == pytorch_tensor_dlpack.type())
self.assertTrue(
pytorch_tensor.type() == pytorch_tensor_dlpack.type())

def test_non_contiguous_error(self):
pytorch_tensor = torch.rand([20, 30], dtype=torch.float16)
Expand Down Expand Up @@ -92,13 +85,13 @@ def test_dlpack_gpu_tensors(self):
# Test different dtypes
pytorch_dtypes = [
torch.float16, torch.float32, torch.float64, torch.int8,
torch.int16, torch.int32, torch.int64, torch.uint8, torch.bool
torch.int16, torch.int32, torch.int64, torch.uint8
]

for pytorch_dtype in pytorch_dtypes:
pytorch_tensor = torch.rand(
[100], dtype=torch.float16, device='cuda') * 100
pytorch_tensor = pytorch_tensor.type(pytorch_dtype)
pytorch_tensor = torch.ones([100],
dtype=pytorch_dtype,
device='cuda')
dlpack_tensor = to_dlpack(pytorch_tensor)
pb_tensor = pb_utils.Tensor.from_dlpack('test_tensor',
dlpack_tensor)
Expand Down
55 changes: 55 additions & 0 deletions qa/python_models/variable_gpu_output/config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

name: "variable_gpu_output"
backend: "python"
max_batch_size: 256

input [
{
name: "INPUT"
data_type: TYPE_FP32
dims: [ 1 ]
}
]
output [
{
name: "OUTPUT"
data_type: TYPE_FP32
dims: [ -1 ]
}
]

dynamic_batching {
max_queue_delay_microseconds: 1000000
}

instance_group [
{
count: 1
kind: KIND_GPU
}
]
47 changes: 47 additions & 0 deletions qa/python_models/variable_gpu_output/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import triton_python_backend_utils as pb_utils
import torch
from torch.utils.dlpack import to_dlpack


class TritonPythonModel:

def execute(self, requests):
# The client will send 5 requests
assert (len(requests) == 5)
responses = []
for i, request in enumerate(requests):
# Create an (i+1)-element array with all the tensors equal to (i+1)
output = torch.ones(i + 1, dtype=torch.float32, device='cuda')
output = output * (i + 1)
output_pb_tensor = pb_utils.Tensor.from_dlpack(
"OUTPUT", to_dlpack(output))
inference_response = pb_utils.InferenceResponse(
output_tensors=[output_pb_tensor])
responses.append(inference_response)
return responses

0 comments on commit 95e386b

Please sign in to comment.