Skip to content

Commit 9fa76f9

Browse files
authored
[UR][L0 v2] Add USM pool tracking in context and fix async pool cleanup (#18802)
This PR mirrors fixes implemented in #18406 to adapter L0 v1. To cleanup async pools properly, USM pool handles must be tracked. This feature enables `UR_USM_ALLOC_INFO_POOL` prop to be implemented in `urUSMGetMemAllocInfo`.
1 parent 0dbf7ec commit 9fa76f9

File tree

7 files changed

+71
-13
lines changed

7 files changed

+71
-13
lines changed

unified-runtime/source/adapters/level_zero/usm.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ ur_result_t urUSMPoolCreate(
507507
*Pool = reinterpret_cast<ur_usm_pool_handle_t>(
508508
new ur_usm_pool_handle_t_(Context, PoolDesc));
509509

510-
std::shared_lock<ur_shared_mutex> ContextLock(Context->Mutex);
510+
std::scoped_lock<ur_shared_mutex> ContextLock(Context->Mutex);
511511
Context->UsmPoolHandles.insert(Context->UsmPoolHandles.cend(), *Pool);
512512

513513
} catch (const UsmAllocationException &Ex) {
@@ -531,7 +531,7 @@ ur_result_t
531531
/// [in] pointer to USM memory pool
532532
urUSMPoolRelease(ur_usm_pool_handle_t Pool) {
533533
if (Pool->RefCount.decrementAndTest()) {
534-
std::shared_lock<ur_shared_mutex> ContextLock(Pool->Context->Mutex);
534+
std::scoped_lock<ur_shared_mutex> ContextLock(Pool->Context->Mutex);
535535
Pool->Context->UsmPoolHandles.remove(Pool);
536536
delete Pool;
537537
}
@@ -610,7 +610,7 @@ ur_result_t UR_APICALL urUSMPoolCreateExp(
610610
*Pool = reinterpret_cast<ur_usm_pool_handle_t>(
611611
new ur_usm_pool_handle_t_(Context, Device, PoolDesc));
612612

613-
std::shared_lock<ur_shared_mutex> ContextLock(Context->Mutex);
613+
std::scoped_lock<ur_shared_mutex> ContextLock(Context->Mutex);
614614
Context->UsmPoolHandles.insert(Context->UsmPoolHandles.cend(), *Pool);
615615

616616
} catch (const UsmAllocationException &Ex) {
@@ -627,7 +627,7 @@ ur_result_t UR_APICALL urUSMPoolCreateExp(
627627
ur_result_t UR_APICALL urUSMPoolDestroyExp(ur_context_handle_t /*Context*/,
628628
ur_device_handle_t /*Device*/,
629629
ur_usm_pool_handle_t Pool) {
630-
std::shared_lock<ur_shared_mutex> ContextLock(Pool->Context->Mutex);
630+
std::scoped_lock<ur_shared_mutex> ContextLock(Pool->Context->Mutex);
631631
Pool->Context->UsmPoolHandles.remove(Pool);
632632
delete Pool;
633633

unified-runtime/source/adapters/level_zero/v2/context.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,16 @@ ur_usm_pool_handle_t ur_context_handle_t_::getDefaultUSMPool() {
116116

117117
ur_usm_pool_handle_t ur_context_handle_t_::getAsyncPool() { return &asyncPool; }
118118

119+
void ur_context_handle_t_::addUsmPool(ur_usm_pool_handle_t hPool) {
120+
std::scoped_lock<ur_shared_mutex> lock(Mutex);
121+
usmPoolHandles.push_back(hPool);
122+
}
123+
124+
void ur_context_handle_t_::removeUsmPool(ur_usm_pool_handle_t hPool) {
125+
std::scoped_lock<ur_shared_mutex> lock(Mutex);
126+
usmPoolHandles.remove(hPool);
127+
}
128+
119129
const std::vector<ur_device_handle_t> &
120130
ur_context_handle_t_::getP2PDevices(ur_device_handle_t hDevice) const {
121131
return p2pAccessDevices[hDevice->Id.value()];

unified-runtime/source/adapters/level_zero/v2/context.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ struct ur_context_handle_t_ : ur_object {
3333
ur_usm_pool_handle_t getDefaultUSMPool();
3434
ur_usm_pool_handle_t getAsyncPool();
3535

36+
void addUsmPool(ur_usm_pool_handle_t hPool);
37+
void removeUsmPool(ur_usm_pool_handle_t hPool);
38+
39+
template <typename Func> void forEachUsmPool(Func func) {
40+
std::shared_lock<ur_shared_mutex> lock(Mutex);
41+
for (const auto &hPool : usmPoolHandles) {
42+
if (!func(hPool))
43+
break;
44+
}
45+
}
46+
3647
const std::vector<ur_device_handle_t> &
3748
getP2PDevices(ur_device_handle_t hDevice) const;
3849

@@ -69,4 +80,5 @@ struct ur_context_handle_t_ : ur_object {
6980

7081
ur_usm_pool_handle_t_ defaultUSMPool;
7182
ur_usm_pool_handle_t_ asyncPool;
83+
std::list<ur_usm_pool_handle_t> usmPoolHandles;
7284
};

unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ ur_result_t ur_queue_immediate_in_order_t::queueFinish() {
160160
(commandListLocked->getZeCommandList(), UINT64_MAX));
161161

162162
hContext->getAsyncPool()->cleanupPoolsForQueue(this);
163+
hContext->forEachUsmPool([this](ur_usm_pool_handle_t hPool) {
164+
hPool->cleanupPoolsForQueue(this);
165+
return true;
166+
});
163167

164168
// Free deferred kernels
165169
for (auto &hKernel : submittedKernels) {

unified-runtime/source/adapters/level_zero/v2/usm.cpp

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,14 @@ ur_result_t ur_usm_pool_handle_t_::allocate(
229229

230230
*ppRetMem = umfPoolAlignedMalloc(umfPool, size, alignment);
231231
if (*ppRetMem == nullptr) {
232-
auto umfRet = umfPoolGetLastAllocationError(umfPool);
233-
return umf::umf2urResult(umfRet);
232+
if (pool->asyncPool.cleanup()) { // true means that objects were deallocated
233+
// let's try again
234+
*ppRetMem = umfPoolAlignedMalloc(umfPool, size, alignment);
235+
}
236+
if (*ppRetMem == nullptr) {
237+
auto umfRet = umfPoolGetLastAllocationError(umfPool);
238+
return umf::umf2urResult(umfRet);
239+
}
234240
}
235241

236242
return UR_RESULT_SUCCESS;
@@ -246,6 +252,18 @@ ur_result_t ur_usm_pool_handle_t_::free(void *ptr) {
246252
}
247253
}
248254

255+
bool ur_usm_pool_handle_t_::hasPool(const umf_memory_pool_handle_t umfPool) {
256+
bool found = false;
257+
poolManager.forEachPool([&](UsmPool *p) {
258+
if (p->umfPool.get() == umfPool) {
259+
found = true;
260+
return false; // break
261+
}
262+
return true;
263+
});
264+
return found;
265+
}
266+
249267
std::optional<std::pair<void *, ur_event_handle_t>>
250268
ur_usm_pool_handle_t_::allocateEnqueued(ur_context_handle_t hContext,
251269
void *hQueue, bool isInOrderQueue,
@@ -281,14 +299,14 @@ ur_usm_pool_handle_t_::allocateEnqueued(ur_context_handle_t hContext,
281299

282300
void ur_usm_pool_handle_t_::cleanupPools() {
283301
poolManager.forEachPool([&](UsmPool *p) {
284-
return p->asyncPool.cleanup();
302+
p->asyncPool.cleanup();
285303
return true;
286304
});
287305
}
288306

289307
void ur_usm_pool_handle_t_::cleanupPoolsForQueue(void *hQueue) {
290308
poolManager.forEachPool([&](UsmPool *p) {
291-
return p->asyncPool.cleanupForQueue(hQueue);
309+
p->asyncPool.cleanupForQueue(hQueue);
292310
return true;
293311
});
294312
}
@@ -303,6 +321,7 @@ ur_result_t urUSMPoolCreate(
303321
/// [out] pointer to USM memory pool
304322
ur_usm_pool_handle_t *hPool) try {
305323
*hPool = new ur_usm_pool_handle_t_(hContext, pPoolDesc);
324+
hContext->addUsmPool(*hPool);
306325
return UR_RESULT_SUCCESS;
307326
} catch (umf_result_t e) {
308327
return umf::umf2urResult(e);
@@ -325,6 +344,7 @@ ur_result_t
325344
/// [in] pointer to USM memory pool
326345
urUSMPoolRelease(ur_usm_pool_handle_t hPool) try {
327346
if (hPool->RefCount.decrementAndTest()) {
347+
hPool->getContextHandle()->removeUsmPool(hPool);
328348
delete hPool;
329349
}
330350
return UR_RESULT_SUCCESS;
@@ -517,13 +537,25 @@ ur_result_t urUSMGetMemAllocInfo(
517537
return ReturnValue(size);
518538
}
519539
case UR_USM_ALLOC_INFO_POOL: {
520-
// TODO
521-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
540+
auto umfPool = umfPoolByPtr(ptr);
541+
if (!umfPool) {
542+
return UR_RESULT_ERROR_INVALID_VALUE;
543+
}
544+
545+
ur_result_t ret = UR_RESULT_ERROR_INVALID_VALUE;
546+
hContext->forEachUsmPool([&](ur_usm_pool_handle_t hPool) {
547+
if (hPool->hasPool(umfPool)) {
548+
ret = ReturnValue(hPool);
549+
return false; // break;
550+
}
551+
return true;
552+
});
553+
return ret;
554+
}
522555
default:
523556
UR_LOG(ERR, "urUSMGetMemAllocInfo: unsupported ParamName");
524557
return UR_RESULT_ERROR_INVALID_VALUE;
525558
}
526-
}
527559
return UR_RESULT_SUCCESS;
528560
} catch (umf_result_t e) {
529561
return umf::umf2urResult(e);

unified-runtime/source/adapters/level_zero/v2/usm.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ struct ur_usm_pool_handle_t_ : ur_object {
3838
size_t size, void **ppRetMem);
3939
ur_result_t free(void *ptr);
4040

41+
bool hasPool(const umf_memory_pool_handle_t hPool);
42+
4143
std::optional<std::pair<void *, ur_event_handle_t>>
4244
allocateEnqueued(ur_context_handle_t hContext, void *hQueue,
4345
bool isInOrderQueue, ur_device_handle_t hDevice,

unified-runtime/test/conformance/usm/urUSMGetMemAllocInfo.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ UUR_DEVICE_TEST_SUITE_WITH_PARAM(
2525
uur::deviceTestWithParamPrinter<ur_usm_alloc_info_t>);
2626

2727
TEST_P(urUSMGetMemAllocInfoPoolTest, SuccessPool) {
28-
UUR_KNOWN_FAILURE_ON(uur::LevelZeroV2{});
29-
3028
size_t property_size = 0;
3129
const ur_usm_alloc_info_t property_name = UR_USM_ALLOC_INFO_POOL;
3230

0 commit comments

Comments
 (0)