Skip to content

Commit 2ad3f93

Browse files
committed
Add locks around cuda driver api and sparse calls
The driver API in CUDA is not thread safe(or we are using it incorrectly). We need to protect these calls with mutexes. The cuSparse library also needs to be protected with these locks.
1 parent 5176e1f commit 2ad3f93

File tree

5 files changed

+22
-7
lines changed

5 files changed

+22
-7
lines changed

src/api/c/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ af_err af_get_active_backend(af_backend *result)
7474
af_err af_init()
7575
{
7676
try {
77-
static std::once_flag flag;
77+
thread_local std::once_flag flag;
7878
std::call_once(flag, []() {
7979
getDeviceInfo();
8080
});

src/backend/cuda/jit.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ static kc_entry_t compileKernel(const char *ker_name, string jit_ker)
514514
reinterpret_cast<void*>(1)
515515
};
516516

517+
std::lock_guard<std::mutex> lock(getDriverApiMutex(getActiveDeviceId()));
517518
CU_CHECK(cuLinkCreate(5, linkOptions, linkOptionValues, &linkState));
518519
CU_CHECK(cuLinkAddData(linkState, CU_JIT_INPUT_PTX, (void*)ptx.get(),
519520
ptx_size, ker_name, 0, NULL, NULL));
@@ -689,6 +690,7 @@ void evalNodes(vector<Param<T> >&outputs, vector<Node *> output_nodes)
689690
args.push_back((void *)&num_odims);
690691
}
691692

693+
std::lock_guard<std::mutex> lock(getDriverApiMutex(getActiveDeviceId()));
692694
CU_CHECK(cuLaunchKernel(ker,
693695
blocks_x,
694696
blocks_y,
@@ -739,7 +741,4 @@ template void evalNodes<intl >(vector<Param<intl > > &out, vector<Node *> no
739741
template void evalNodes<uintl >(vector<Param<uintl > > &out, vector<Node *> node);
740742
template void evalNodes<short >(vector<Param<short > > &out, vector<Node *> node);
741743
template void evalNodes<ushort >(vector<Param<ushort > > &out, vector<Node *> node);
742-
743-
744-
745744
}

src/backend/cuda/platform.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ unsigned getMaxJitSize()
283283
return length;
284284
}
285285

286+
std::mutex& getDriverApiMutex(int device) {
287+
return DeviceManager::getInstance().driver_api_mutex[device];
288+
}
289+
286290
int& tlocalActiveDeviceId()
287291
{
288292
thread_local int activeDeviceId = 0;

src/backend/cuda/platform.hpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111

1212
#include <cuda.h>
1313
#include <cuda_runtime.h>
14-
#include <memory>
15-
#include <vector>
16-
#include <string>
1714
#include <memory.hpp>
1815
#include <GraphicsResourceManager.hpp>
1916
#include <cufft.hpp>
@@ -22,6 +19,11 @@
2219
#include <cusparse.hpp>
2320
#include <common/types.hpp>
2421

22+
#include <memory>
23+
#include <mutex>
24+
#include <string>
25+
#include <vector>
26+
2527
namespace cuda
2628
{
2729
int getBackend();
@@ -41,6 +43,8 @@ void devprop(char* d_name, char* d_platform, char *d_toolkit, char* d_compute);
4143

4244
unsigned getMaxJitSize();
4345

46+
std::mutex& getDriverApiMutex(int device);
47+
4448
int getDeviceCount();
4549

4650
int getActiveDeviceId();
@@ -111,6 +115,8 @@ class DeviceManager
111115
friend GraphicsResourceManager& interopManager();
112116
#endif
113117

118+
friend std::mutex& getDriverApiMutex(int device);
119+
114120
friend std::string getDeviceInfo(int device);
115121

116122
friend std::string getPlatformInfo();
@@ -159,6 +165,7 @@ class DeviceManager
159165

160166
std::unique_ptr<MemoryManagerPinned> pinnedMemManager;
161167

168+
std::mutex driver_api_mutex[MAX_DEVICES];
162169
#if defined(WITH_GRAPHICS)
163170
std::unique_ptr<GraphicsResourceManager> gfxManagers[MAX_DEVICES];
164171
#endif

src/backend/cuda/sparse_blas.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ Array<T> matmul(const common::SparseArray<T> lhs, const Array<T> rhs,
142142

143143
dim4 rStrides = rhs.strides();
144144

145+
// NOTE: The cuSparse library seems to be using the driver API in the
146+
// implementation. This is causing issues with our JIT kernel generation.
147+
// This may be a bug in the cuSparse library.
148+
std::lock_guard<std::mutex> lock(getDriverApiMutex(getActiveDeviceId()));
149+
145150
// Create Sparse Matrix Descriptor
146151
cusparseMatDescr_t descr = 0;
147152
CUSPARSE_CHECK(cusparseCreateMatDescr(&descr));

0 commit comments

Comments
 (0)