Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One cudaStream_t instance per raft::handle_t #291

Merged
merged 58 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
dc8ce65
checking in with handle changes
divyegala Jul 12, 2021
e9c88df
working handle cpp tests
divyegala Jul 14, 2021
ec5ec5d
working python handle
divyegala Jul 14, 2021
347b702
Merge branch 'branch-21.08' of https://github.com/rapidsai/raft into …
divyegala Jul 14, 2021
8b27ba2
styling changes
divyegala Jul 14, 2021
102ad4e
removing unnecessary TearDown from matrix gtest
divyegala Jul 14, 2021
6ddb1ff
renaming wrong variable name
divyegala Jul 14, 2021
52e775e
better doc for handle constructor according to review
divyegala Jul 15, 2021
54c67d4
review feedback
divyegala Jul 16, 2021
aedfa52
adjusting default handle stream to per thread
divyegala Jul 19, 2021
a502087
adjusting doc
divyegala Jul 19, 2021
8045f16
handle on knn detail API
divyegala Jul 19, 2021
8b2ab71
convenience function on handle to get stream from pool
divyegala Jul 19, 2021
fa320dd
correcting build
divyegala Jul 20, 2021
c25ab19
stream from pool at index
divyegala Jul 20, 2021
1ccc5cc
removing getting stream from pool functionality on handle
divyegala Jul 20, 2021
240fcf6
passing cpp tests
divyegala Sep 23, 2021
522c571
per-thread stream tests passing
divyegala Sep 23, 2021
89a23f6
solving pos argument
divyegala Oct 4, 2021
c24ecc8
merge upstream
divyegala Oct 4, 2021
e8a7856
passing tests
divyegala Oct 4, 2021
0c9871a
fix for failures in CI
divyegala Oct 4, 2021
5ab4f7c
Merge branch 'branch-21.12' of https://github.com/rapidsai/raft into …
divyegala Oct 13, 2021
830db09
review comments
divyegala Oct 14, 2021
2cf1e51
merging upstream
divyegala Oct 18, 2021
9a20bbf
resolving bad merge
divyegala Oct 18, 2021
7288978
changing sync method from cdef to def
divyegala Oct 22, 2021
ed6e4d8
removing cdef sync from handle pxd
divyegala Oct 28, 2021
9e83e9a
Merge branch 'branch-21.12' of https://github.com/rapidsai/raft into …
divyegala Oct 28, 2021
865fa7a
trying legacy stream
divyegala Nov 9, 2021
2044fb2
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 9, 2021
8bdbf81
back to default stream per thread
divyegala Nov 16, 2021
fe05b09
merging branch-22.02
divyegala Nov 16, 2021
d243eca
fixing bad merge
divyegala Nov 16, 2021
553453f
merge branch-21.12
divyegala Nov 16, 2021
5287e6e
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 17, 2021
2e60f56
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Nov 17, 2021
480ba37
correcting legacy to per-thread
divyegala Nov 17, 2021
1877061
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 22, 2021
1be9fc7
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Nov 22, 2021
0efbd91
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 22, 2021
ceac531
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Nov 22, 2021
239a887
merging upstream
divyegala Dec 7, 2021
41d0694
merging upstream
divyegala Dec 7, 2021
a89ab29
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Dec 9, 2021
d106d6e
fixing compiler error
divyegala Dec 9, 2021
daadd95
merging upstream
divyegala Dec 9, 2021
7051e39
Reverting fused l2 changes. cuml CI still seems to be broken
cjnolet Dec 10, 2021
6bb7eeb
Fixing style
cjnolet Dec 10, 2021
3322ebe
merging corey's fused l2 knn bug revert
divyegala Dec 10, 2021
cbb0540
fixing macro name
divyegala Dec 10, 2021
ea97177
fixing typo with curly brace
divyegala Dec 10, 2021
8338dcc
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Dec 10, 2021
9659249
Adding no throw macro variants
cjnolet Dec 10, 2021
d12db1c
Fixing typo
cjnolet Dec 10, 2021
6186ead
pulling corey's macro updates
divyegala Dec 10, 2021
5ed4289
merging upstream
divyegala Dec 10, 2021
e97a938
merging upstream
divyegala Dec 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions cpp/include/raft/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ class handle_t {
{
std::lock_guard<std::mutex> _(mutex_);
if (!cublas_initialized_) {
RAFT_CUBLAS_TRY(cublasCreate(&cublas_handle_));
RAFT_CUBLAS_TRY(cublasSetStream(cublas_handle_, stream_view_));
RAFT_CUBLAS_TRY_NO_THROW(cublasCreate(&cublas_handle_));
RAFT_CUBLAS_TRY_NO_THROW(cublasSetStream(cublas_handle_, stream_view_));
cublas_initialized_ = true;
}
return cublas_handle_;
Expand All @@ -94,8 +94,8 @@ class handle_t {
{
std::lock_guard<std::mutex> _(mutex_);
if (!cusolver_dn_initialized_) {
RAFT_CUSOLVER_TRY(cusolverDnCreate(&cusolver_dn_handle_));
RAFT_CUSOLVER_TRY(cusolverDnSetStream(cusolver_dn_handle_, stream_view_));
RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnCreate(&cusolver_dn_handle_));
RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnSetStream(cusolver_dn_handle_, stream_view_));
cusolver_dn_initialized_ = true;
}
return cusolver_dn_handle_;
Expand All @@ -105,8 +105,8 @@ class handle_t {
{
std::lock_guard<std::mutex> _(mutex_);
if (!cusolver_sp_initialized_) {
RAFT_CUSOLVER_TRY(cusolverSpCreate(&cusolver_sp_handle_));
RAFT_CUSOLVER_TRY(cusolverSpSetStream(cusolver_sp_handle_, stream_view_));
RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpCreate(&cusolver_sp_handle_));
RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpSetStream(cusolver_sp_handle_, stream_view_));
cusolver_sp_initialized_ = true;
}
return cusolver_sp_handle_;
Expand All @@ -116,8 +116,8 @@ class handle_t {
{
std::lock_guard<std::mutex> _(mutex_);
if (!cusparse_initialized_) {
RAFT_CUSPARSE_TRY(cusparseCreate(&cusparse_handle_));
RAFT_CUSPARSE_TRY(cusparseSetStream(cusparse_handle_, stream_view_));
RAFT_CUSPARSE_TRY_NO_THROW(cusparseCreate(&cusparse_handle_));
RAFT_CUSPARSE_TRY_NO_THROW(cusparseSetStream(cusparse_handle_, stream_view_));
cusparse_initialized_ = true;
}
return cusparse_handle_;
Expand Down Expand Up @@ -256,7 +256,7 @@ class handle_t {
{
std::lock_guard<std::mutex> _(mutex_);
if (!device_prop_initialized_) {
RAFT_CUDA_TRY(cudaGetDeviceProperties(&prop_, dev_id_));
RAFT_CUDA_TRY_NO_THROW(cudaGetDeviceProperties(&prop_, dev_id_));
device_prop_initialized_ = true;
}
return prop_;
Expand Down Expand Up @@ -292,12 +292,15 @@ class handle_t {

void destroy_resources()
{
///@todo: enable *_NO_THROW variants once we have enabled logging
if (cusparse_initialized_) { RAFT_CUSPARSE_TRY(cusparseDestroy(cusparse_handle_)); }
if (cusolver_dn_initialized_) { RAFT_CUSOLVER_TRY(cusolverDnDestroy(cusolver_dn_handle_)); }
if (cusolver_sp_initialized_) { RAFT_CUSOLVER_TRY(cusolverSpDestroy(cusolver_sp_handle_)); }
if (cublas_initialized_) { RAFT_CUBLAS_TRY(cublasDestroy(cublas_handle_)); }
RAFT_CUDA_TRY(cudaEventDestroy(event_));
if (cusparse_initialized_) { RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroy(cusparse_handle_)); }
if (cusolver_dn_initialized_) {
RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnDestroy(cusolver_dn_handle_));
}
if (cusolver_sp_initialized_) {
RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpDestroy(cusolver_sp_handle_));
}
if (cublas_initialized_) { RAFT_CUBLAS_TRY_NO_THROW(cublasDestroy(cublas_handle_)); }
RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(event_));
}
}; // class handle_t

Expand Down
23 changes: 23 additions & 0 deletions cpp/include/raft/linalg/cublas_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,31 @@ inline const char* cublas_error_to_string(cublasStatus_t err)
#define CUBLAS_TRY(call) RAFT_CUBLAS_TRY(call)
#endif

// /**
// * @brief check for cuda runtime API errors but log error instead of raising
// * exception.
// */
#define RAFT_CUBLAS_TRY_NO_THROW(call) \
do { \
cublasStatus_t const status = call; \
if (CUBLAS_STATUS_SUCCESS != status) { \
printf("CUBLAS call='%s' at file=%s line=%d failed with %s\n", \
#call, \
__FILE__, \
__LINE__, \
raft::linalg::detail::cublas_error_to_string(status)); \
} \
} while (0)

/** FIXME: remove after cuml rename */
#ifndef CUBLAS_CHECK
#define CUBLAS_CHECK(call) CUBLAS_TRY(call)
#endif

/** FIXME: remove after cuml rename */
#ifndef CUBLAS_CHECK_NO_THROW
#define CUBLAS_CHECK_NO_THROW(call) RAFT_CUBLAS_TRY_NO_THROW(call)
#endif

namespace raft {
namespace linalg {
Expand Down
20 changes: 20 additions & 0 deletions cpp/include/raft/linalg/cusolver_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,31 @@ inline const char* cusolver_error_to_string(cusolverStatus_t err)
#define CUSOLVER_TRY(call) RAFT_CUSOLVER_TRY(call)
#endif

// /**
// * @brief check for cuda runtime API errors but log error instead of raising
// * exception.
// */
#define RAFT_CUSOLVER_TRY_NO_THROW(call) \
do { \
cusolverStatus_t const status = call; \
if (CUSOLVER_STATUS_SUCCESS != status) { \
printf("CUSOLVER call='%s' at file=%s line=%d failed with %s\n", \
#call, \
__FILE__, \
__LINE__, \
raft::linalg::detail::cusolver_error_to_string(status)); \
} \
} while (0)

// FIXME: remove after cuml rename
#ifndef CUSOLVER_CHECK
#define CUSOLVER_CHECK(call) CUSOLVER_TRY(call)
#endif

#ifndef CUSOLVER_CHECK_NO_THROW
#define CUSOLVER_CHECK_NO_THROW(call) CUSOLVER_TRY_NO_THROW(call)
#endif

namespace raft {
namespace linalg {

Expand Down