From 81e5a392240399f3a538e038651dae64f777f502 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian Date: Wed, 15 Mar 2023 15:49:06 -0400 Subject: [PATCH] Add parsing paremters to the HTTP and GRPC frontends (#5490) * Add parsing paremters to the HTTP frontend * Add parameters to the GRPC server * Add testing for parameters * Fix up * Add testing for async and streaming * Add documentation and reserved parameters list * Modify based on feedback --- docs/protocol/README.md | 1 + docs/protocol/extension_parameters.md | 87 +++++++++ .../model_repository/parameter/1/model.py | 67 +++++++ qa/L0_parameters/parameters_test.py | 184 ++++++++++++++++++ qa/L0_parameters/test.sh | 78 ++++++++ src/grpc_server.cc | 34 ++++ src/http_server.cc | 122 +++++++----- 7 files changed, 522 insertions(+), 51 deletions(-) create mode 100644 docs/protocol/extension_parameters.md create mode 100644 qa/L0_parameters/model_repository/parameter/1/model.py create mode 100644 qa/L0_parameters/parameters_test.py create mode 100644 qa/L0_parameters/test.sh diff --git a/docs/protocol/README.md b/docs/protocol/README.md index d28dc2b06a..808751957c 100644 --- a/docs/protocol/README.md +++ b/docs/protocol/README.md @@ -44,6 +44,7 @@ plus several extensions that are defined in the following documents: - [Statistics extension](./extension_statistics.md) - [Trace extension](./extension_trace.md) - [Logging extension](./extension_logging.md) +- [Parameters extension](./extension_parameters.md) For the GRPC protocol, the [protobuf specification](https://github.com/triton-inference-server/common/blob/main/protobuf/grpc_service.proto) diff --git a/docs/protocol/extension_parameters.md b/docs/protocol/extension_parameters.md new file mode 100644 index 0000000000..b22fbe79a8 --- /dev/null +++ b/docs/protocol/extension_parameters.md @@ -0,0 +1,87 @@ + + +# Parameters Extension + +This document describes Triton's parameters extension. The +parameters extension allows an inference request to provide +custom parameters that cannot be provided as inputs. Because this extension is +supported, Triton reports “parameters” in the extensions field of its +Server Metadata. This extension uses the optional "parameters" +field in the KServe Protocol in +[HTTP](https://kserve.github.io/website/0.10/modelserving/data_plane/v2_protocol/#inference-request-json-object) +and +[GRPC](https://kserve.github.io/website/0.10/modelserving/data_plane/v2_protocol/#parameters). + +The following parameters are reserved for Triton's usage and should not be +used as custom parameters: + +- sequence_id +- priority +- timeout +- sequence_start +- sequence_end +- All the keys that start with "triton_" prefix. +- headers + +When using both GRPC and HTTP endpoints, you need to make sure to not use +the reserved parameters list to avoid unexpected behavior. The reserved +parameters are not accessible in the Triton C-API. + +## HTTP/REST + +The following example shows how a request can include custom parameters. + +``` +POST /v2/models/mymodel/infer HTTP/1.1 +Host: localhost:8000 +Content-Type: application/json +Content-Length: +{ + "parameters" : { "my_custom_parameter" : 42 } + "inputs" : [ + { + "name" : "input0", + "shape" : [ 2, 2 ], + "datatype" : "UINT32", + "data" : [ 1, 2, 3, 4 ] + } + ], + "outputs" : [ + { + "name" : "output0", + } + ] +} +``` + +## GRPC + +The `parameters` field in the +ModelInferRequest message can be used to send custom parameters. + diff --git a/qa/L0_parameters/model_repository/parameter/1/model.py b/qa/L0_parameters/model_repository/parameter/1/model.py new file mode 100644 index 0000000000..616b5bbc60 --- /dev/null +++ b/qa/L0_parameters/model_repository/parameter/1/model.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import triton_python_backend_utils as pb_utils +import numpy as np + + +class TritonPythonModel: + + @staticmethod + def auto_complete_config(auto_complete_model_config): + inputs = [{'name': 'INPUT0', 'data_type': 'TYPE_FP32', 'dims': [1]}] + outputs = [{'name': 'OUTPUT0', 'data_type': 'TYPE_STRING', 'dims': [1]}] + + config = auto_complete_model_config.as_dict() + input_names = [] + output_names = [] + for input in config['input']: + input_names.append(input['name']) + for output in config['output']: + output_names.append(output['name']) + + for input in inputs: + if input['name'] not in input_names: + auto_complete_model_config.add_input(input) + for output in outputs: + if output['name'] not in output_names: + auto_complete_model_config.add_output(output) + + auto_complete_model_config.set_max_batch_size(0) + return auto_complete_model_config + + def execute(self, requests): + # A simple model that puts the parameters in the in the request in the + # output. + responses = [] + for request in requests: + output0 = np.asarray([request.parameters()], dtype=object) + output_tensor = pb_utils.Tensor("OUTPUT0", output0) + inference_response = pb_utils.InferenceResponse( + output_tensors=[output_tensor]) + responses.append(inference_response) + + return responses diff --git a/qa/L0_parameters/parameters_test.py b/qa/L0_parameters/parameters_test.py new file mode 100644 index 0000000000..7fccaaf973 --- /dev/null +++ b/qa/L0_parameters/parameters_test.py @@ -0,0 +1,184 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys + +sys.path.append("../common") + +import numpy as np +import infer_util as iu +import test_util as tu +import tritonclient.http as httpclient +import tritonclient.grpc as grpcclient +import tritonclient.http.aio as asynchttpclient +import tritonclient.grpc.aio as asyncgrpcclient +from tritonclient.utils import InferenceServerException +from unittest import IsolatedAsyncioTestCase +import json +import unittest +import queue +from functools import partial + + +class InferenceParametersTest(IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + self.http = httpclient.InferenceServerClient(url='localhost:8000') + self.async_http = asynchttpclient.InferenceServerClient( + url='localhost:8000') + self.grpc = grpcclient.InferenceServerClient(url='localhost:8001') + self.async_grpc = asyncgrpcclient.InferenceServerClient( + url='localhost:8001') + + self.parameter_list = [] + self.parameter_list.append({'key1': 'value1', 'key2': 'value2'}) + self.parameter_list.append({'key1': 1, 'key2': 2}) + self.parameter_list.append({'key1': True, 'key2': 'value2'}) + self.parameter_list.append({'triton_': True, 'key2': 'value2'}) + + def callback(user_data, result, error): + if error: + user_data.put(error) + else: + user_data.put(result) + + self.grpc_callback = callback + + def create_inputs(self, client_type): + inputs = [] + inputs.append(client_type.InferInput('INPUT0', [1], "FP32")) + + # Initialize the data + inputs[0].set_data_from_numpy(np.asarray([1], dtype=np.float32)) + return inputs + + async def send_request_and_verify(self, + client_type, + client, + is_async=False): + + inputs = self.create_inputs(client_type) + for parameters in self.parameter_list: + # The `triton_` prefix is reserved for Triton usage + should_error = False + if 'triton_' in parameters.keys(): + should_error = True + + if is_async: + if should_error: + with self.assertRaises(InferenceServerException): + result = await client.infer(model_name='parameter', + inputs=inputs, + parameters=parameters) + return + else: + result = await client.infer(model_name='parameter', + inputs=inputs, + parameters=parameters) + + else: + if should_error: + with self.assertRaises(InferenceServerException): + result = client.infer(model_name='parameter', + inputs=inputs, + parameters=parameters) + return + else: + result = client.infer(model_name='parameter', + inputs=inputs, + parameters=parameters) + + self.verify_outputs(result, parameters) + + def verify_outputs(self, result, parameters): + result = result.as_numpy('OUTPUT0') + self.assertEqual(json.loads(result[0]), parameters) + + async def test_grpc_parameter(self): + await self.send_request_and_verify(grpcclient, self.grpc) + + async def test_http_parameter(self): + await self.send_request_and_verify(httpclient, self.http) + + async def test_async_http_parameter(self): + await self.send_request_and_verify(asynchttpclient, + self.async_http, + is_async=True) + + async def test_async_grpc_parameter(self): + await self.send_request_and_verify(asyncgrpcclient, + self.async_grpc, + is_async=True) + + def test_http_async_parameter(self): + inputs = self.create_inputs(httpclient) + # Skip the parameter that returns an error + parameter_list = self.parameter_list[:-1] + for parameters in parameter_list: + result = self.http.async_infer(model_name='parameter', + inputs=inputs, + parameters=parameters).get_result() + self.verify_outputs(result, parameters) + + def test_grpc_async_parameter(self): + user_data = queue.Queue() + inputs = self.create_inputs(grpcclient) + # Skip the parameter that returns an error + parameter_list = self.parameter_list[:-1] + for parameters in parameter_list: + self.grpc.async_infer(model_name='parameter', + inputs=inputs, + parameters=parameters, + callback=partial(self.grpc_callback, + user_data)) + result = user_data.get() + self.assertFalse(result is InferenceServerException) + self.verify_outputs(result, parameters) + + def test_grpc_stream_parameter(self): + user_data = queue.Queue() + self.grpc.start_stream(callback=partial(self.grpc_callback, user_data)) + inputs = self.create_inputs(grpcclient) + # Skip the parameter that returns an error + parameter_list = self.parameter_list[:-1] + for parameters in parameter_list: + self.grpc.async_stream_infer(model_name='parameter', + inputs=inputs, + parameters=parameters) + result = user_data.get() + self.assertFalse(result is InferenceServerException) + self.verify_outputs(result, parameters) + self.grpc.stop_stream() + + async def asyncTearDown(self): + self.http.close() + self.grpc.close() + await self.async_grpc.close() + await self.async_http.close() + + +if __name__ == '__main__': + unittest.main() diff --git a/qa/L0_parameters/test.sh b/qa/L0_parameters/test.sh new file mode 100644 index 0000000000..5def250177 --- /dev/null +++ b/qa/L0_parameters/test.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +REPO_VERSION=${NVIDIA_TRITON_SERVER_VERSION} +if [ "$#" -ge 1 ]; then + REPO_VERSION=$1 +fi +if [ -z "$REPO_VERSION" ]; then + echo -e "Repository version must be specified" + echo -e "\n***\n*** Test Failed\n***" + exit 1 +fi +if [ ! -z "$TEST_REPO_ARCH" ]; then + REPO_VERSION=${REPO_VERSION}_${TEST_REPO_ARCH} +fi + +CLIENT_LOG="./client.log" +TEST_SCRIPT_PY="parameters_test.py" +EXPECTED_NUM_TESTS="4" + +SERVER=/opt/tritonserver/bin/tritonserver +SERVER_ARGS="--model-repository=model_repository --exit-timeout-secs=120" +SERVER_LOG="./inference_server.log" +source ../common/util.sh + +RET=0 + +run_server +if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 +fi + +set +e +python3 $TEST_SCRIPT_PY >$CLIENT_LOG 2>&1 +if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Failed\n***" + RET=1 +fi +set -e + +kill $SERVER_PID +wait $SERVER_PID + +if [ $RET -eq 0 ]; then + echo -e "\n***\n*** Test Passed\n***" +else + cat $CLIENT_LOG + echo -e "\n***\n*** Test FAILED\n***" +fi + +exit $RET diff --git a/src/grpc_server.cc b/src/grpc_server.cc index 97bbd0628a..9d4539726f 100644 --- a/src/grpc_server.cc +++ b/src/grpc_server.cc @@ -28,6 +28,7 @@ #include #include + #include #include #include @@ -38,6 +39,7 @@ #include #include #include + #include "classification.h" #include "common.h" #include "grpc++/grpc++.h" @@ -3615,6 +3617,38 @@ SetInferenceRequestMetadata( } RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetTimeoutMicroseconds( inference_request, infer_param.int64_param())); + } else if (param.first.rfind("triton_", 0) == 0) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + ("parameter keys starting with 'triton_' are reserved for Triton " + "usage " + "and should not be specified.")); + } else { + const auto& infer_param = param.second; + if (infer_param.parameter_choice_case() == + inference::InferParameter::ParameterChoiceCase::kInt64Param) { + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetIntParameter( + inference_request, param.first.c_str(), infer_param.int64_param())); + } else if ( + infer_param.parameter_choice_case() == + inference::InferParameter::ParameterChoiceCase::kBoolParam) { + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetBoolParameter( + inference_request, param.first.c_str(), infer_param.bool_param())); + } else if ( + infer_param.parameter_choice_case() == + inference::InferParameter::ParameterChoiceCase::kStringParam) { + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetStringParameter( + inference_request, param.first.c_str(), + infer_param.string_param().c_str())); + } else { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "invalid value type for '" + param.first + + "' parameter, expected " + "int64_param, bool_param, or string_param.") + .c_str()); + } } } diff --git a/src/http_server.cc b/src/http_server.cc index d1af9aa283..d1c2a83cd8 100644 --- a/src/http_server.cc +++ b/src/http_server.cc @@ -32,9 +32,11 @@ #include #include + #include #include #include + #include "classification.h" #define TRITONJSON_STATUSTYPE TRITONSERVER_Error* @@ -2354,82 +2356,100 @@ HTTPAPIServer::EVBufferToInput( AllocPayload::OutputInfo::Kind default_output_kind = AllocPayload::OutputInfo::JSON; - // Set sequence correlation ID and flags if any triton::common::TritonJson::Value params_json; - if (request_json.Find("parameters", ¶ms_json)) { - triton::common::TritonJson::Value seq_json; - if (params_json.Find("sequence_id", &seq_json)) { - // Try to parse sequence_id as uint64_t - uint64_t seq_id; - if (seq_json.AsUInt(&seq_id) != nullptr) { - // On failure try to parse as a string - std::string seq_id; - RETURN_MSG_IF_ERR( - seq_json.AsString(&seq_id), "Unable to parse 'sequence_id'"); - RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetCorrelationIdString( - irequest, seq_id.c_str())); - } else { - RETURN_IF_ERR( - TRITONSERVER_InferenceRequestSetCorrelationId(irequest, seq_id)); - } - } + if (request_json.MemberAsObject("parameters", ¶ms_json) == nullptr) { + std::vector parameters; + RETURN_MSG_IF_ERR( + params_json.Members(¶meters), "failed to get request params."); uint32_t flags = 0; - - { - triton::common::TritonJson::Value start_json; - if (params_json.Find("sequence_start", &start_json)) { + for (auto& parameter : parameters) { + if (parameter == "sequence_id") { + uint64_t seq_id; + // Try to parse sequence_id as uint64_t + if (params_json.MemberAsUInt(parameter.c_str(), &seq_id) != nullptr) { + // On failure try to parse as a string + std::string seq_id; + RETURN_MSG_IF_ERR( + params_json.MemberAsString(parameter.c_str(), &seq_id), + "Unable to parse 'sequence_id'"); + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetCorrelationIdString( + irequest, seq_id.c_str())); + } else { + RETURN_IF_ERR( + TRITONSERVER_InferenceRequestSetCorrelationId(irequest, seq_id)); + } + } else if (parameter == "sequence_start") { bool start; RETURN_MSG_IF_ERR( - start_json.AsBool(&start), "Unable to parse 'sequence_start'"); + params_json.MemberAsBool(parameter.c_str(), &start), + "Unable to parse 'sequence_start'"); if (start) { flags |= TRITONSERVER_REQUEST_FLAG_SEQUENCE_START; } - } - - triton::common::TritonJson::Value end_json; - if (params_json.Find("sequence_end", &end_json)) { + } else if (parameter == "sequence_end") { bool end; RETURN_MSG_IF_ERR( - end_json.AsBool(&end), "Unable to parse 'sequence_end'"); + params_json.MemberAsBool(parameter.c_str(), &end), + "Unable to parse 'sequence_end'"); if (end) { flags |= TRITONSERVER_REQUEST_FLAG_SEQUENCE_END; } - } - } - - RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetFlags(irequest, flags)); - - { - triton::common::TritonJson::Value priority_json; - if (params_json.Find("priority", &priority_json)) { + } else if (parameter == "priority") { uint64_t p; RETURN_MSG_IF_ERR( - priority_json.AsUInt(&p), "Unable to parse 'priority'"); + params_json.MemberAsUInt(parameter.c_str(), &p), + "Unable to parse 'priority'"); RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetPriority(irequest, p)); - } - } - - { - triton::common::TritonJson::Value timeout_json; - if (params_json.Find("timeout", &timeout_json)) { + } else if (parameter == "timeout") { uint64_t t; - RETURN_MSG_IF_ERR(timeout_json.AsUInt(&t), "Unable to parse 'timeout'"); + RETURN_MSG_IF_ERR( + params_json.MemberAsUInt(parameter.c_str(), &t), + "Unable to parse 'timeout'"); RETURN_IF_ERR( TRITONSERVER_InferenceRequestSetTimeoutMicroseconds(irequest, t)); - } - } - - { - triton::common::TritonJson::Value bdo_json; - if (params_json.Find("binary_data_output", &bdo_json)) { + } else if (parameter == "binary_data_output") { bool bdo; RETURN_MSG_IF_ERR( - bdo_json.AsBool(&bdo), "Unable to parse 'binary_data_output'"); + params_json.MemberAsBool(parameter.c_str(), &bdo), + "Unable to parse 'binary_data_output'"); default_output_kind = (bdo) ? AllocPayload::OutputInfo::BINARY : AllocPayload::OutputInfo::JSON; + } else if (parameter.rfind("triton_", 0) == 0) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + ("parameter keys starting with 'triton_' are reserved for Triton " + "usage " + "and should not be specified.")); + } else { + std::string string_value; + int64_t int_value; + bool bool_value; + if (params_json.MemberAsString(parameter.c_str(), &string_value) == + nullptr) { + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetStringParameter( + irequest, parameter.c_str(), string_value.c_str())); + } else if ( + params_json.MemberAsInt(parameter.c_str(), &int_value) == nullptr) { + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetIntParameter( + irequest, parameter.c_str(), int_value)); + } else if ( + params_json.MemberAsBool(parameter.c_str(), &bool_value) == + nullptr) { + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetBoolParameter( + irequest, parameter.c_str(), bool_value)); + } else { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + ("parameter '" + parameter + + "' has invalid type. It should be either " + "'int', 'bool', or 'string'.") + .c_str()); + } } } + + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetFlags(irequest, flags)); } // Get the byte-size for each input and from that get the blocks