Skip to content

Commit

Permalink
Add testing for empty GPU output tensor with CUDA device setting (#4921)
Browse files Browse the repository at this point in the history
* Add testing for empty gpu output tensor with cuda device setting

* Fix up
  • Loading branch information
krishung5 authored Sep 29, 2022
1 parent d16fb98 commit da0d0f2
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 7 deletions.
21 changes: 18 additions & 3 deletions qa/L0_backend_python/io/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import test_util as tu
import shm_util
import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient
from tritonclient.utils import *
import numpy as np
import unittest
Expand Down Expand Up @@ -59,7 +60,7 @@ class IOTest(tu.TestResultCollector):
def setUp(self):
self._shm_leak_detector = shm_util.ShmLeakDetector()

def _run_test(self):
def _run_ensemble_test(self):
model_name = "ensemble_io"
user_data = UserData()
with grpcclient.InferenceServerClient("localhost:8001") as client:
Expand Down Expand Up @@ -100,9 +101,23 @@ def test_ensemble_io(self):
# Only run the shared memory leak detection with the default trial
if TRIAL == 'default':
with self._shm_leak_detector.Probe():
self._run_test()
self._run_ensemble_test()
else:
self._run_test()
self._run_ensemble_test()

def test_empty_gpu_output(self):
model_name = 'dlpack_empty_output'
with httpclient.InferenceServerClient("localhost:8000") as client:
input_data = np.array([[1.0]], dtype=np.float32)
inputs = [
httpclient.InferInput("INPUT", input_data.shape,
np_to_triton_dtype(input_data.dtype))
]
inputs[0].set_data_from_numpy(input_data)
result = client.infer(model_name, inputs)
output = result.as_numpy('OUTPUT')
self.assertIsNotNone(output)
self.assertEqual(output.size, 0)


if __name__ == '__main__':
Expand Down
42 changes: 38 additions & 4 deletions qa/L0_backend_python/io/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ rm -fr *.log ./models
pip3 uninstall -y torch
pip3 install torch==1.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html

# IOTest.test_ensemble_io
TRIALS="default decoupled"

for trial in $TRIALS; do
Expand Down Expand Up @@ -82,15 +83,15 @@ for trial in $TRIALS; do
fi

set +e
python3 $UNITTEST_PY > $CLIENT_LOG
python3 $UNITTEST_PY IOTest.test_ensemble_io > $CLIENT_LOG.test_ensemble_io
if [ $? -ne 0 ]; then
echo -e "\n***\n*** io_test.py FAILED. \n***"
cat $CLIENT_LOG
echo -e "\n***\n*** IOTest.test_ensemble_io FAILED. \n***"
cat $CLIENT_LOG.test_ensemble_io
RET=1
else
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
cat $CLIENT_LOG.test_ensemble_io
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
Expand All @@ -102,6 +103,39 @@ for trial in $TRIALS; do
wait $SERVER_PID
done

# IOTest.test_empty_gpu_output
rm -rf models && mkdir models
mkdir -p models/dlpack_empty_output/1/
cp ../../python_models/dlpack_empty_output/model.py ./models/dlpack_empty_output/1/
cp ../../python_models/dlpack_empty_output/config.pbtxt ./models/dlpack_empty_output/

run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
RET=1
fi

set +e
python3 $UNITTEST_PY IOTest.test_empty_gpu_output > $CLIENT_LOG.test_empty_gpu_output
if [ $? -ne 0 ]; then
echo -e "\n***\n*** IOTest.test_empty_gpu_output FAILED. \n***"
cat $CLIENT_LOG.test_empty_gpu_output
RET=1
else
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
if [ $? -ne 0 ]; then
cat $CLIENT_LOG.test_empty_gpu_output
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi

set -e

kill $SERVER_PID
wait $SERVER_PID

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** IO test PASSED.\n***"
else
Expand Down
43 changes: 43 additions & 0 deletions qa/python_models/dlpack_empty_output/config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

name: "dlpack_empty_output"
max_batch_size: 8

input [
{
name: "INPUT"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
output [
{
name: "OUTPUT"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
55 changes: 55 additions & 0 deletions qa/python_models/dlpack_empty_output/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import torch
import triton_python_backend_utils as pb_utils
from torch.utils.dlpack import to_dlpack


class TritonPythonModel:

def initialize(self, args):
pass

def execute(self, requests):
responses = []

for _ in requests:
SHAPE = (0,)

pytorch_tensor = torch.ones(SHAPE, dtype=torch.float32)

device = torch.device("cuda:0")
pytorch_tensor = pytorch_tensor.to(device)

dlpack_tensor = to_dlpack(pytorch_tensor)
pb_tensor = pb_utils.Tensor.from_dlpack('OUTPUT', dlpack_tensor)

inference_response = pb_utils.InferenceResponse(
output_tensors=[pb_tensor])
responses.append(inference_response)

return responses

0 comments on commit da0d0f2

Please sign in to comment.