Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring Descriptor Architecture #486

Draft
wants to merge 20 commits into
base: fea-mnmg
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
5b9c796
Prototype unified singular descriptor class
jadu-nv Jul 8, 2024
25afe54
Remove mrc/codable/type_traits.hpp and introduce requires
jadu-nv Jul 8, 2024
a952ceb
Add complex object to TransferFullDescriptors test and debug deserial…
jadu-nv Jul 8, 2024
a07642d
Refactor callback logic to use shared_ptr rather than tokens and add …
jadu-nv Jul 9, 2024
8d46d14
Fix deferred messaging pull and refactor registration cache
jadu-nv Jul 12, 2024
5010e10
Clean up registration cache code, use std::optional, increase byte si…
jadu-nv Jul 15, 2024
ff5486d
Add support for device memory RDMA and testing
jadu-nv Jul 18, 2024
c27676d
Add initial benchmarking of descriptors
jadu-nv Jul 22, 2024
1584fe9
Bug fix for device memory sent via eager protocol and add async multi…
jadu-nv Jul 25, 2024
7066fa4
Add more benchmark cases and fix typo
jadu-nv Jul 29, 2024
a547dc8
Add test for back pressure of the DataPlane resources object
jadu-nv Jul 29, 2024
c76a12b
Refactor benchmark to achieve async send
jadu-nv Jul 31, 2024
3003dc6
Add async receive with coroutines, streamline benchmarking, and add c…
jadu-nv Aug 6, 2024
23f4661
Add logic to deregister memory from registration cache
jadu-nv Aug 9, 2024
6228e1a
Add TypedDescriptor for storing the type of descriptor
jadu-nv Aug 10, 2024
3c6e273
Refactor receive coroutine to use ClosableRingBuffer
jadu-nv Aug 12, 2024
ef267bc
Add asynchronous sending and new semaphore coroutines library
jadu-nv Aug 13, 2024
11eb547
Add total time and more accurate benchmarking
jadu-nv Aug 14, 2024
3147b96
Fix descriptor tests
jadu-nv Aug 14, 2024
19451ab
Revert changes made to CMake and add thorough commenting
jadu-nv Aug 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Clean up registration cache code, use std::optional, increase byte si…
…ze in tests
  • Loading branch information
jadu-nv committed Jul 15, 2024
commit 5010e10f4a3f6741330ea1d512cb30e9612a34d4
5 changes: 3 additions & 2 deletions cpp/mrc/src/internal/ucx/registration_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,17 @@ std::shared_ptr<ucxx::MemoryHandle> RegistrationCache3::add_block(uintptr_t addr
return this->add_block(reinterpret_cast<void*>(addr), bytes);
}

std::shared_ptr<ucxx::MemoryHandle> RegistrationCache3::lookup(const void* addr) const noexcept
std::optional<std::shared_ptr<ucxx::MemoryHandle>> RegistrationCache3::lookup(const void* addr) const noexcept
{
std::lock_guard<decltype(m_mutex)> lock(m_mutex);
if (m_memory_handle_by_address.find(addr) != m_memory_handle_by_address.end())
{
return m_memory_handle_by_address.at(addr);
}
return std::nullopt;
}

std::shared_ptr<ucxx::MemoryHandle> RegistrationCache3::lookup(uintptr_t addr) const noexcept
std::optional<std::shared_ptr<ucxx::MemoryHandle>> RegistrationCache3::lookup(uintptr_t addr) const noexcept
{
return this->lookup(reinterpret_cast<const void*>(addr));
}
Expand Down
34 changes: 7 additions & 27 deletions cpp/mrc/src/internal/ucx/registration_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,8 @@ class RegistrationCache2 final
/**
* @brief UCX Registration Cache
*
* UCX memory registration object that will both register/deregister memory as well as cache the set of local and remote
* keys for each registration. The cache can be queried for the original memory block by providing any valid address
* contained in the contiguous block.
* UCX memory registration object that will both register/deregister memory. The cache can be queried for the original
* memory block by providing the starting address of the contiguous block.
*/
class RegistrationCache3 final
{
Expand All @@ -200,39 +199,20 @@ class RegistrationCache3 final

std::shared_ptr<ucxx::MemoryHandle> add_block(uintptr_t addr, std::size_t bytes);

/**
* @brief Deregister a contiguous block of memory from the ucx context and remove the cache entry
*
* @param addr
* @param bytes
* @return std::size_t
*/
std::size_t drop_block(const void* addr, std::size_t bytes);

std::size_t drop_block(uintptr_t addr, std::size_t bytes);

/**
* @brief Look up the memory registration details for a given address.
*
* This method queries the registration cache to find the UcxMemoryBlock containing the original address and size as
* well as the local and remote keys associated with the memory block.
*
* Any address contained within a registered block can be used to query the UcxMemoryBlock
* This method queries the registration cache to find the MemoryHanlde containing the original address and size as
* well as the serialized remote keys associated with the memory block.
*
* @param addr
* @return const MemoryBlock&
* @return std::shared_ptr<ucxx::MemoryHandle>
*/
std::shared_ptr<ucxx::MemoryHandle> lookup(const void* addr) const noexcept;
std::optional<std::shared_ptr<ucxx::MemoryHandle>> lookup(const void* addr) const noexcept;

std::shared_ptr<ucxx::MemoryHandle> lookup(uintptr_t addr) const noexcept;
std::optional<std::shared_ptr<ucxx::MemoryHandle>> lookup(uintptr_t addr) const noexcept;

private:
ucp_mem_h register_memory(const void* address, std::size_t bytes);

std::tuple<ucp_mem_h, void*, std::size_t> register_memory_with_rkey(const void* address, std::size_t bytes);

void unregister_memory(ucp_mem_h handle, void* rbuffer = nullptr);

mutable std::mutex m_mutex;
const std::shared_ptr<ucxx::Context> m_context;
std::map<const void*, std::shared_ptr<ucxx::MemoryHandle>> m_memory_handle_by_address;
Expand Down
15 changes: 10 additions & 5 deletions cpp/mrc/src/public/runtime/remote_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,12 +417,17 @@ void Descriptor2::setup_remote_payloads()

auto* deferred_msg = payload.mutable_deferred_msg();

// Need to register the memory
auto ucx_block = m_data_plane_resources.registration_cache3().add_block(deferred_msg->address(),
deferred_msg->bytes());
auto ucx_block = m_data_plane_resources.registration_cache3().lookup(deferred_msg->address());

auto serializedRemoteKey = ucx_block->createRemoteKey()->serialize();
deferred_msg->set_remote_key(serializedRemoteKey);
if (!ucx_block.has_value())
{
// Need to register the memory
ucx_block = m_data_plane_resources.registration_cache3().add_block(deferred_msg->address(),
deferred_msg->bytes());
}

auto remoteKey = ucx_block.value()->createRemoteKey();
deferred_msg->set_remote_key(remoteKey->serialize());
}
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/mrc/src/tests/test_network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ TEST_F(TestNetwork, TransferFullDescriptors)
static_assert(codable::member_decodable<ComplexObject>);
static_assert(codable::member_decodable<TransferObject>);

ComplexObject send_data = {"test", 42, {"test", 42, std::vector<u_int8_t>(8_KiB)}, std::vector<u_int8_t>(8_KiB)};
ComplexObject send_data = {"test", 42, {"test", 42, std::vector<u_int8_t>(64_KiB)}, std::vector<u_int8_t>(8_KiB)};

auto send_data_copy = send_data;

Expand Down Expand Up @@ -706,7 +706,7 @@ TEST_F(TestNetwork, TransferFullDescriptorsBroadcast)
// Create initial data
static_assert(codable::decodable<TransferObject>);

TransferObject send_data = {"test", 42, std::vector<u_int8_t>(8_KiB)};
TransferObject send_data = {"test", 42, std::vector<u_int8_t>(64_KiB)};

auto send_data_copy = send_data;

Expand Down