Skip to content

Commit

Permalink
Add test to check the output memory type for onnx models (triton-infe…
Browse files Browse the repository at this point in the history
…rence-server#6033)

* Add test to check the output memory type for onnx models

* Remove unused import

* Address comment
  • Loading branch information
krishung5 authored Jul 7, 2023
1 parent fd96f23 commit 0049763
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 3 deletions.
85 changes: 82 additions & 3 deletions qa/L0_warmup/test.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -42,6 +42,9 @@ export CUDA_VISIBLE_DEVICES=0

CLIENT=../clients/image_client
CLIENT_LOG="./client.log"
CLIENT_PY=./python_unittest.py
EXPECTED_NUM_TESTS="1"
TEST_RESULT_FILE='test_results.txt'

IMAGE="../images/vulture.jpeg"

Expand All @@ -56,6 +59,7 @@ SERVER_LOG="./inference_server.log"
source ../common/util.sh

RET=0
rm -fr *.txt

for BACKEND in ${BACKENDS}; do
rm -f $SERVER_LOG $CLIENT_LOG
Expand Down Expand Up @@ -408,8 +412,83 @@ set -e
kill $SERVER_PID
wait $SERVER_PID

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** Test Passed\n***"
# Test the onnx model to verify that the memory type of the output tensor
# remains unchanged with the warmup setting
pip3 uninstall -y torch
pip3 install torch==1.13.0+cu117 -f https://download.pytorch.org/whl/torch_stable.html

rm -fr models && mkdir models
cp -r /data/inferenceserver/${REPO_VERSION}/qa_model_repository/onnx_nobatch_float32_float32_float32 models/.
(cd models/onnx_nobatch_float32_float32_float32 && \
echo "" >> config.pbtxt && \
echo 'instance_group [{' >> config.pbtxt && \
echo ' kind : KIND_GPU' >> config.pbtxt && \
echo '}]' >> config.pbtxt && \
echo 'model_warmup [{' >> config.pbtxt && \
echo ' name : "sample"' >> config.pbtxt && \
echo ' batch_size: 1' >> config.pbtxt && \
echo ' inputs {' >> config.pbtxt && \
echo ' key: "INPUT0"' >> config.pbtxt && \
echo ' value: {' >> config.pbtxt && \
echo ' data_type: TYPE_FP32' >> config.pbtxt && \
echo " dims: 16" >> config.pbtxt && \
echo " zero_data: false" >> config.pbtxt && \
echo ' }' >> config.pbtxt && \
echo ' }' >> config.pbtxt && \
echo ' inputs {' >> config.pbtxt && \
echo ' key: "INPUT1"' >> config.pbtxt && \
echo ' value: {' >> config.pbtxt && \
echo ' data_type: TYPE_FP32' >> config.pbtxt && \
echo " dims: 16" >> config.pbtxt && \
echo " zero_data: false" >> config.pbtxt && \
echo ' }' >> config.pbtxt && \
echo ' }' >> config.pbtxt && \
echo '}]' >> config.pbtxt )

mkdir -p models/bls_onnx_warmup/1/
cp ../python_models/bls_onnx_warmup/model.py models/bls_onnx_warmup/1/
cp ../python_models/bls_onnx_warmup/config.pbtxt models/bls_onnx_warmup/.

cp ../L0_backend_python/python_unittest.py .
sed -i 's#sys.path.append("../../common")#sys.path.append("../common")#g' python_unittest.py

run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi

set +e

export MODEL_NAME='bls_onnx_warmup'
python3 $CLIENT_PY >> $CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
echo -e "\n***\n*** 'bls_onnx_warmup' test FAILED. \n***"
cat $CLIENT_LOG
RET=1
else
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi

set -e


kill $SERVER_PID
wait $SERVER_PID


if [ $RET -eq 1 ]; then
cat $CLIENT_LOG
cat $SERVER_LOG
echo -e "\n***\n*** Test Failed \n***"
else
echo -e "\n***\n*** Test Passed \n***"
fi

exit $RET
38 changes: 38 additions & 0 deletions qa/python_models/bls_onnx_warmup/config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

name: "bls_onnx_warmup"
backend: "python"

output [
{
name: "OUTPUT0"
data_type: TYPE_FP32
dims: [ 16 ]
}
]

instance_group [{ kind: KIND_CPU }]
87 changes: 87 additions & 0 deletions qa/python_models/bls_onnx_warmup/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
import unittest
import triton_python_backend_utils as pb_utils
from torch.utils.dlpack import from_dlpack


class PBBLSONNXWarmupTest(unittest.TestCase):

def test_onnx_output_mem_type(self):
input0_np = np.random.randn(*[16])
input0_np = input0_np.astype(np.float32)
input1_np = np.random.randn(*[16])
input1_np = input1_np.astype(np.float32)
input0 = pb_utils.Tensor('INPUT0', input0_np)
input1 = pb_utils.Tensor('INPUT1', input1_np)
infer_request = pb_utils.InferenceRequest(
model_name='onnx_nobatch_float32_float32_float32',
inputs=[input0, input1],
requested_output_names=['OUTPUT0', 'OUTPUT1'])

infer_response = infer_request.exec()

self.assertFalse(infer_response.has_error())

output0 = pb_utils.get_output_tensor_by_name(infer_response, 'OUTPUT0')
output1 = pb_utils.get_output_tensor_by_name(infer_response, 'OUTPUT1')

self.assertIsNotNone(output0)
self.assertIsNotNone(output1)

# The memory type of output tensor should be GPU
self.assertFalse(output0.is_cpu())
self.assertFalse(output1.is_cpu())

expected_output_0 = input0.as_numpy() - input1.as_numpy()
expected_output_1 = input0.as_numpy() + input1.as_numpy()

output0 = from_dlpack(
output0.to_dlpack()).to('cpu').cpu().detach().numpy()
output1 = from_dlpack(
output1.to_dlpack()).to('cpu').cpu().detach().numpy()

self.assertTrue(np.all(output0 == expected_output_0))
self.assertTrue(np.all(output1 == expected_output_1))


class TritonPythonModel:

def execute(self, requests):
responses = []
for _ in requests:
# Run the unittest and store the results in InferenceResponse.
test = unittest.main('model', exit=False)
responses.append(
pb_utils.InferenceResponse([
pb_utils.Tensor(
'OUTPUT0',
np.array([test.result.wasSuccessful()],
dtype=np.float16))
]))
return responses

0 comments on commit 0049763

Please sign in to comment.