@@ -159,12 +159,96 @@ static int accelerator_cuda_get_device_id(CUcontext mem_ctx) {
159
159
return dev_id ;
160
160
}
161
161
162
+ static int accelerator_cuda_check_mpool (CUdeviceptr dbuf , CUmemorytype * mem_type ,
163
+ int * dev_id )
164
+ {
165
+ #if OPAL_CUDA_VMM_SUPPORT
166
+ static int device_count = -1 ;
167
+ static int mpool_supported = -1 ;
168
+ CUresult result ;
169
+ CUmemoryPool mpool ;
170
+ CUmemAccess_flags flags ;
171
+ CUmemLocation location ;
172
+
173
+ if (mpool_supported <= 0 ) {
174
+ if (mpool_supported == -1 ) {
175
+ if (device_count == -1 ) {
176
+ result = cuDeviceGetCount (& device_count );
177
+ if (result != CUDA_SUCCESS || (0 == device_count )) {
178
+ mpool_supported = 0 ; /* never check again */
179
+ device_count = 0 ;
180
+ return 0 ;
181
+ }
182
+ }
183
+
184
+ /* assume uniformity of devices */
185
+ result = cuDeviceGetAttribute (& mpool_supported ,
186
+ CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED , 0 );
187
+ if (result != CUDA_SUCCESS ) {
188
+ mpool_supported = 0 ;
189
+ }
190
+ }
191
+ if (0 == mpool_supported ) {
192
+ return 0 ;
193
+ }
194
+ }
195
+
196
+ result = cuPointerGetAttribute (& mpool , CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE ,
197
+ dbuf );
198
+ if (CUDA_SUCCESS != result ) {
199
+ return 0 ;
200
+ }
201
+
202
+ /* check if device has access */
203
+ for (int i = 0 ; i < device_count ; i ++ ) {
204
+ location .type = CU_MEM_LOCATION_TYPE_DEVICE ;
205
+ location .id = i ;
206
+ result = cuMemPoolGetAccess (& flags , mpool , & location );
207
+ if ((CUDA_SUCCESS == result ) &&
208
+ (CU_MEM_ACCESS_FLAGS_PROT_READWRITE == flags )) {
209
+ * mem_type = CU_MEMORYTYPE_DEVICE ;
210
+ * dev_id = i ;
211
+ return 1 ;
212
+ }
213
+ }
214
+
215
+ /* host must have access as device access possibility is exhausted */
216
+ * mem_type = CU_MEMORYTYPE_HOST ;
217
+ * dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
218
+ return 0 ;
219
+ #endif
220
+
221
+ return 0 ;
222
+ }
223
+
224
+ static int accelerator_cuda_get_primary_context (CUdevice dev_id , CUcontext * pctx )
225
+ {
226
+ CUresult result ;
227
+ unsigned int flags ;
228
+ int active ;
229
+
230
+ result = cuDevicePrimaryCtxGetState (dev_id , & flags , & active );
231
+ if (CUDA_SUCCESS != result ) {
232
+ return OPAL_ERROR ;
233
+ }
234
+
235
+ if (active ) {
236
+ result = cuDevicePrimaryCtxRetain (pctx , dev_id );
237
+ return OPAL_SUCCESS ;
238
+ }
239
+
240
+ return OPAL_ERROR ;
241
+ }
242
+
162
243
static int accelerator_cuda_check_addr (const void * addr , int * dev_id , uint64_t * flags )
163
244
{
164
245
CUresult result ;
165
246
int is_vmm = 0 ;
247
+ int is_mpool_ptr = 0 ;
166
248
int vmm_dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
249
+ int mpool_dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
167
250
CUmemorytype vmm_mem_type = 0 ;
251
+ CUmemorytype mpool_mem_type = 0 ;
168
252
CUmemorytype mem_type = 0 ;
169
253
CUdeviceptr dbuf = (CUdeviceptr ) addr ;
170
254
CUcontext ctx = NULL , mem_ctx = NULL ;
@@ -177,6 +261,7 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
177
261
* flags = 0 ;
178
262
179
263
is_vmm = accelerator_cuda_check_vmm (dbuf , & vmm_mem_type , & vmm_dev_id );
264
+ is_mpool_ptr = accelerator_cuda_check_mpool (dbuf , & mpool_mem_type , & mpool_dev_id );
180
265
181
266
#if OPAL_CUDA_GET_ATTRIBUTES
182
267
uint32_t is_managed = 0 ;
@@ -210,6 +295,9 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
210
295
if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE )) {
211
296
mem_type = CU_MEMORYTYPE_DEVICE ;
212
297
* dev_id = vmm_dev_id ;
298
+ } else if (is_mpool_ptr && (mpool_mem_type == CU_MEMORYTYPE_DEVICE )) {
299
+ mem_type = CU_MEMORYTYPE_DEVICE ;
300
+ * dev_id = mpool_dev_id ;
213
301
} else {
214
302
/* Host memory, nothing to do here */
215
303
return 0 ;
@@ -220,6 +308,8 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
220
308
} else {
221
309
if (is_vmm ) {
222
310
* dev_id = vmm_dev_id ;
311
+ } else if (is_mpool_ptr ) {
312
+ * dev_id = mpool_dev_id ;
223
313
} else {
224
314
/* query the device from the context */
225
315
* dev_id = accelerator_cuda_get_device_id (mem_ctx );
@@ -238,13 +328,18 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
238
328
if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE )) {
239
329
mem_type = CU_MEMORYTYPE_DEVICE ;
240
330
* dev_id = vmm_dev_id ;
331
+ } else if (is_mpool_ptr && (mpool_mem_type == CU_MEMORYTYPE_DEVICE )) {
332
+ mem_type = CU_MEMORYTYPE_DEVICE ;
333
+ * dev_id = mpool_dev_id ;
241
334
} else {
242
335
/* Host memory, nothing to do here */
243
336
return 0 ;
244
337
}
245
338
} else {
246
339
if (is_vmm ) {
247
340
* dev_id = vmm_dev_id ;
341
+ } else if (is_mpool_ptr ) {
342
+ * dev_id = mpool_dev_id ;
248
343
} else {
249
344
result = cuPointerGetAttribute (& mem_ctx ,
250
345
CU_POINTER_ATTRIBUTE_CONTEXT , dbuf );
@@ -278,14 +373,18 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
278
373
return OPAL_ERROR ;
279
374
}
280
375
#endif /* OPAL_CUDA_GET_ATTRIBUTES */
281
- if (is_vmm ) {
282
- /* This function is expected to set context if pointer is device
283
- * accessible but VMM allocations have NULL context associated
284
- * which cannot be set against the calling thread */
285
- opal_output (0 ,
286
- "CUDA: unable to set context with the given pointer"
287
- "ptr=%p aborting..." , addr );
288
- return OPAL_ERROR ;
376
+ if (is_vmm || is_mpool_ptr ) {
377
+ if (OPAL_SUCCESS ==
378
+ accelerator_cuda_get_primary_context (
379
+ is_vmm ? vmm_dev_id : mpool_dev_id , & mem_ctx )) {
380
+ /* As VMM/mempool allocations have no context associated
381
+ * with them, check if device primary context can be set */
382
+ } else {
383
+ opal_output (0 ,
384
+ "CUDA: unable to set ctx with the given pointer"
385
+ "ptr=%p aborting..." , addr );
386
+ return OPAL_ERROR ;
387
+ }
289
388
}
290
389
291
390
result = cuCtxSetCurrent (mem_ctx );
0 commit comments