@@ -57,25 +57,22 @@ static auto ThrowOnCudaError = [](CUresult res, int lineNum = -1) {
5757
5858class CudaResMgr {
5959 CudaResMgr () {
60+ lock_guard<mutex> lock (gContextsMutex );
61+ lock_guard<mutex> lock (gStreamsMutex );
62+
6063 ThrowOnCudaError (cuInit (0 ), __LINE__);
6164
6265 int nGpu;
6366 ThrowOnCudaError (cuDeviceGetCount (&nGpu), __LINE__);
64- {
65- lock_guard<mutex> lock (gContextsMutex );
66- for (int i = 0 ; i < nGpu; i++) {
67- CUcontext cuContext = nullptr ;
6867
69- g_Contexts.push_back (cuContext);
70- }
71- }
72- {
73- lock_guard<mutex> lock (gStreamsMutex );
74- for (int i = 0 ; i < nGpu; i++) {
75- CUstream cuStream = nullptr ;
7668
77- g_Streams.push_back (cuStream);
78- }
69+ for (int i = 0 ; i < nGpu; i++) {
70+ CUdevice cuDevice = 0 ;
71+ CUcontext cuContext = nullptr ;
72+ g_Contexts.push_back (make_pair (cuDevice,cuContext));
73+
74+ CUstream cuStream = nullptr ;
75+ g_Streams.push_back (cuStream);
7976 }
8077 return ;
8178 }
@@ -90,21 +87,23 @@ class CudaResMgr {
9087 if (idx >= GetNumGpus ()) {
9188 return nullptr ;
9289 }
90+
9391 lock_guard<mutex> lock (gContextsMutex );
9492 auto &ctx = g_Contexts[idx];
95- if (!ctx) {
93+ if (!ctx. second ) {
9694 CUdevice cuDevice = 0 ;
9795 ThrowOnCudaError (cuDeviceGet (&cuDevice, idx), __LINE__);
98- ThrowOnCudaError (cuCtxCreate (&ctx, 0 , cuDevice), __LINE__);
96+ ThrowOnCudaError (cuDevicePrimaryCtxRetain (&ctx. second , cuDevice), __LINE__);
9997 }
10098
101- return g_Contexts[idx];
99+ return g_Contexts[idx]. second ;
102100 }
103101
104102 CUstream GetStream (size_t idx) {
105103 if (idx >= GetNumGpus ()) {
106104 return nullptr ;
107105 }
106+
108107 lock_guard<mutex> lock (gStreamsMutex );
109108 auto &str = g_Streams[idx];
110109 if (!str) {
@@ -131,15 +130,15 @@ class CudaResMgr {
131130 }
132131 g_Streams.clear ();
133132 }
133+
134134 {
135135 lock_guard<mutex> lock (gContextsMutex );
136- for (auto &cuContext : g_Contexts) {
137- if (cuContext ) {
138- ThrowOnCudaError (cuCtxDestroy (cuContext ), __LINE__);
136+ for (int i= 0 ;i< g_Contexts. size ();i++ ) {
137+ if (g_Contexts[i]. second ) {
138+ ThrowOnCudaError (cuDevicePrimaryCtxRelease (g_Contexts[i]. first ), __LINE__);
139139 }
140140 }
141141 g_Contexts.clear ();
142-
143142 }
144143 } catch (runtime_error &e) {
145144 cerr << e.what () << endl;
@@ -154,7 +153,7 @@ class CudaResMgr {
154153
155154 static size_t GetNumGpus () { return Instance ().g_Contexts .size (); }
156155
157- vector<CUcontext> g_Contexts;
156+ vector<pair<CUdevice, CUcontext> > g_Contexts;
158157 vector<CUstream> g_Streams;
159158 mutex gContextsMutex ;
160159 mutex gStreamsMutex ;
0 commit comments