Skip to content

Commit fdb497d

Browse files
opal/cuda: Handle stream-ordered allocations and assign primary device context
Signed-off-by: Akshay Venkatesh <akvenkatesh@nvidia.com> (cherry picked from commit 5328616)
1 parent da2c8fd commit fdb497d

File tree

2 files changed

+108
-9
lines changed

2 files changed

+108
-9
lines changed

3rd-party/prrte

Submodule prrte updated 144 files

opal/mca/accelerator/cuda/accelerator_cuda.c

Lines changed: 107 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,96 @@ static int accelerator_cuda_get_device_id(CUcontext mem_ctx) {
159159
return dev_id;
160160
}
161161

162+
static int accelerator_cuda_check_mpool(CUdeviceptr dbuf, CUmemorytype *mem_type,
163+
int *dev_id)
164+
{
165+
#if OPAL_CUDA_VMM_SUPPORT
166+
static int device_count = -1;
167+
static int mpool_supported = -1;
168+
CUresult result;
169+
CUmemoryPool mpool;
170+
CUmemAccess_flags flags;
171+
CUmemLocation location;
172+
173+
if (mpool_supported <= 0) {
174+
if (mpool_supported == -1) {
175+
if (device_count == -1) {
176+
result = cuDeviceGetCount(&device_count);
177+
if (result != CUDA_SUCCESS || (0 == device_count)) {
178+
mpool_supported = 0; /* never check again */
179+
device_count = 0;
180+
return 0;
181+
}
182+
}
183+
184+
/* assume uniformity of devices */
185+
result = cuDeviceGetAttribute(&mpool_supported,
186+
CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, 0);
187+
if (result != CUDA_SUCCESS) {
188+
mpool_supported = 0;
189+
}
190+
}
191+
if (0 == mpool_supported) {
192+
return 0;
193+
}
194+
}
195+
196+
result = cuPointerGetAttribute(&mpool, CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE,
197+
dbuf);
198+
if (CUDA_SUCCESS != result) {
199+
return 0;
200+
}
201+
202+
/* check if device has access */
203+
for (int i = 0; i < device_count; i++) {
204+
location.type = CU_MEM_LOCATION_TYPE_DEVICE;
205+
location.id = i;
206+
result = cuMemPoolGetAccess(&flags, mpool, &location);
207+
if ((CUDA_SUCCESS == result) &&
208+
(CU_MEM_ACCESS_FLAGS_PROT_READWRITE == flags)) {
209+
*mem_type = CU_MEMORYTYPE_DEVICE;
210+
*dev_id = i;
211+
return 1;
212+
}
213+
}
214+
215+
/* host must have access as device access possibility is exhausted */
216+
*mem_type = CU_MEMORYTYPE_HOST;
217+
*dev_id = MCA_ACCELERATOR_NO_DEVICE_ID;
218+
return 0;
219+
#endif
220+
221+
return 0;
222+
}
223+
224+
static int accelerator_cuda_get_primary_context(CUdevice dev_id, CUcontext *pctx)
225+
{
226+
CUresult result;
227+
unsigned int flags;
228+
int active;
229+
230+
result = cuDevicePrimaryCtxGetState(dev_id, &flags, &active);
231+
if (CUDA_SUCCESS != result) {
232+
return OPAL_ERROR;
233+
}
234+
235+
if (active) {
236+
result = cuDevicePrimaryCtxRetain(pctx, dev_id);
237+
return OPAL_SUCCESS;
238+
}
239+
240+
return OPAL_ERROR;
241+
}
242+
162243
static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *flags)
163244
{
164245
CUresult result;
165246
int is_vmm = 0;
247+
int is_mpool_ptr = 0;
166248
int vmm_dev_id = MCA_ACCELERATOR_NO_DEVICE_ID;
249+
int mpool_dev_id = MCA_ACCELERATOR_NO_DEVICE_ID;
167250
CUmemorytype vmm_mem_type = 0;
251+
CUmemorytype mpool_mem_type = 0;
168252
CUmemorytype mem_type = 0;
169253
CUdeviceptr dbuf = (CUdeviceptr) addr;
170254
CUcontext ctx = NULL, mem_ctx = NULL;
@@ -177,6 +261,7 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
177261
*flags = 0;
178262

179263
is_vmm = accelerator_cuda_check_vmm(dbuf, &vmm_mem_type, &vmm_dev_id);
264+
is_mpool_ptr = accelerator_cuda_check_mpool(dbuf, &mpool_mem_type, &mpool_dev_id);
180265

181266
#if OPAL_CUDA_GET_ATTRIBUTES
182267
uint32_t is_managed = 0;
@@ -210,6 +295,9 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
210295
if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE)) {
211296
mem_type = CU_MEMORYTYPE_DEVICE;
212297
*dev_id = vmm_dev_id;
298+
} else if (is_mpool_ptr && (mpool_mem_type == CU_MEMORYTYPE_DEVICE)) {
299+
mem_type = CU_MEMORYTYPE_DEVICE;
300+
*dev_id = mpool_dev_id;
213301
} else {
214302
/* Host memory, nothing to do here */
215303
return 0;
@@ -220,6 +308,8 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
220308
} else {
221309
if (is_vmm) {
222310
*dev_id = vmm_dev_id;
311+
} else if (is_mpool_ptr) {
312+
*dev_id = mpool_dev_id;
223313
} else {
224314
/* query the device from the context */
225315
*dev_id = accelerator_cuda_get_device_id(mem_ctx);
@@ -238,13 +328,18 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
238328
if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE)) {
239329
mem_type = CU_MEMORYTYPE_DEVICE;
240330
*dev_id = vmm_dev_id;
331+
} else if (is_mpool_ptr && (mpool_mem_type == CU_MEMORYTYPE_DEVICE)) {
332+
mem_type = CU_MEMORYTYPE_DEVICE;
333+
*dev_id = mpool_dev_id;
241334
} else {
242335
/* Host memory, nothing to do here */
243336
return 0;
244337
}
245338
} else {
246339
if (is_vmm) {
247340
*dev_id = vmm_dev_id;
341+
} else if (is_mpool_ptr) {
342+
*dev_id = mpool_dev_id;
248343
} else {
249344
result = cuPointerGetAttribute(&mem_ctx,
250345
CU_POINTER_ATTRIBUTE_CONTEXT, dbuf);
@@ -278,14 +373,18 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
278373
return OPAL_ERROR;
279374
}
280375
#endif /* OPAL_CUDA_GET_ATTRIBUTES */
281-
if (is_vmm) {
282-
/* This function is expected to set context if pointer is device
283-
* accessible but VMM allocations have NULL context associated
284-
* which cannot be set against the calling thread */
285-
opal_output(0,
286-
"CUDA: unable to set context with the given pointer"
287-
"ptr=%p aborting...", addr);
288-
return OPAL_ERROR;
376+
if (is_vmm || is_mpool_ptr) {
377+
if (OPAL_SUCCESS ==
378+
accelerator_cuda_get_primary_context(
379+
is_vmm ? vmm_dev_id : mpool_dev_id, &mem_ctx)) {
380+
/* As VMM/mempool allocations have no context associated
381+
* with them, check if device primary context can be set */
382+
} else {
383+
opal_output(0,
384+
"CUDA: unable to set ctx with the given pointer"
385+
"ptr=%p aborting...", addr);
386+
return OPAL_ERROR;
387+
}
289388
}
290389

291390
result = cuCtxSetCurrent(mem_ctx);

0 commit comments

Comments
 (0)