diff --git a/morpheus/_lib/include/morpheus/stages/triton_inference.hpp b/morpheus/_lib/include/morpheus/stages/triton_inference.hpp index 1cc8af06af..5024d18a56 100644 --- a/morpheus/_lib/include/morpheus/stages/triton_inference.hpp +++ b/morpheus/_lib/include/morpheus/stages/triton_inference.hpp @@ -22,6 +22,7 @@ #include "morpheus/stages/inference_client_stage.hpp" #include "morpheus/types.hpp" +#include #include #include @@ -29,6 +30,7 @@ // IWYU pragma: no_include "rxcpp/sources/rx-iterate.hpp" #include +#include #include #include @@ -106,7 +108,11 @@ class MORPHEUS_EXPORT ITritonClient class MORPHEUS_EXPORT HttpTritonClient : public ITritonClient { private: - std::unique_ptr m_client; + std::string m_server_url; + std::mutex m_client_mutex; + boost::fibers::fiber_specific_ptr m_fiber_local_client; + + triton::client::InferenceServerHttpClient& get_client(); public: HttpTritonClient(std::string server_url); diff --git a/morpheus/_lib/messages/module.cpp b/morpheus/_lib/messages/module.cpp index fcaacb2e9d..270e52d0a5 100644 --- a/morpheus/_lib/messages/module.cpp +++ b/morpheus/_lib/messages/module.cpp @@ -42,6 +42,7 @@ #include "morpheus/utilities/string_util.hpp" #include "morpheus/version.hpp" +#include // for COMPACT_GOOGLE_LOG_INFO, LogMessage, VLOG #include #include // for basic_json #include // IWYU pragma: keep @@ -53,17 +54,76 @@ #include // for pymrc::import #include +#include // for size_t #include #include #include #include +#include // IWYU pragma: keep +#include // for type_info +#include // for index_sequence, make_index_sequence #include +// For some reason IWYU thinks the variant header is needed for tuple, and that the array header is needed for +// tuple_element +// IWYU pragma: no_include +// IWYU pragma: no_include namespace morpheus { namespace fs = std::filesystem; namespace py = pybind11; +template +void reg_converter() +{ + mrc::edge::EdgeConnector, std::shared_ptr>::register_converter(); +} + +template +void reg_py_type_helper() +{ + // Register the port util + mrc::pymrc::PortBuilderUtil::register_port_util>(); + + // Register conversion to and from python + mrc::edge::EdgeConnector, mrc::pymrc::PyObjectHolder>::register_converter(); + mrc::edge::EdgeConnector>::register_converter(); +} + +template +void do_register_tuple_index() +{ + static constexpr std::size_t LeftIndex = I / std::tuple_size::value; + static constexpr std::size_t RightIndex = I % std::tuple_size::value; + + using left_t = typename std::tuple_element::type; + using right_t = typename std::tuple_element::type; + + // Only register if one of the types is a subclass of the other + if constexpr (!std::is_same_v && std::is_base_of_v) + { + // Print the registration + VLOG(20) << "[Type Registration]: Registering: " << typeid(left_t).name() << " -> " << typeid(right_t).name(); + reg_converter(); + } + else + { + VLOG(20) << "[Type Registration]: Skipping: " << typeid(left_t).name() << " -> " << typeid(right_t).name(); + } +}; + +template +void register_tuple_index(std::index_sequence /*unused*/) +{ + (do_register_tuple_index(), ...); +} + +template +void register_permutations() +{ + register_tuple_index>(std::make_index_sequence<(sizeof...(TypesT)) * (sizeof...(TypesT))>()); +} + PYBIND11_MODULE(messages, _module) { _module.doc() = R"pbdoc( @@ -86,80 +146,25 @@ PYBIND11_MODULE(messages, _module) // Allows python objects to keep DataTable objects alive py::class_>(_module, "DataTable"); - mrc::pymrc::PortBuilderUtil::register_port_util>(); - mrc::pymrc::PortBuilderUtil::register_port_util>(); - mrc::pymrc::PortBuilderUtil::register_port_util>(); - mrc::pymrc::PortBuilderUtil::register_port_util>(); - mrc::pymrc::PortBuilderUtil::register_port_util>(); - mrc::pymrc::PortBuilderUtil::register_port_util>(); - mrc::pymrc::PortBuilderUtil::register_port_util>(); - mrc::pymrc::PortBuilderUtil::register_port_util>(); - mrc::pymrc::PortBuilderUtil::register_port_util>(); - - // EdgeConnectors for converting between PyObjectHolders and various Message types - mrc::edge::EdgeConnector, - mrc::pymrc::PyObjectHolder>::register_converter(); - mrc::edge::EdgeConnector>::register_converter(); - - mrc::edge::EdgeConnector, mrc::pymrc::PyObjectHolder>::register_converter(); - mrc::edge::EdgeConnector>::register_converter(); - - mrc::edge::EdgeConnector, mrc::pymrc::PyObjectHolder>::register_converter(); - mrc::edge::EdgeConnector>::register_converter(); - - mrc::edge::EdgeConnector, - mrc::pymrc::PyObjectHolder>::register_converter(); - mrc::edge::EdgeConnector>::register_converter(); - - mrc::edge::EdgeConnector, - mrc::pymrc::PyObjectHolder>::register_converter(); - mrc::edge::EdgeConnector>::register_converter(); + // Add type registrations for all our common types + reg_py_type_helper(); + reg_py_type_helper(); + reg_py_type_helper(); + reg_py_type_helper(); + reg_py_type_helper(); + reg_py_type_helper(); + reg_py_type_helper(); + reg_py_type_helper(); + reg_py_type_helper(); // EdgeConnectors for derived classes of MultiMessage to MultiMessage - mrc::edge::EdgeConnector, - std::shared_ptr>::register_converter(); - - mrc::edge::EdgeConnector, - std::shared_ptr>::register_converter(); - - mrc::edge::EdgeConnector, - std::shared_ptr>::register_converter(); - - mrc::edge::EdgeConnector, - std::shared_ptr>::register_converter(); - - mrc::edge::EdgeConnector, - std::shared_ptr>::register_converter(); - - mrc::edge::EdgeConnector, - std::shared_ptr>::register_converter(); - - mrc::edge::EdgeConnector, - std::shared_ptr>::register_converter(); - - mrc::edge::EdgeConnector, - std::shared_ptr>::register_converter(); - - mrc::edge::EdgeConnector, - std::shared_ptr>::register_converter(); - - mrc::edge::EdgeConnector, - std::shared_ptr>::register_converter(); - - mrc::edge::EdgeConnector, - std::shared_ptr>::register_converter(); - - mrc::edge::EdgeConnector, - std::shared_ptr>::register_converter(); - - mrc::edge::EdgeConnector, - std::shared_ptr>::register_converter(); - - mrc::edge::EdgeConnector, - std::shared_ptr>::register_converter(); + register_permutations(); // Tensor Memory classes py::class_>(_module, "TensorMemory") diff --git a/morpheus/_lib/src/stages/triton_inference.cpp b/morpheus/_lib/src/stages/triton_inference.cpp index a78beb5d11..c704c324a6 100644 --- a/morpheus/_lib/src/stages/triton_inference.cpp +++ b/morpheus/_lib/src/stages/triton_inference.cpp @@ -36,9 +36,10 @@ #include // for min #include #include -#include +#include // for function #include #include +#include #include #include // for runtime_error, out_of_range #include @@ -136,76 +137,35 @@ struct TritonInferOperation namespace morpheus { -HttpTritonClient::HttpTritonClient(std::string server_url) +HttpTritonClient::HttpTritonClient(std::string server_url) : m_server_url(std::move(server_url)) { - std::unique_ptr client; - - CHECK_TRITON(triton::client::InferenceServerHttpClient::Create(&client, server_url, false)); - - bool is_server_live; - - auto status = client->IsServerLive(&is_server_live); - - if (not status.IsOk()) - { - std::string new_server_url = server_url; - if (is_default_grpc_port(new_server_url)) - { - LOG(WARNING) << "Failed to connect to Triton at '" << server_url - << "'. Default gRPC port of (8001) was detected but C++ " - "InferenceClientStage uses HTTP protocol. Retrying with default HTTP port (8000)"; - - // We are using the default gRPC port, try the default HTTP - std::unique_ptr unique_client; - - CHECK_TRITON(triton::client::InferenceServerHttpClient::Create(&unique_client, new_server_url, false)); - - client = std::move(unique_client); - - status = client->IsServerLive(&is_server_live); - } - else if (status.Message().find("Unsupported protocol") != std::string::npos) - { - throw std::runtime_error(MORPHEUS_CONCAT_STR( - "Failed to connect to Triton at '" - << server_url - << "'. Received 'Unsupported Protocol' error. Are you using the right port? The C++ " - "InferenceClientStage uses Triton's HTTP protocol instead of gRPC. Ensure you have " - "specified the HTTP port (Default 8000).")); - } - - if (not status.IsOk()) - throw std::runtime_error( - MORPHEUS_CONCAT_STR("Unable to connect to Triton at '" - << server_url << "'. Check the URL and port and ensure the server is running.")); - } - - m_client = std::move(client); + // Force the creation of the client + this->get_client(); } triton::client::Error HttpTritonClient::is_server_live(bool* live) { - return m_client->IsServerLive(live); + return this->get_client().IsServerLive(live); } triton::client::Error HttpTritonClient::is_server_ready(bool* ready) { - return m_client->IsServerReady(ready); + return this->get_client().IsServerReady(ready); } triton::client::Error HttpTritonClient::is_model_ready(bool* ready, std::string& model_name) { - return m_client->IsModelReady(ready, model_name); + return this->get_client().IsModelReady(ready, model_name); } triton::client::Error HttpTritonClient::model_config(std::string* model_config, std::string& model_name) { - return m_client->ModelConfig(model_config, model_name); + return this->get_client().ModelConfig(model_config, model_name); } triton::client::Error HttpTritonClient::model_metadata(std::string* model_metadata, std::string& model_name) { - return m_client->ModelMetadata(model_metadata, model_name); + return this->get_client().ModelMetadata(model_metadata, model_name); } triton::client::Error HttpTritonClient::async_infer(triton::client::InferenceServerHttpClient::OnCompleteFn callback, @@ -240,7 +200,7 @@ triton::client::Error HttpTritonClient::async_infer(triton::client::InferenceSer triton::client::InferResult* result; - auto status = m_client->Infer(&result, options, inference_input_ptrs, inference_output_ptrs); + auto status = this->get_client().Infer(&result, options, inference_input_ptrs, inference_output_ptrs); callback(result); @@ -248,13 +208,13 @@ triton::client::Error HttpTritonClient::async_infer(triton::client::InferenceSer // TODO(cwharris): either fix tests or make this ENV-flagged, as AsyncInfer gives different results. - // return m_client->AsyncInfer( - // [callback](triton::client::InferResult* result) { - // callback(result); - // }, - // options, - // inference_input_ptrs, - // inference_output_ptrs); + // return this->get_client().AsyncInfer( + // [callback](triton::client::InferResult* result) { + // callback(result); + // }, + // options, + // inference_input_ptrs, + // inference_output_ptrs); } TritonInferenceClientSession::TritonInferenceClientSession(std::shared_ptr client, @@ -522,4 +482,66 @@ std::unique_ptr TritonInferenceClient::create_session() return std::make_unique(m_client, m_model_name, m_force_convert_inputs); } +triton::client::InferenceServerHttpClient& HttpTritonClient::get_client() +{ + if (m_fiber_local_client.get() == nullptr) + { + // Block in case we need to change the server_url + std::unique_lock lock(m_client_mutex); + + std::unique_ptr client; + + CHECK_TRITON(triton::client::InferenceServerHttpClient::Create(&client, m_server_url, false)); + + bool is_server_live; + + auto status = client->IsServerLive(&is_server_live); + + if (not status.IsOk()) + { + std::string new_server_url = m_server_url; + // We are using the default gRPC port, try the default HTTP + if (is_default_grpc_port(new_server_url)) + { + LOG(WARNING) << "Failed to connect to Triton at '" << m_server_url + << "'. Default gRPC port of (8001) was detected but C++ " + "InferenceClientStage uses HTTP protocol. Retrying with default HTTP port (8000)"; + + CHECK_TRITON(triton::client::InferenceServerHttpClient::Create(&client, new_server_url, false)); + + status = client->IsServerLive(&is_server_live); + + // If that worked, update the server URL + if (status.IsOk() && is_server_live) + { + m_server_url = new_server_url; + } + } + else if (status.Message().find("Unsupported protocol") != std::string::npos) + { + throw std::runtime_error(MORPHEUS_CONCAT_STR( + "Failed to connect to Triton at '" + << m_server_url + << "'. Received 'Unsupported Protocol' error. Are you using the right port? The C++ " + "InferenceClientStage uses Triton's HTTP protocol instead of gRPC. Ensure you have " + "specified the HTTP port (Default 8000).")); + } + + if (!status.IsOk()) + throw std::runtime_error(MORPHEUS_CONCAT_STR( + "Unable to connect to Triton at '" + << m_server_url << "'. Check the URL and port and ensure the server is running.")); + } + + if (!is_server_live) + throw std::runtime_error(MORPHEUS_CONCAT_STR( + "Unable to connect to Triton at '" + << m_server_url + << "'. Server reported as not live. Check the URL and port and ensure the server is running.")); + + m_fiber_local_client.reset(client.release()); + } + + return *m_fiber_local_client; +} } // namespace morpheus diff --git a/morpheus/stages/inference/triton_inference_stage.py b/morpheus/stages/inference/triton_inference_stage.py index c46cdcab48..26420c2f51 100644 --- a/morpheus/stages/inference/triton_inference_stage.py +++ b/morpheus/stages/inference/triton_inference_stage.py @@ -804,9 +804,4 @@ def _get_cpp_inference_node(self, builder: mrc.Builder) -> mrc.SegmentObject: def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject: node = super()._build_single(builder, input_node) - # ensure that the C++ impl only uses a single progress engine - if (self._build_cpp_node()): - node.launch_options.pe_count = 1 - node.launch_options.engines_per_pe = 1 - return node diff --git a/tests/test_sid.py b/tests/test_sid.py index b36903fd82..d7190bce7a 100755 --- a/tests/test_sid.py +++ b/tests/test_sid.py @@ -78,7 +78,6 @@ def _run_minibert_pipeline( config.pipeline_batch_size = 1024 config.feature_length = FEATURE_LENGTH config.edge_buffer_size = 128 - config.num_threads = 1 val_file_name = os.path.join(TEST_DIRS.validation_data_dir, 'sid-validation-data.csv') vocab_file_name = os.path.join(TEST_DIRS.data_dir, 'bert-base-uncased-hash.txt')