Skip to content

[SYCL] Remove OwnZeMemHandle from USMAllocator #7853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions sycl/plugins/level_zero/pi_level_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4189,8 +4189,7 @@ pi_result piMemRetain(pi_mem Mem) {
// If indirect access tracking is not enabled then this functions just performs
// zeMemFree. If indirect access tracking is enabled then reference counting is
// performed.
static pi_result ZeMemFreeHelper(pi_context Context, void *Ptr,
bool OwnZeMemHandle = true) {
static pi_result ZeMemFreeHelper(pi_context Context, void *Ptr) {
pi_platform Plt = Context->getPlatform();
std::unique_lock<pi_shared_mutex> ContextsLock(Plt->ContextsMutex,
std::defer_lock);
Expand All @@ -4210,8 +4209,7 @@ static pi_result ZeMemFreeHelper(pi_context Context, void *Ptr,
Context->MemAllocs.erase(It);
}

if (OwnZeMemHandle)
ZE_CALL(zeMemFree, (Context->ZeContext, Ptr));
ZE_CALL(zeMemFree, (Context->ZeContext, Ptr));

if (IndirectAccessTrackingEnabled)
PI_CALL(ContextReleaseHelper(Context));
Expand All @@ -4220,7 +4218,7 @@ static pi_result ZeMemFreeHelper(pi_context Context, void *Ptr,
}

static pi_result USMFreeHelper(pi_context Context, void *Ptr,
bool OwnZeMemHandle);
bool OwnZeMemHandle = true);

pi_result piMemRelease(pi_mem Mem) {
PI_ASSERT(Mem, PI_ERROR_INVALID_MEM_OBJECT);
Expand Down Expand Up @@ -8085,10 +8083,8 @@ static pi_result USMHostAllocImpl(void **ResultPtr, pi_context Context,
return PI_SUCCESS;
}

static pi_result USMFreeImpl(pi_context Context, void *Ptr,
bool OwnZeMemHandle) {
if (OwnZeMemHandle)
ZE_CALL(zeMemFree, (Context->ZeContext, Ptr));
static pi_result USMFreeImpl(pi_context Context, void *Ptr) {
ZE_CALL(zeMemFree, (Context->ZeContext, Ptr));
return PI_SUCCESS;
}

Expand Down Expand Up @@ -8147,8 +8143,8 @@ void *USMMemoryAllocBase::allocate(size_t Size, size_t Alignment) {
return Ptr;
}

void USMMemoryAllocBase::deallocate(void *Ptr, bool OwnZeMemHandle) {
auto Res = USMFreeImpl(Context, Ptr, OwnZeMemHandle);
void USMMemoryAllocBase::deallocate(void *Ptr) {
auto Res = USMFreeImpl(Context, Ptr);
if (Res != PI_SUCCESS) {
throw UsmAllocationException(Res);
}
Expand Down Expand Up @@ -8396,8 +8392,13 @@ static pi_result USMFreeHelper(pi_context Context, void *Ptr,
Context->MemAllocs.erase(It);
}

if (!OwnZeMemHandle) {
// Memory should not be freed
return PI_SUCCESS;
}

if (!UseUSMAllocator) {
pi_result Res = USMFreeImpl(Context, Ptr, OwnZeMemHandle);
pi_result Res = USMFreeImpl(Context, Ptr);
if (IndirectAccessTrackingEnabled)
PI_CALL(ContextReleaseHelper(Context));
return Res;
Expand All @@ -8416,7 +8417,7 @@ static pi_result USMFreeHelper(pi_context Context, void *Ptr,
// If memory type is host release from host pool
if (ZeMemoryAllocationProperties.type == ZE_MEMORY_TYPE_HOST) {
try {
Context->HostMemAllocContext->deallocate(Ptr, OwnZeMemHandle);
Context->HostMemAllocContext->deallocate(Ptr);
} catch (const UsmAllocationException &Ex) {
return Ex.getError();
} catch (...) {
Expand Down Expand Up @@ -8444,16 +8445,16 @@ static pi_result USMFreeHelper(pi_context Context, void *Ptr,
PI_ASSERT(Device, PI_ERROR_INVALID_DEVICE);

auto DeallocationHelper =
[Context, Device, Ptr,
OwnZeMemHandle](std::unordered_map<ze_device_handle_t, USMAllocContext>
&AllocContextMap) {
[Context, Device,
Ptr](std::unordered_map<ze_device_handle_t, USMAllocContext>
&AllocContextMap) {
try {
auto It = AllocContextMap.find(Device->ZeDevice);
if (It == AllocContextMap.end())
return PI_ERROR_INVALID_VALUE;

// The right context is found, deallocate the pointer
It->second.deallocate(Ptr, OwnZeMemHandle);
It->second.deallocate(Ptr);
} catch (const UsmAllocationException &Ex) {
return Ex.getError();
}
Expand All @@ -8479,7 +8480,7 @@ static pi_result USMFreeHelper(pi_context Context, void *Ptr,
}
}

pi_result Res = USMFreeImpl(Context, Ptr, OwnZeMemHandle);
pi_result Res = USMFreeImpl(Context, Ptr);
if (SharedReadOnlyAllocsIterator != Context->SharedReadOnlyAllocs.end()) {
Context->SharedReadOnlyAllocs.erase(SharedReadOnlyAllocsIterator);
}
Expand All @@ -8494,7 +8495,7 @@ pi_result piextUSMFree(pi_context Context, void *Ptr) {
std::scoped_lock<pi_shared_mutex> Lock(
IndirectAccessTrackingEnabled ? Plt->ContextsMutex : Context->Mutex);

return USMFreeHelper(Context, Ptr, true /* OwnZeMemHandle */);
return USMFreeHelper(Context, Ptr);
}

pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
Expand Down Expand Up @@ -9410,11 +9411,11 @@ pi_result _pi_buffer::free() {
std::scoped_lock<pi_shared_mutex> Lock(
IndirectAccessTrackingEnabled ? Plt->ContextsMutex : Context->Mutex);

PI_CALL(USMFreeHelper(Context, ZeHandle, true));
PI_CALL(USMFreeHelper(Context, ZeHandle));
break;
}
case allocation_t::free_native:
PI_CALL(ZeMemFreeHelper(Context, ZeHandle, true));
PI_CALL(ZeMemFreeHelper(Context, ZeHandle));
break;
case allocation_t::unimport:
ZeUSMImport.doZeUSMRelease(Context->getPlatform()->ZeDriver, ZeHandle);
Expand Down
2 changes: 1 addition & 1 deletion sycl/plugins/level_zero/pi_level_zero.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class USMMemoryAllocBase : public SystemMemory {
: Context{Ctx}, Device{Dev} {}
void *allocate(size_t Size) override final;
void *allocate(size_t Size, size_t Alignment) override final;
void deallocate(void *Ptr, bool OwnZeMemHandle) override final;
void deallocate(void *Ptr) override final;
};

// Allocation routines for shared memory type
Expand Down
15 changes: 7 additions & 8 deletions sycl/plugins/unified_runtime/ur/usm_allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class USMAllocContext::USMAllocImpl {

void *allocate(size_t Size, size_t Alignment, bool &FromPool);
void *allocate(size_t Size, bool &FromPool);
void deallocate(void *Ptr, bool &ToPool, bool OwnZeMemHandle);
void deallocate(void *Ptr, bool &ToPool);

SystemMemory &getMemHandle() { return *MemHandle; }

Expand Down Expand Up @@ -332,7 +332,7 @@ Slab::Slab(Bucket &Bkt)

Slab::~Slab() {
unregSlab(*this);
bucket.getMemHandle().deallocate(MemPtr, true /* OwnZeMemHandle */);
bucket.getMemHandle().deallocate(MemPtr);
}

// Return the index of the first available chunk, -1 otherwize
Expand Down Expand Up @@ -737,8 +737,7 @@ Bucket &USMAllocContext::USMAllocImpl::findBucket(size_t Size) {
return *(*It);
}

void USMAllocContext::USMAllocImpl::deallocate(void *Ptr, bool &ToPool,
bool OwnZeMemHandle) {
void USMAllocContext::USMAllocImpl::deallocate(void *Ptr, bool &ToPool) {
auto *SlabPtr = AlignPtrDown(Ptr, SlabMinSize());

// Lock the map on read
Expand All @@ -748,7 +747,7 @@ void USMAllocContext::USMAllocImpl::deallocate(void *Ptr, bool &ToPool,
auto Slabs = getKnownSlabs().equal_range(SlabPtr);
if (Slabs.first == Slabs.second) {
Lk.unlock();
getMemHandle().deallocate(Ptr, OwnZeMemHandle);
getMemHandle().deallocate(Ptr);
return;
}

Expand Down Expand Up @@ -779,7 +778,7 @@ void USMAllocContext::USMAllocImpl::deallocate(void *Ptr, bool &ToPool,
// There is a rare case when we have a pointer from system allocation next
// to some slab with an entry in the map. So we find a slab
// but the range checks fail.
getMemHandle().deallocate(Ptr, OwnZeMemHandle);
getMemHandle().deallocate(Ptr);
}

USMAllocContext::USMAllocContext(std::unique_ptr<SystemMemory> MemHandle,
Expand Down Expand Up @@ -813,9 +812,9 @@ void *USMAllocContext::allocate(size_t size, size_t alignment) {
return Ptr;
}

void USMAllocContext::deallocate(void *ptr, bool OwnZeMemHandle) {
void USMAllocContext::deallocate(void *ptr) {
bool ToPool;
pImpl->deallocate(ptr, ToPool, OwnZeMemHandle);
pImpl->deallocate(ptr, ToPool);

if (pImpl->getParams().PoolTrace > 2) {
auto MT = pImpl->getParams().memoryTypeName;
Expand Down
4 changes: 2 additions & 2 deletions sycl/plugins/unified_runtime/ur/usm_allocator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SystemMemory {
public:
virtual void *allocate(size_t size) = 0;
virtual void *allocate(size_t size, size_t aligned) = 0;
virtual void deallocate(void *ptr, bool OwnZeMemHandle) = 0;
virtual void deallocate(void *ptr) = 0;
virtual ~SystemMemory() = default;
};

Expand Down Expand Up @@ -68,7 +68,7 @@ class USMAllocContext {

void *allocate(size_t size);
void *allocate(size_t size, size_t alignment);
void deallocate(void *ptr, bool OwnZeMemHandle);
void deallocate(void *ptr);

private:
std::unique_ptr<USMAllocImpl> pImpl;
Expand Down