Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One cudaStream_t instance per raft::handle_t #291

Merged
merged 58 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
dc8ce65
checking in with handle changes
divyegala Jul 12, 2021
e9c88df
working handle cpp tests
divyegala Jul 14, 2021
ec5ec5d
working python handle
divyegala Jul 14, 2021
347b702
Merge branch 'branch-21.08' of https://github.com/rapidsai/raft into …
divyegala Jul 14, 2021
8b27ba2
styling changes
divyegala Jul 14, 2021
102ad4e
removing unnecessary TearDown from matrix gtest
divyegala Jul 14, 2021
6ddb1ff
renaming wrong variable name
divyegala Jul 14, 2021
52e775e
better doc for handle constructor according to review
divyegala Jul 15, 2021
54c67d4
review feedback
divyegala Jul 16, 2021
aedfa52
adjusting default handle stream to per thread
divyegala Jul 19, 2021
a502087
adjusting doc
divyegala Jul 19, 2021
8045f16
handle on knn detail API
divyegala Jul 19, 2021
8b2ab71
convenience function on handle to get stream from pool
divyegala Jul 19, 2021
fa320dd
correcting build
divyegala Jul 20, 2021
c25ab19
stream from pool at index
divyegala Jul 20, 2021
1ccc5cc
removing getting stream from pool functionality on handle
divyegala Jul 20, 2021
240fcf6
passing cpp tests
divyegala Sep 23, 2021
522c571
per-thread stream tests passing
divyegala Sep 23, 2021
89a23f6
solving pos argument
divyegala Oct 4, 2021
c24ecc8
merge upstream
divyegala Oct 4, 2021
e8a7856
passing tests
divyegala Oct 4, 2021
0c9871a
fix for failures in CI
divyegala Oct 4, 2021
5ab4f7c
Merge branch 'branch-21.12' of https://github.com/rapidsai/raft into …
divyegala Oct 13, 2021
830db09
review comments
divyegala Oct 14, 2021
2cf1e51
merging upstream
divyegala Oct 18, 2021
9a20bbf
resolving bad merge
divyegala Oct 18, 2021
7288978
changing sync method from cdef to def
divyegala Oct 22, 2021
ed6e4d8
removing cdef sync from handle pxd
divyegala Oct 28, 2021
9e83e9a
Merge branch 'branch-21.12' of https://github.com/rapidsai/raft into …
divyegala Oct 28, 2021
865fa7a
trying legacy stream
divyegala Nov 9, 2021
2044fb2
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 9, 2021
8bdbf81
back to default stream per thread
divyegala Nov 16, 2021
fe05b09
merging branch-22.02
divyegala Nov 16, 2021
d243eca
fixing bad merge
divyegala Nov 16, 2021
553453f
merge branch-21.12
divyegala Nov 16, 2021
5287e6e
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 17, 2021
2e60f56
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Nov 17, 2021
480ba37
correcting legacy to per-thread
divyegala Nov 17, 2021
1877061
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 22, 2021
1be9fc7
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Nov 22, 2021
0efbd91
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 22, 2021
ceac531
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Nov 22, 2021
239a887
merging upstream
divyegala Dec 7, 2021
41d0694
merging upstream
divyegala Dec 7, 2021
a89ab29
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Dec 9, 2021
d106d6e
fixing compiler error
divyegala Dec 9, 2021
daadd95
merging upstream
divyegala Dec 9, 2021
7051e39
Reverting fused l2 changes. cuml CI still seems to be broken
cjnolet Dec 10, 2021
6bb7eeb
Fixing style
cjnolet Dec 10, 2021
3322ebe
merging corey's fused l2 knn bug revert
divyegala Dec 10, 2021
cbb0540
fixing macro name
divyegala Dec 10, 2021
ea97177
fixing typo with curly brace
divyegala Dec 10, 2021
8338dcc
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Dec 10, 2021
9659249
Adding no throw macro variants
cjnolet Dec 10, 2021
d12db1c
Fixing typo
cjnolet Dec 10, 2021
6186ead
pulling corey's macro updates
divyegala Dec 10, 2021
5ed4289
merging upstream
divyegala Dec 10, 2021
e97a938
merging upstream
divyegala Dec 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
review feedback
  • Loading branch information
divyegala committed Jul 16, 2021
commit 54c67d48606e6507d477c0df0d5202b809abdb1e
13 changes: 12 additions & 1 deletion cpp/include/raft/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class handle_t {
std::lock_guard<std::mutex> _(mutex_);
if (!cublas_initialized_) {
CUBLAS_CHECK(cublasCreate(&cublas_handle_));
CUBLAS_CHECK(cublasSetStream(cublas_handle, stream_view_));
CUBLAS_CHECK(cublasSetStream(cublas_handle_, stream_view_));
cublas_initialized_ = true;
}
return cublas_handle_;
Expand Down Expand Up @@ -163,6 +163,17 @@ class handle_t {
}
}

/**
* @brief synchronize subset of stream pool
*
* @param[in] stream_indices the indices of the streams in the stream pool to synchronize
*/
void sync_stream_pool(const std::vector<std::size_t> stream_indices) const {
for (const auto& stream_index : stream_indices) {
stream_pool_.get_stream(stream_index).synchronize();
}
}

/**
* @brief ask stream pool to wait on last event in main stream
*/
Expand Down
5 changes: 3 additions & 2 deletions cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,9 @@ void brute_force_knn_impl(std::vector<float *> &input, std::vector<int> &sizes,
float *out_d_ptr = out_D + (i * k * n);
int64_t *out_i_ptr = out_I + (i * k * n);

cudaStream_t stream =
n_internal_streams > 0 ? internalStreams.get_stream() : userStream;
cudaStream_t stream = n_internal_streams > 0
? internalStreams.get_stream().value()
: userStream;

switch (metric) {
case raft::distance::DistanceType::Haversine:
Expand Down
5 changes: 3 additions & 2 deletions cpp/test/handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ TEST(Raft, HandleDefault) {

TEST(Raft, Handle) {
// test stream pool creation
rmm::cuda_stream_pool stream_pool{4};
constexpr std::size_t n_streams = 4;
rmm::cuda_stream_pool stream_pool{n_streams};
handle_t h(rmm::cuda_stream_default, stream_pool);
ASSERT_EQ(4, h.get_stream_pool().get_pool_size());
ASSERT_EQ(n_streams, h.get_stream_pool().get_pool_size());

// test non default stream handle
cudaStream_t stream;
Expand Down