Skip to content

Commit

Permalink
Fix triton multi threading when using the C++ stage (#1739)
Browse files Browse the repository at this point in the history
- In 24.03, multi threading had to be disabled in the Triton C++ stage due to a race condition.
- This PR changes the client to use fiber local storage so multiple fibers can be running with different clients at the same time
- Allows increasing the `pe_count` of the stage beyond 1


## By Submitting this PR I confirm:
- I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md).
- When the PR is ready for review, new or existing tests cover these changes.
- When the PR is ready for review, the documentation is up to date with these changes.

Authors:
  - Michael Demoret (https://github.com/mdemoret-nv)
  - David Gardner (https://github.com/dagardner-nv)

Approvers:
  - David Gardner (https://github.com/dagardner-nv)

URL: #1739
  • Loading branch information
mdemoret-nv committed Jun 26, 2024
1 parent 6ea6c49 commit 8095a76
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 137 deletions.
8 changes: 7 additions & 1 deletion morpheus/_lib/include/morpheus/stages/triton_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
#include "morpheus/stages/inference_client_stage.hpp"
#include "morpheus/types.hpp"

#include <boost/fiber/fss.hpp>
#include <http_client.h>
#include <mrc/coroutines/task.hpp>

#include <cstdint>
// IWYU pragma: no_include "rxcpp/sources/rx-iterate.hpp"

#include <memory>
#include <mutex>
#include <string>
#include <vector>

Expand Down Expand Up @@ -106,7 +108,11 @@ class MORPHEUS_EXPORT ITritonClient
class MORPHEUS_EXPORT HttpTritonClient : public ITritonClient
{
private:
std::unique_ptr<triton::client::InferenceServerHttpClient> m_client;
std::string m_server_url;
std::mutex m_client_mutex;
boost::fibers::fiber_specific_ptr<triton::client::InferenceServerHttpClient> m_fiber_local_client;

triton::client::InferenceServerHttpClient& get_client();

public:
HttpTritonClient(std::string server_url);
Expand Down
149 changes: 77 additions & 72 deletions morpheus/_lib/messages/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "morpheus/utilities/string_util.hpp"
#include "morpheus/version.hpp"

#include <glog/logging.h> // for COMPACT_GOOGLE_LOG_INFO, LogMessage, VLOG
#include <mrc/edge/edge_connector.hpp>
#include <nlohmann/json.hpp> // for basic_json
#include <pybind11/functional.h> // IWYU pragma: keep
Expand All @@ -53,17 +54,76 @@
#include <pymrc/utils.hpp> // for pymrc::import
#include <rxcpp/rx.hpp>

#include <cstddef> // for size_t
#include <filesystem>
#include <memory>
#include <sstream>
#include <string>
#include <tuple> // IWYU pragma: keep
#include <typeinfo> // for type_info
#include <utility> // for index_sequence, make_index_sequence
#include <vector>
// For some reason IWYU thinks the variant header is needed for tuple, and that the array header is needed for
// tuple_element
// IWYU pragma: no_include <array>
// IWYU pragma: no_include <variant>

namespace morpheus {

namespace fs = std::filesystem;
namespace py = pybind11;

template <typename FirstT, typename SecondT>
void reg_converter()
{
mrc::edge::EdgeConnector<std::shared_ptr<FirstT>, std::shared_ptr<SecondT>>::register_converter();
}

template <typename T>
void reg_py_type_helper()
{
// Register the port util
mrc::pymrc::PortBuilderUtil::register_port_util<std::shared_ptr<T>>();

// Register conversion to and from python
mrc::edge::EdgeConnector<std::shared_ptr<T>, mrc::pymrc::PyObjectHolder>::register_converter();
mrc::edge::EdgeConnector<mrc::pymrc::PyObjectHolder, std::shared_ptr<T>>::register_converter();
}

template <typename TupleT, std::size_t I>
void do_register_tuple_index()
{
static constexpr std::size_t LeftIndex = I / std::tuple_size<TupleT>::value;
static constexpr std::size_t RightIndex = I % std::tuple_size<TupleT>::value;

using left_t = typename std::tuple_element<LeftIndex, TupleT>::type;
using right_t = typename std::tuple_element<RightIndex, TupleT>::type;

// Only register if one of the types is a subclass of the other
if constexpr (!std::is_same_v<left_t, right_t> && std::is_base_of_v<right_t, left_t>)
{
// Print the registration
VLOG(20) << "[Type Registration]: Registering: " << typeid(left_t).name() << " -> " << typeid(right_t).name();
reg_converter<left_t, right_t>();
}
else
{
VLOG(20) << "[Type Registration]: Skipping: " << typeid(left_t).name() << " -> " << typeid(right_t).name();
}
};

template <typename TupleT, std::size_t... Is>
void register_tuple_index(std::index_sequence<Is...> /*unused*/)
{
(do_register_tuple_index<TupleT, Is>(), ...);
}

template <typename... TypesT>
void register_permutations()
{
register_tuple_index<std::tuple<TypesT...>>(std::make_index_sequence<(sizeof...(TypesT)) * (sizeof...(TypesT))>());
}

PYBIND11_MODULE(messages, _module)
{
_module.doc() = R"pbdoc(
Expand All @@ -86,80 +146,25 @@ PYBIND11_MODULE(messages, _module)
// Allows python objects to keep DataTable objects alive
py::class_<IDataTable, std::shared_ptr<IDataTable>>(_module, "DataTable");

mrc::pymrc::PortBuilderUtil::register_port_util<std::shared_ptr<ControlMessage>>();
mrc::pymrc::PortBuilderUtil::register_port_util<std::shared_ptr<MessageMeta>>();
mrc::pymrc::PortBuilderUtil::register_port_util<std::shared_ptr<MultiMessage>>();
mrc::pymrc::PortBuilderUtil::register_port_util<std::shared_ptr<MultiTensorMessage>>();
mrc::pymrc::PortBuilderUtil::register_port_util<std::shared_ptr<MultiInferenceMessage>>();
mrc::pymrc::PortBuilderUtil::register_port_util<std::shared_ptr<MultiInferenceFILMessage>>();
mrc::pymrc::PortBuilderUtil::register_port_util<std::shared_ptr<MultiInferenceNLPMessage>>();
mrc::pymrc::PortBuilderUtil::register_port_util<std::shared_ptr<MultiResponseMessage>>();
mrc::pymrc::PortBuilderUtil::register_port_util<std::shared_ptr<MultiResponseProbsMessage>>();

// EdgeConnectors for converting between PyObjectHolders and various Message types
mrc::edge::EdgeConnector<std::shared_ptr<morpheus::ControlMessage>,
mrc::pymrc::PyObjectHolder>::register_converter();
mrc::edge::EdgeConnector<mrc::pymrc::PyObjectHolder,
std::shared_ptr<morpheus::ControlMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MessageMeta>, mrc::pymrc::PyObjectHolder>::register_converter();
mrc::edge::EdgeConnector<mrc::pymrc::PyObjectHolder, std::shared_ptr<morpheus::MessageMeta>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiMessage>, mrc::pymrc::PyObjectHolder>::register_converter();
mrc::edge::EdgeConnector<mrc::pymrc::PyObjectHolder, std::shared_ptr<morpheus::MultiMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiInferenceMessage>,
mrc::pymrc::PyObjectHolder>::register_converter();
mrc::edge::EdgeConnector<mrc::pymrc::PyObjectHolder,
std::shared_ptr<morpheus::MultiInferenceMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiResponseMessage>,
mrc::pymrc::PyObjectHolder>::register_converter();
mrc::edge::EdgeConnector<mrc::pymrc::PyObjectHolder,
std::shared_ptr<morpheus::MultiResponseMessage>>::register_converter();
// Add type registrations for all our common types
reg_py_type_helper<ControlMessage>();
reg_py_type_helper<MessageMeta>();
reg_py_type_helper<MultiMessage>();
reg_py_type_helper<MultiTensorMessage>();
reg_py_type_helper<MultiInferenceMessage>();
reg_py_type_helper<MultiInferenceFILMessage>();
reg_py_type_helper<MultiInferenceNLPMessage>();
reg_py_type_helper<MultiResponseMessage>();
reg_py_type_helper<MultiResponseProbsMessage>();

// EdgeConnectors for derived classes of MultiMessage to MultiMessage
mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiTensorMessage>,
std::shared_ptr<morpheus::MultiMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiInferenceMessage>,
std::shared_ptr<morpheus::MultiTensorMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiInferenceMessage>,
std::shared_ptr<morpheus::MultiMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiInferenceFILMessage>,
std::shared_ptr<morpheus::MultiInferenceMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiInferenceFILMessage>,
std::shared_ptr<morpheus::MultiTensorMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiInferenceFILMessage>,
std::shared_ptr<morpheus::MultiMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiInferenceNLPMessage>,
std::shared_ptr<morpheus::MultiInferenceMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiInferenceNLPMessage>,
std::shared_ptr<morpheus::MultiTensorMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiInferenceNLPMessage>,
std::shared_ptr<morpheus::MultiMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiResponseMessage>,
std::shared_ptr<morpheus::MultiMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiResponseMessage>,
std::shared_ptr<morpheus::MultiTensorMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiResponseProbsMessage>,
std::shared_ptr<morpheus::MultiResponseMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiResponseProbsMessage>,
std::shared_ptr<morpheus::MultiTensorMessage>>::register_converter();

mrc::edge::EdgeConnector<std::shared_ptr<morpheus::MultiResponseProbsMessage>,
std::shared_ptr<morpheus::MultiMessage>>::register_converter();
register_permutations<MultiMessage,
MultiTensorMessage,
MultiInferenceMessage,
MultiInferenceFILMessage,
MultiInferenceNLPMessage,
MultiResponseMessage,
MultiResponseProbsMessage>();

// Tensor Memory classes
py::class_<TensorMemory, std::shared_ptr<TensorMemory>>(_module, "TensorMemory")
Expand Down
Loading

0 comments on commit 8095a76

Please sign in to comment.