Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable shallow copy of handle_t's resources with different workspace_resource #1165

Merged
merged 5 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding copy constructor to handle/device_resources that enables a dif…
…ferent workspace_resource to be used.
  • Loading branch information
cjnolet committed Jan 24, 2023
commit 32af7f0acdd7701986e1ab7d0fa2640dca466641
5 changes: 1 addition & 4 deletions cpp/include/raft/core/resources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,7 @@ class resources {
* @brief Shallow copy of underlying resources instance.
* Note that this does not create any new resources.
*/
resources(const resources&)
: factories_(resources.factories_.copy()), resources_(resources.resources_.copy())
{
}
resources(const resources& res) : factories_(res.factories_), resources_(res.resources_) {}
resources& operator=(const resources&) = delete;
resources(resources&&) = delete;
resources& operator=(resources&&) = delete;
Expand Down
31 changes: 31 additions & 0 deletions cpp/test/core/handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,4 +268,35 @@ TEST(Raft, WorkspaceResource)
delete pool_mr;
}

TEST(Raft, WorkspaceResourceCopy)
{
auto stream_pool = std::make_shared<rmm::cuda_stream_pool>(10);

handle_t handle(rmm::cuda_stream_per_thread, stream_pool);

auto pool_mr = new rmm::mr::pool_memory_resource(rmm::mr::get_current_device_resource());

handle_t copied_handle(handle, pool_mr);

// Assert shallow copied state
ASSERT_EQ(handle.get_stream().value(), copied_handle.get_stream().value());
ASSERT_EQ(handle.get_stream_pool_size(), copied_handle.get_stream_pool_size());

// Sanity check to make sure non-corresponding streams are not equal
ASSERT_NE(handle.get_stream_pool().get_stream(0).value(),
copied_handle.get_stream_pool().get_stream(1).value());

for (size_t i = 0; i < handle.get_stream_pool_size(); ++i) {
ASSERT_EQ(handle.get_stream_pool().get_stream(i).value(),
copied_handle.get_stream_pool().get_stream(i).value());
}

// Assert the workspace_resources are what we expect
ASSERT_TRUE(dynamic_cast<const rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>*>(
handle.get_workspace_resource()) == nullptr);

ASSERT_TRUE(dynamic_cast<const rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>*>(
copied_handle.get_workspace_resource()) != nullptr);
}

} // namespace raft