Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring Descriptor Architecture #486

Draft
wants to merge 20 commits into
base: fea-mnmg
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
5b9c796
Prototype unified singular descriptor class
jadu-nv Jul 8, 2024
25afe54
Remove mrc/codable/type_traits.hpp and introduce requires
jadu-nv Jul 8, 2024
a952ceb
Add complex object to TransferFullDescriptors test and debug deserial…
jadu-nv Jul 8, 2024
a07642d
Refactor callback logic to use shared_ptr rather than tokens and add …
jadu-nv Jul 9, 2024
8d46d14
Fix deferred messaging pull and refactor registration cache
jadu-nv Jul 12, 2024
5010e10
Clean up registration cache code, use std::optional, increase byte si…
jadu-nv Jul 15, 2024
ff5486d
Add support for device memory RDMA and testing
jadu-nv Jul 18, 2024
c27676d
Add initial benchmarking of descriptors
jadu-nv Jul 22, 2024
1584fe9
Bug fix for device memory sent via eager protocol and add async multi…
jadu-nv Jul 25, 2024
7066fa4
Add more benchmark cases and fix typo
jadu-nv Jul 29, 2024
a547dc8
Add test for back pressure of the DataPlane resources object
jadu-nv Jul 29, 2024
c76a12b
Refactor benchmark to achieve async send
jadu-nv Jul 31, 2024
3003dc6
Add async receive with coroutines, streamline benchmarking, and add c…
jadu-nv Aug 6, 2024
23f4661
Add logic to deregister memory from registration cache
jadu-nv Aug 9, 2024
6228e1a
Add TypedDescriptor for storing the type of descriptor
jadu-nv Aug 10, 2024
3c6e273
Refactor receive coroutine to use ClosableRingBuffer
jadu-nv Aug 12, 2024
ef267bc
Add asynchronous sending and new semaphore coroutines library
jadu-nv Aug 13, 2024
11eb547
Add total time and more accurate benchmarking
jadu-nv Aug 14, 2024
3147b96
Fix descriptor tests
jadu-nv Aug 14, 2024
19451ab
Revert changes made to CMake and add thorough commenting
jadu-nv Aug 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor callback logic to use shared_ptr rather than tokens and add …
…broadcasting test
  • Loading branch information
jadu-nv committed Jul 9, 2024
commit a07642de23b2a0f33c22e11209982e16db1bd753
27 changes: 7 additions & 20 deletions cpp/mrc/include/mrc/runtime/remote_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,6 @@ class Descriptor2 : public std::enable_shared_from_this<Descriptor2>
template <typename T>
memory::buffer serialize(std::shared_ptr<memory::memory_resource> mr);

template <typename T>
memory::buffer_view serialize(memory::buffer_view buffer);

template <typename T>
[[nodiscard]] const T deserialize() const;

Expand All @@ -442,6 +439,7 @@ class Descriptor2 : public std::enable_shared_from_this<Descriptor2>
m_encoded_object(std::move(encoded_object)), m_data_plane_resources(data_plane_resources) {}

void setup_remote_payloads();
void register_remote_descriptor();

std::any m_value;

Expand All @@ -453,27 +451,16 @@ class Descriptor2 : public std::enable_shared_from_this<Descriptor2>
template <typename T>
memory::buffer Descriptor2::serialize(std::shared_ptr<memory::memory_resource> mr)
{
m_encoded_object = std::move(mrc::codable::encode2<T>(std::any_cast<T>(m_value)));

setup_remote_payloads();

// Allocate enough bytes to hold the encoded object
auto buffer = memory::buffer(m_encoded_object->proto().ByteSizeLong(), mr);

if (!m_encoded_object->proto().SerializeToArray(buffer.data(), buffer.bytes()))
if (!m_encoded_object)
{
LOG(FATAL) << "Failed to serialize EncodedObjectProto to bytes";
m_encoded_object = std::move(mrc::codable::encode2<T>(std::any_cast<T>(m_value)));
this->setup_remote_payloads();
}

return buffer;
}

template <typename T>
memory::buffer_view Descriptor2::serialize(memory::buffer_view buffer)
{
m_encoded_object = std::move(mrc::codable::encode2<T>(std::any_cast<T>(m_value)));
this->register_remote_descriptor();

setup_remote_payloads();
// Allocate enough bytes to hold the encoded object
auto buffer = memory::buffer(m_encoded_object->proto().ByteSizeLong(), mr);

if (!m_encoded_object->proto().SerializeToArray(buffer.data(), buffer.bytes()))
{
Expand Down
54 changes: 32 additions & 22 deletions cpp/mrc/src/internal/data_plane/data_plane_resources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,22 +161,22 @@ DataPlaneResources2::DataPlaneResources2()
DVLOG(10) << "initialize the registration cache for this context";
m_registration_cache = std::make_shared<ucx::RegistrationCache2>(m_context);

auto decrement_callback = ucxx::AmReceiverCallbackType([this](std::shared_ptr<ucxx::Request> req) {
auto pull_complete_callback = ucxx::AmReceiverCallbackType([this](std::shared_ptr<ucxx::Request> req) {
auto status = req->getStatus();
if (status != UCS_OK)
{
LOG(ERROR) << "Error calling decrement_callback, request failed with status " << status << "("
LOG(ERROR) << "Error calling pull_complete_callback, request failed with status " << status << "("
<< ucs_status_string(status) << ")";
}

auto* dec_message = reinterpret_cast<remote_descriptor::RemoteDescriptorDecrementMessage*>(
auto* message = reinterpret_cast<remote_descriptor::DescriptorPullCompletionMessage*>(
req->getRecvBuffer()->data());

decrement_tokens(dec_message);
complete_remote_pull(message);
});
m_worker->registerAmReceiverCallback(
ucxx::AmReceiverCallbackInfo(ucxx::AmReceiverCallbackOwnerType("MRC"), ucxx::AmReceiverCallbackIdType(0)),
decrement_callback);
pull_complete_callback);

// flush any work that needs to be done by the workers
this->flush();
Expand Down Expand Up @@ -386,44 +386,54 @@ uint64_t DataPlaneResources2::get_next_object_id()
return m_next_object_id++;
}

uint64_t DataPlaneResources2::register_remote_decriptor(
std::shared_ptr<runtime::Descriptor2> remote_descriptor)
uint64_t DataPlaneResources2::register_remote_decriptor(std::shared_ptr<runtime::Descriptor2> descriptor)
{
auto object_id = get_next_object_id();
remote_descriptor->encoded_object().set_object_id(object_id);
// If the descriptor has an object_id > 0, the descriptor has already been registered and should not be re-registered
auto object_id = descriptor->encoded_object().object_id();
if (object_id > 0)
{
m_descriptor_by_id[object_id].push_back(descriptor);
return object_id;
}

object_id = get_next_object_id();
descriptor->encoded_object().set_object_id(object_id);
{
std::unique_lock lock(m_remote_descriptors_mutex);
m_remote_descriptors_cv.wait(lock, [this] {
return m_remote_descriptor_by_id.size() < m_max_remote_descriptors;
return m_descriptor_by_id.size() < m_max_remote_descriptors;
});
m_remote_descriptor_by_id[object_id] = remote_descriptor;
m_descriptor_by_id[object_id].push_back(descriptor);
}
return object_id;
}

uint64_t DataPlaneResources2::registered_remote_descriptor_count()
{
return m_remote_descriptor_by_id.size();
return m_descriptor_by_id.size();
}

uint64_t DataPlaneResources2::registered_remote_descriptor_token_count(uint64_t object_id)
uint64_t DataPlaneResources2::registered_remote_descriptor_ptr_count(uint64_t object_id)
{
return m_remote_descriptor_by_id.at(object_id)->encoded_object().tokens();
return m_descriptor_by_id.at(object_id).size();
}

void DataPlaneResources2::decrement_tokens(remote_descriptor::RemoteDescriptorDecrementMessage* dec_message)
void DataPlaneResources2::complete_remote_pull(remote_descriptor::DescriptorPullCompletionMessage* message)
{
if (dec_message->tokens > 0)
// If the mapping between object_id to descriptor shared ptrs exists, then there exists >= 1 shared ptrs
if (m_descriptor_by_id.find(message->object_id) != m_descriptor_by_id.end())
{
auto remote_descriptor = m_remote_descriptor_by_id[dec_message->object_id];
auto tokens = remote_descriptor->encoded_object().tokens();
tokens -= dec_message->tokens;
remote_descriptor->encoded_object().set_tokens(tokens);
if (tokens == 0)
// Once we've completed pulling of a descriptor, we remove a descriptor shared ptr from the vector
// When the vector becomes empty, there will be no more shared ptrs pointing to the descriptor object,
// it will be destructed accordingly.
// We should also remove that mapping as the object_id corresponding to that mapping will not be reused.
auto& descriptors = m_descriptor_by_id[message->object_id];
descriptors.pop_back();
if (descriptors.size() == 0)
{
{
std::unique_lock lock(m_remote_descriptors_mutex);
m_remote_descriptor_by_id.erase(dec_message->object_id);
m_descriptor_by_id.erase(message->object_id);
}
m_remote_descriptors_cv.notify_one();
}
Expand Down
11 changes: 6 additions & 5 deletions cpp/mrc/src/internal/data_plane/data_plane_resources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ class DataPlaneResources2
ucs_memory_type_t mem_type);
std::shared_ptr<ucxx::Request> am_recv_async(std::shared_ptr<ucxx::Endpoint> endpoint);

uint64_t register_remote_decriptor(std::shared_ptr<runtime::Descriptor2> remote_descriptor);
uint64_t register_remote_decriptor(std::shared_ptr<runtime::Descriptor2> descriptor);
uint64_t registered_remote_descriptor_count();
uint64_t registered_remote_descriptor_token_count(uint64_t object_id);
uint64_t registered_remote_descriptor_ptr_count(uint64_t object_id);

private:
std::optional<uint64_t> m_instance_id; // Global ID used to identify this instance
Expand All @@ -212,22 +212,23 @@ class DataPlaneResources2
std::map<std::string, std::shared_ptr<ucxx::Endpoint>> m_endpoints_by_address;
std::map<uint64_t, std::shared_ptr<ucxx::Endpoint>> m_endpoints_by_id;

std::atomic<uint64_t> m_next_object_id{0};
// An object_id of 0 (default protobuf int field value) signifies an unregistered descriptor
std::atomic<uint64_t> m_next_object_id{1};

// std::shared_ptr<node::Queue<std::unique_ptr<runtime::ValueDescriptor>>> m_outbound_descriptors;
// std::map<InstanceID, std::weak_ptr<node::Queue<std::unique_ptr<runtime::ValueDescriptor>>>>
// m_inbound_port_channels;

uint64_t get_next_object_id();

void decrement_tokens(remote_descriptor::RemoteDescriptorDecrementMessage* dec_message);
void complete_remote_pull(remote_descriptor::DescriptorPullCompletionMessage* message);

uint64_t m_max_remote_descriptors{std::numeric_limits<uint64_t>::max()};
boost::fibers::mutex m_remote_descriptors_mutex{};
boost::fibers::condition_variable m_remote_descriptors_cv{};

protected:
std::map<uint64_t, std::shared_ptr<runtime::Descriptor2>> m_remote_descriptor_by_id;
std::map<uint64_t, std::vector<std::shared_ptr<runtime::Descriptor2>>> m_descriptor_by_id;
};

} // namespace mrc::data_plane
5 changes: 5 additions & 0 deletions cpp/mrc/src/internal/remote_descriptor/messages.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@

namespace mrc::remote_descriptor {

struct DescriptorPullCompletionMessage
{
std::uint64_t object_id;
};

struct RemoteDescriptorDecrementMessage
{
std::uint64_t object_id;
Expand Down
9 changes: 5 additions & 4 deletions cpp/mrc/src/public/runtime/remote_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,14 +385,13 @@ std::shared_ptr<Descriptor2> Descriptor2::create(memory::buffer_view view, data_
data_plane_resources.wait_requests(requests);

// For the remote descriptor message, send decrement to the remote resources
remote_descriptor::RemoteDescriptorDecrementMessage dec_message;
remote_descriptor::DescriptorPullCompletionMessage dec_message;
dec_message.object_id = descriptor->proto().object_id();
dec_message.tokens = descriptor->proto().tokens();

// TODO(Peter): Define `ucxx::AmReceiverCallbackInfo` at central place, must be known by all MRC processes.
// Send a decrement message using custom AM receiver callback
auto decrement_request = ep->amSend(&dec_message,
sizeof(remote_descriptor::RemoteDescriptorDecrementMessage),
sizeof(remote_descriptor::DescriptorPullCompletionMessage),
UCS_MEMORY_TYPE_HOST,
ucxx::AmReceiverCallbackInfo("MRC", 0));

Expand All @@ -405,7 +404,6 @@ void Descriptor2::setup_remote_payloads()

// Transfer the info object
remote_object.set_instance_id(m_data_plane_resources.get_instance_id());
remote_object.set_tokens(std::numeric_limits<uint64_t>::max());

// Loop over all local payloads and convert them to remote payloads

Expand All @@ -432,7 +430,10 @@ void Descriptor2::setup_remote_payloads()
deferred_msg->set_memory_block_size(ucx_block->bytes());
deferred_msg->set_remote_key(ucx_block->packed_remote_keys());
}
}

void Descriptor2::register_remote_descriptor()
{
m_data_plane_resources.register_remote_decriptor(shared_from_this());
}
} // namespace mrc::runtime
Loading