Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring Descriptor Architecture #486

Draft
wants to merge 20 commits into
base: fea-mnmg
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
5b9c796
Prototype unified singular descriptor class
jadu-nv Jul 8, 2024
25afe54
Remove mrc/codable/type_traits.hpp and introduce requires
jadu-nv Jul 8, 2024
a952ceb
Add complex object to TransferFullDescriptors test and debug deserial…
jadu-nv Jul 8, 2024
a07642d
Refactor callback logic to use shared_ptr rather than tokens and add …
jadu-nv Jul 9, 2024
8d46d14
Fix deferred messaging pull and refactor registration cache
jadu-nv Jul 12, 2024
5010e10
Clean up registration cache code, use std::optional, increase byte si…
jadu-nv Jul 15, 2024
ff5486d
Add support for device memory RDMA and testing
jadu-nv Jul 18, 2024
c27676d
Add initial benchmarking of descriptors
jadu-nv Jul 22, 2024
1584fe9
Bug fix for device memory sent via eager protocol and add async multi…
jadu-nv Jul 25, 2024
7066fa4
Add more benchmark cases and fix typo
jadu-nv Jul 29, 2024
a547dc8
Add test for back pressure of the DataPlane resources object
jadu-nv Jul 29, 2024
c76a12b
Refactor benchmark to achieve async send
jadu-nv Jul 31, 2024
3003dc6
Add async receive with coroutines, streamline benchmarking, and add c…
jadu-nv Aug 6, 2024
23f4661
Add logic to deregister memory from registration cache
jadu-nv Aug 9, 2024
6228e1a
Add TypedDescriptor for storing the type of descriptor
jadu-nv Aug 10, 2024
3c6e273
Refactor receive coroutine to use ClosableRingBuffer
jadu-nv Aug 12, 2024
ef267bc
Add asynchronous sending and new semaphore coroutines library
jadu-nv Aug 13, 2024
11eb547
Add total time and more accurate benchmarking
jadu-nv Aug 14, 2024
3147b96
Fix descriptor tests
jadu-nv Aug 14, 2024
19451ab
Revert changes made to CMake and add thorough commenting
jadu-nv Aug 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix deferred messaging pull and refactor registration cache
  • Loading branch information
jadu-nv committed Jul 12, 2024
commit 8d46d14ff75ede073db8e3788185ba32a22d968d
17 changes: 9 additions & 8 deletions cpp/mrc/include/mrc/runtime/remote_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ class Descriptor2 : public std::enable_shared_from_this<Descriptor2>
memory::buffer serialize(std::shared_ptr<memory::memory_resource> mr);

template <typename T>
[[nodiscard]] const T deserialize() const;
[[nodiscard]] const T deserialize();

static std::shared_ptr<Descriptor2> create(std::any value, data_plane::DataPlaneResources2& data_plane_resources);
static std::shared_ptr<Descriptor2> create(memory::buffer_view view, data_plane::DataPlaneResources2& data_plane_resources);
Expand All @@ -443,6 +443,8 @@ class Descriptor2 : public std::enable_shared_from_this<Descriptor2>

std::any m_value;

std::vector<memory::buffer> m_local_buffers;

std::unique_ptr<codable::DescriptorObjectHandler> m_encoded_object;

data_plane::DataPlaneResources2& m_data_plane_resources;
Expand All @@ -453,7 +455,7 @@ memory::buffer Descriptor2::serialize(std::shared_ptr<memory::memory_resource> m
{
if (!m_encoded_object)
{
m_encoded_object = std::move(mrc::codable::encode2<T>(std::any_cast<T>(m_value)));
m_encoded_object = std::move(mrc::codable::encode2<T>(std::any_cast<const T&>(m_value)));
this->setup_remote_payloads();
}

Expand All @@ -471,12 +473,11 @@ memory::buffer Descriptor2::serialize(std::shared_ptr<memory::memory_resource> m
}

template <typename T>
[[nodiscard]] const T Descriptor2::deserialize() const
[[nodiscard]] const T Descriptor2::deserialize()
{
return m_value.has_value() ? std::move(std::any_cast<T>(m_value)) :
std::move(mrc::codable::decode2<T>(*m_encoded_object));
T return_value = m_value.has_value() ? std::move(std::any_cast<T>(m_value)) :
std::move(mrc::codable::decode2<T>(*m_encoded_object));
m_value.reset();
return std::move(return_value);
}



} // namespace mrc::runtime
29 changes: 11 additions & 18 deletions cpp/mrc/src/internal/data_plane/data_plane_resources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ DataPlaneResources2::DataPlaneResources2()
m_address = m_worker->getAddress();

DVLOG(10) << "initialize the registration cache for this context";
m_registration_cache = std::make_shared<ucx::RegistrationCache2>(m_context);
m_registration_cache3 = std::make_shared<ucx::RegistrationCache3>(m_context);

auto pull_complete_callback = ucxx::AmReceiverCallbackType([this](std::shared_ptr<ucxx::Request> req) {
auto status = req->getStatus();
Expand Down Expand Up @@ -204,6 +204,11 @@ ucx::RegistrationCache2& DataPlaneResources2::registration_cache() const
return *m_registration_cache;
}

ucx::RegistrationCache3& DataPlaneResources2::registration_cache3() const
{
return *m_registration_cache3;
}

std::shared_ptr<ucxx::Endpoint> DataPlaneResources2::create_endpoint(const ucx::WorkerAddress& address,
uint64_t instance_id)
{
Expand Down Expand Up @@ -288,32 +293,20 @@ std::shared_ptr<ucxx::Request> DataPlaneResources2::memory_send_async(std::share
std::shared_ptr<ucxx::Request> DataPlaneResources2::memory_recv_async(std::shared_ptr<ucxx::Endpoint> endpoint,
memory::buffer_view buffer_view,
uintptr_t remote_addr,
const void* packed_rkey_data)
const std::string& serialized_rkey)
{
return this->memory_recv_async(endpoint, buffer_view.data(), buffer_view.bytes(), remote_addr, packed_rkey_data);
return this->memory_recv_async(endpoint, buffer_view.data(), buffer_view.bytes(), remote_addr, serialized_rkey);
}

std::shared_ptr<ucxx::Request> DataPlaneResources2::memory_recv_async(std::shared_ptr<ucxx::Endpoint> endpoint,
void* addr,
std::size_t bytes,
uintptr_t remote_addr,
const void* packed_rkey_data)
const std::string& serialized_rkey)
{
ucp_rkey_h rkey;

// Unpack the key
auto rc = ucp_ep_rkey_unpack(endpoint->getHandle(), packed_rkey_data, &rkey);
CHECK_EQ(rc, UCS_OK);

// Const cast away because UCXX only accepts void*
auto request = endpoint->memGet(addr,
bytes,
remote_addr,
rkey,
false,
[rkey](ucs_status_t status, std::shared_ptr<void> user_data) {
ucp_rkey_destroy(rkey);
});
auto rkey = ucxx::createRemoteKeyFromSerialized(endpoint, serialized_rkey);
auto request = endpoint->memGet(addr, bytes, rkey);

return request;
}
Expand Down
8 changes: 6 additions & 2 deletions cpp/mrc/src/internal/data_plane/data_plane_resources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ namespace ucxx {
class Context;
class Endpoint;
class Worker;
class RemoteKey;
class Request;
class Address;
} // namespace ucxx
Expand All @@ -61,6 +62,7 @@ class NetworkResources;
namespace mrc::ucx {
class RegistrationCache;
class RegistrationCache2;
class RegistrationCache3;
class UcxResources;
} // namespace mrc::ucx

Expand Down Expand Up @@ -134,6 +136,7 @@ class DataPlaneResources2
std::string address() const;

ucx::RegistrationCache2& registration_cache() const;
ucx::RegistrationCache3& registration_cache3() const;

std::shared_ptr<ucxx::Endpoint> create_endpoint(const std::string& address, uint64_t instance_id);

Expand Down Expand Up @@ -164,13 +167,13 @@ class DataPlaneResources2
std::shared_ptr<ucxx::Request> memory_recv_async(std::shared_ptr<ucxx::Endpoint> endpoint,
memory::buffer_view buffer_view,
uintptr_t remote_addr,
const void* packed_rkey_data);
const std::string& serialized_rkey);

std::shared_ptr<ucxx::Request> memory_recv_async(std::shared_ptr<ucxx::Endpoint> endpoint,
void* addr,
std::size_t bytes,
uintptr_t remote_addr,
const void* packed_rkey_data);
const std::string& serialized_rkey);

std::shared_ptr<ucxx::Request> tagged_send_async(std::shared_ptr<ucxx::Endpoint> endpoint,
memory::const_buffer_view buffer_view,
Expand Down Expand Up @@ -208,6 +211,7 @@ class DataPlaneResources2
std::shared_ptr<ucxx::Address> m_address;

std::shared_ptr<ucx::RegistrationCache2> m_registration_cache;
std::shared_ptr<ucx::RegistrationCache3> m_registration_cache3;

std::map<std::string, std::shared_ptr<ucxx::Endpoint>> m_endpoints_by_address;
std::map<uint64_t, std::shared_ptr<ucxx::Endpoint>> m_endpoints_by_id;
Expand Down
32 changes: 32 additions & 0 deletions cpp/mrc/src/internal/ucx/registration_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,36 @@ void RegistrationCache2::unregister_memory(ucp_mem_h handle, void* rbuffer)
}
}
}

RegistrationCache3::RegistrationCache3(std::shared_ptr<ucxx::Context> context) : m_context(std::move(context))
{
CHECK(m_context);
}

std::shared_ptr<ucxx::MemoryHandle> RegistrationCache3::add_block(void* addr, std::size_t bytes)
{
DCHECK(addr && bytes);
std::lock_guard<decltype(m_mutex)> lock(m_mutex);
m_memory_handle_by_address[addr] = m_context->createMemoryHandle(bytes, addr);
return m_memory_handle_by_address[addr];
}

std::shared_ptr<ucxx::MemoryHandle> RegistrationCache3::add_block(uintptr_t addr, std::size_t bytes)
{
return this->add_block(reinterpret_cast<void*>(addr), bytes);
}

std::shared_ptr<ucxx::MemoryHandle> RegistrationCache3::lookup(const void* addr) const noexcept
{
std::lock_guard<decltype(m_mutex)> lock(m_mutex);
if (m_memory_handle_by_address.find(addr) != m_memory_handle_by_address.end())
{
return m_memory_handle_by_address.at(addr);
}
}

std::shared_ptr<ucxx::MemoryHandle> RegistrationCache3::lookup(uintptr_t addr) const noexcept
{
return this->lookup(reinterpret_cast<const void*>(addr));
}
} // namespace mrc::ucx
63 changes: 63 additions & 0 deletions cpp/mrc/src/internal/ucx/registration_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

namespace ucxx {
class Context;
class MemoryHandle;
}

namespace mrc::ucx {
Expand Down Expand Up @@ -174,4 +175,66 @@ class RegistrationCache2 final
memory::BlockManager<MemoryBlock> m_blocks;
};

/**
* @brief UCX Registration Cache
*
* UCX memory registration object that will both register/deregister memory as well as cache the set of local and remote
* keys for each registration. The cache can be queried for the original memory block by providing any valid address
* contained in the contiguous block.
*/
class RegistrationCache3 final
{
public:
RegistrationCache3(std::shared_ptr<ucxx::Context> context);

/**
* @brief Register a contiguous block of memory starting at addr and spanning `bytes` bytes.
*
* For each block of memory registered with the RegistrationCache, an entry containing the block information is
* storage and can be queried.
*
* @param addr
* @param bytes
*/
std::shared_ptr<ucxx::MemoryHandle> add_block(void* addr, std::size_t bytes);

std::shared_ptr<ucxx::MemoryHandle> add_block(uintptr_t addr, std::size_t bytes);

/**
* @brief Deregister a contiguous block of memory from the ucx context and remove the cache entry
*
* @param addr
* @param bytes
* @return std::size_t
*/
std::size_t drop_block(const void* addr, std::size_t bytes);

std::size_t drop_block(uintptr_t addr, std::size_t bytes);

/**
* @brief Look up the memory registration details for a given address.
*
* This method queries the registration cache to find the UcxMemoryBlock containing the original address and size as
* well as the local and remote keys associated with the memory block.
*
* Any address contained within a registered block can be used to query the UcxMemoryBlock
*
* @param addr
* @return const MemoryBlock&
*/
std::shared_ptr<ucxx::MemoryHandle> lookup(const void* addr) const noexcept;

std::shared_ptr<ucxx::MemoryHandle> lookup(uintptr_t addr) const noexcept;

private:
ucp_mem_h register_memory(const void* address, std::size_t bytes);

std::tuple<ucp_mem_h, void*, std::size_t> register_memory_with_rkey(const void* address, std::size_t bytes);

void unregister_memory(ucp_mem_h handle, void* rbuffer = nullptr);

mutable std::mutex m_mutex;
const std::shared_ptr<ucxx::Context> m_context;
std::map<const void*, std::shared_ptr<ucxx::MemoryHandle>> m_memory_handle_by_address;
};
} // namespace mrc::ucx
48 changes: 21 additions & 27 deletions cpp/mrc/src/public/runtime/remote_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ std::unique_ptr<LocalDescriptor2> LocalDescriptor2::from_remote(std::unique_ptr<
// Allocate the memory needed for this and prevent it from going out-of-scope before request completes
buffers.emplace_back(deferred_remote_msg.bytes(), mr);

// now issue the request
requests.push_back(data_plane_resources.memory_recv_async(ep,
buffers.back(),
deferred_remote_msg.address(),
deferred_remote_msg.remote_key().data()));
// // now issue the request
// requests.push_back(data_plane_resources.memory_recv_async(ep,
// buffers.back(),
// deferred_remote_msg.address(),
// deferred_remote_msg.remote_key().data()));

deferred_msg->set_address(reinterpret_cast<uintptr_t>(buffers.back().data()));
deferred_msg->set_bytes(buffers.back().bytes());
Expand Down Expand Up @@ -348,16 +348,14 @@ std::shared_ptr<Descriptor2> Descriptor2::create(memory::buffer_view view, data_
LOG(FATAL) << "Failed to parse EncodedObjectProto from bytes";
}

auto mr = memory::malloc_memory_resource::instance();

std::vector<std::shared_ptr<ucxx::Request>> requests;
std::vector<memory::buffer> buffers;

// Get the endpoint of the remote descriptor
auto ep = data_plane_resources.find_endpoint(descriptor->proto().instance_id());

// Loop over all remote payloads and convert them to local payloads
for (const auto& remote_payload : descriptor->proto().payloads())
for (auto& remote_payload : *descriptor->proto().mutable_payloads())
{
// If payload is an EagerMessage, we do not need to do any pulling
if (remote_payload.has_eager_msg())
Expand All @@ -366,19 +364,19 @@ std::shared_ptr<Descriptor2> Descriptor2::create(memory::buffer_view view, data_
}

// Get the DeferredMessage of the remote payload
auto deferred_remote_msg = remote_payload.deferred_msg();
auto* deferred_remote_msg = remote_payload.mutable_deferred_msg();

// Allocate the memory needed for this and prevent it from going out-of-scope before request completes
buffers.emplace_back(deferred_remote_msg.bytes(), mr);
auto mr = memory::malloc_memory_resource::instance();
buffers.emplace_back(deferred_remote_msg->bytes(), mr);

// now issue the request
requests.push_back(data_plane_resources.memory_recv_async(ep,
buffers.back(),
deferred_remote_msg.address(),
deferred_remote_msg.remote_key().data()));
deferred_remote_msg->address(),
deferred_remote_msg->remote_key()));

deferred_remote_msg.set_address(reinterpret_cast<uintptr_t>(buffers.back().data()));
deferred_remote_msg.set_bytes(buffers.back().bytes());
deferred_remote_msg->set_address(reinterpret_cast<uintptr_t>(buffers.back().data()));
deferred_remote_msg->set_bytes(buffers.back().bytes());
}

// Now, we need to wait for all requests to be complete
Expand All @@ -395,7 +393,9 @@ std::shared_ptr<Descriptor2> Descriptor2::create(memory::buffer_view view, data_
UCS_MEMORY_TYPE_HOST,
ucxx::AmReceiverCallbackInfo("MRC", 0));

return std::shared_ptr<Descriptor2>(new Descriptor2(std::move(descriptor), data_plane_resources));
auto instance = std::shared_ptr<Descriptor2>(new Descriptor2(std::move(descriptor), data_plane_resources));
instance->m_local_buffers = std::move(buffers);
return instance;
}

void Descriptor2::setup_remote_payloads()
Expand All @@ -417,18 +417,12 @@ void Descriptor2::setup_remote_payloads()

auto* deferred_msg = payload.mutable_deferred_msg();

auto ucx_block = m_data_plane_resources.registration_cache().lookup(deferred_msg->address());

if (!ucx_block.has_value())
{
// Need to register the memory
ucx_block = m_data_plane_resources.registration_cache().add_block(deferred_msg->address(),
deferred_msg->bytes());
}
// Need to register the memory
auto ucx_block = m_data_plane_resources.registration_cache3().add_block(deferred_msg->address(),
deferred_msg->bytes());

deferred_msg->set_memory_block_address(reinterpret_cast<std::uint64_t>(ucx_block->data()));
deferred_msg->set_memory_block_size(ucx_block->bytes());
deferred_msg->set_remote_key(ucx_block->packed_remote_keys());
auto serializedRemoteKey = ucx_block->createRemoteKey()->serialize();
deferred_msg->set_remote_key(serializedRemoteKey);
}
}

Expand Down
Loading