Skip to content

Commit

Permalink
[L0 v2] implement urEnqueueMemBuffer[Map/Unmap]
Browse files Browse the repository at this point in the history
and extend ur_mem_handle_t implementations to support
async memory migration (right now, this is only used
for keeping data in sync between device and host allocations).

Also, implement generic memcpy/fill functions in queue which
can be used by both USM and Buffer operations.
  • Loading branch information
igchor committed Oct 11, 2024
1 parent 9553d93 commit c08e73b
Show file tree
Hide file tree
Showing 6 changed files with 525 additions and 338 deletions.
8 changes: 6 additions & 2 deletions source/adapters/level_zero/v2/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,

auto kernelDevices = hKernel->getDevices();
if (kernelDevices.size() == 1) {
auto zePtr = hArgValue->getPtr(kernelDevices.front());
auto zePtr = hArgValue->getDevicePtr(
kernelDevices.front(), ur_mem_handle_t_::access_mode_t::read_write, 0,
hArgValue->getSize(), nullptr);
return hKernel->setArgPointer(argIndex, nullptr, zePtr);
} else {
// TODO: if devices do not have p2p capabilities, we need to have allocation
Expand All @@ -324,7 +326,9 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,
// Get memory that is accessible by the first device.
// If kernel is submitted to a different device the memory
// will be accessed trough the link or migrated in enqueueKernelLaunch.
auto zePtr = hArgValue->getPtr(kernelDevices.front());
auto zePtr = hArgValue->getDevicePtr(
kernelDevices.front(), ur_mem_handle_t_::access_mode_t::read_write, 0,
hArgValue->getSize(), nullptr);
return hKernel->setArgPointer(argIndex, nullptr, zePtr);
}
}
Expand Down
190 changes: 162 additions & 28 deletions source/adapters/level_zero/v2/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,41 @@
ur_mem_handle_t_::ur_mem_handle_t_(ur_context_handle_t hContext, size_t size)
: hContext(hContext), size(size) {}

ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
void *hostPtr, size_t size,
host_ptr_action_t hostPtrAction)
ur_usm_handle_t_::ur_usm_handle_t_(ur_context_handle_t hContext, size_t size,
const void *ptr)
: ur_mem_handle_t_(hContext, size), ptr(const_cast<void *>(ptr)) {}

ur_usm_handle_t_::~ur_usm_handle_t_() {}

void *ur_usm_handle_t_::getDevicePtr(
ur_device_handle_t hDevice, access_mode_t access, size_t offset,
size_t size, std::function<void(void *src, void *dst, size_t)> migrate) {
std::ignore = hDevice;
std::ignore = access;
std::ignore = offset;
std::ignore = size;
std::ignore = migrate;
return ptr;
}

void *ur_usm_handle_t_::mapHostPtr(
access_mode_t access, size_t offset, size_t size,
std::function<void(void *src, void *dst, size_t)>) {
std::ignore = access;
std::ignore = offset;
std::ignore = size;
return ptr;
}

void ur_usm_handle_t_::unmapHostPtr(
void *pMappedPtr, std::function<void(void *src, void *dst, size_t)>) {
std::ignore = pMappedPtr;
/* nop */
}

ur_integrated_mem_handle_t::ur_integrated_mem_handle_t(
ur_context_handle_t hContext, void *hostPtr, size_t size,
host_ptr_action_t hostPtrAction)
: ur_mem_handle_t_(hContext, size) {
bool hostPtrImported = false;
if (hostPtrAction == host_ptr_action_t::import) {
Expand All @@ -37,7 +69,7 @@ ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
}
}

ur_host_mem_handle_t::~ur_host_mem_handle_t() {
ur_integrated_mem_handle_t::~ur_integrated_mem_handle_t() {
if (ptr) {
auto ret = hContext->getDefaultUSMPool()->free(ptr);
if (ret != UR_RESULT_SUCCESS) {
Expand All @@ -46,21 +78,36 @@ ur_host_mem_handle_t::~ur_host_mem_handle_t() {
}
}

void *ur_host_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
void *ur_integrated_mem_handle_t::getDevicePtr(
ur_device_handle_t hDevice, access_mode_t access, size_t offset,
size_t size, std::function<void(void *src, void *dst, size_t)> migrate) {
std::ignore = hDevice;
std::ignore = access;
std::ignore = offset;
std::ignore = size;
std::ignore = migrate;
return ptr;
}

ur_result_t ur_device_mem_handle_t::migrateBufferTo(ur_device_handle_t hDevice,
void *src, size_t size) {
auto Id = hDevice->Id.value();
void *ur_integrated_mem_handle_t::mapHostPtr(
access_mode_t access, size_t offset, size_t size,
std::function<void(void *src, void *dst, size_t)> migrate) {
std::ignore = access;
std::ignore = offset;
std::ignore = size;
std::ignore = migrate;
return ptr;
}

if (!deviceAllocations[Id]) {
UR_CALL(hContext->getDefaultUSMPool()->allocate(hContext, hDevice, nullptr,
UR_USM_TYPE_DEVICE, size,
&deviceAllocations[Id]));
}
void ur_integrated_mem_handle_t::unmapHostPtr(
void *pMappedPtr, std::function<void(void *src, void *dst, size_t)>) {
std::ignore = pMappedPtr;
/* nop */
}

static ur_result_t synchronousZeCopy(ur_context_handle_t hContext,
ur_device_handle_t hDevice, void *dst,
const void *src, size_t size) {
auto commandList = hContext->commandListCache.getImmediateCommandList(
hDevice->ZeDevice, true,
hDevice
Expand All @@ -70,26 +117,42 @@ ur_result_t ur_device_mem_handle_t::migrateBufferTo(ur_device_handle_t hDevice,
std::nullopt);

ZE2UR_CALL(zeCommandListAppendMemoryCopy,
(commandList.get(), deviceAllocations[Id], src, size, nullptr, 0,
nullptr));
(commandList.get(), dst, src, size, nullptr, 0, nullptr));

return UR_RESULT_SUCCESS;
}

ur_result_t
ur_discrete_mem_handle_t::migrateBufferTo(ur_device_handle_t hDevice, void *src,
size_t size) {
auto Id = hDevice->Id.value();

if (!deviceAllocations[Id]) {
UR_CALL(hContext->getDefaultUSMPool()->allocate(hContext, hDevice, nullptr,
UR_USM_TYPE_DEVICE, size,
&deviceAllocations[Id]));
}

UR_CALL(
synchronousZeCopy(hContext, hDevice, deviceAllocations[Id], src, size));

activeAllocationDevice = hDevice;

return UR_RESULT_SUCCESS;
}

ur_device_mem_handle_t::ur_device_mem_handle_t(ur_context_handle_t hContext,
void *hostPtr, size_t size)
ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(ur_context_handle_t hContext,
void *hostPtr, size_t size)
: ur_mem_handle_t_(hContext, size),
deviceAllocations(hContext->getPlatform()->getNumDevices()),
activeAllocationDevice(nullptr) {
activeAllocationDevice(nullptr), hostAllocations() {
if (hostPtr) {
auto initialDevice = hContext->getDevices()[0];
UR_CALL_THROWS(migrateBufferTo(initialDevice, hostPtr, size));
}
}

ur_device_mem_handle_t::~ur_device_mem_handle_t() {
ur_discrete_mem_handle_t::~ur_discrete_mem_handle_t() {
for (auto &ptr : deviceAllocations) {
if (ptr) {
auto ret = hContext->getDefaultUSMPool()->free(ptr);
Expand All @@ -100,8 +163,12 @@ ur_device_mem_handle_t::~ur_device_mem_handle_t() {
}
}

void *ur_device_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
std::lock_guard lock(this->Mutex);
void *ur_discrete_mem_handle_t::getDevicePtrUnlocked(
ur_device_handle_t hDevice, access_mode_t access, size_t offset,
size_t size, std::function<void(void *src, void *dst, size_t)> migrate) {
std::ignore = access;
std::ignore = size;
std::ignore = migrate;

if (!activeAllocationDevice) {
UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate(
Expand All @@ -110,8 +177,10 @@ void *ur_device_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
activeAllocationDevice = hDevice;
}

char *ptr;
if (activeAllocationDevice == hDevice) {
return deviceAllocations[hDevice->Id.value()];
ptr = ur_cast<char *>(deviceAllocations[hDevice->Id.value()]);
return ptr + offset;
}

auto &p2pDevices = hContext->getP2PDevices(hDevice);
Expand All @@ -124,7 +193,71 @@ void *ur_device_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
}

// TODO: see if it's better to migrate the memory to the specified device
return deviceAllocations[activeAllocationDevice->Id.value()];
return ur_cast<char *>(
deviceAllocations[activeAllocationDevice->Id.value()]) +
offset;
}

void *ur_discrete_mem_handle_t::getDevicePtr(
ur_device_handle_t hDevice, access_mode_t access, size_t offset,
size_t size, std::function<void(void *src, void *dst, size_t)> migrate) {
std::lock_guard lock(this->Mutex);
return getDevicePtrUnlocked(hDevice, access, offset, size, migrate);
}

void *ur_discrete_mem_handle_t::mapHostPtr(
access_mode_t access, size_t offset, size_t size,
std::function<void(void *src, void *dst, size_t)> migrate) {
std::lock_guard lock(this->Mutex);

// TODO: use async alloc?

void *ptr;
UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate(
hContext, nullptr, nullptr, UR_USM_TYPE_HOST, size, &ptr));

hostAllocations.emplace_back(ptr, size, offset, access);

if (activeAllocationDevice && access != access_mode_t::write_only) {
auto srcPtr =
ur_cast<char *>(deviceAllocations[activeAllocationDevice->Id.value()]) +
offset;
migrate(srcPtr, hostAllocations.back().ptr, size);
}

return hostAllocations.back().ptr;
}

void ur_discrete_mem_handle_t::unmapHostPtr(
void *pMappedPtr,
std::function<void(void *src, void *dst, size_t)> migrate) {
std::lock_guard lock(this->Mutex);

for (auto &hostAllocation : hostAllocations) {
if (hostAllocation.ptr == pMappedPtr) {
void *devicePtr = nullptr;
if (activeAllocationDevice) {
devicePtr = ur_cast<char *>(
deviceAllocations[activeAllocationDevice->Id.value()]) +
hostAllocation.offset;
} else if (hostAllocation.access != access_mode_t::write_invalidate) {
devicePtr = ur_cast<char *>(getDevicePtrUnlocked(
hContext->getDevices()[0], access_mode_t::read_only,
hostAllocation.offset, hostAllocation.size, migrate));
}

if (devicePtr) {
migrate(hostAllocation.ptr, devicePtr, hostAllocation.size);
}

// TODO: use async free here?
UR_CALL_THROWS(hContext->getDefaultUSMPool()->free(hostAllocation.ptr));
return;
}
}

// No mapping found
throw UR_RESULT_ERROR_INVALID_ARGUMENT;
}

namespace ur::level_zero {
Expand Down Expand Up @@ -155,13 +288,14 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
if (useHostBuffer) {
// TODO: assert that if hostPtr is set, either UR_MEM_FLAG_USE_HOST_POINTER
// or UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER is set?
auto hostPtrAction = flags & UR_MEM_FLAG_USE_HOST_POINTER
? ur_host_mem_handle_t::host_ptr_action_t::import
: ur_host_mem_handle_t::host_ptr_action_t::copy;
auto hostPtrAction =
flags & UR_MEM_FLAG_USE_HOST_POINTER
? ur_integrated_mem_handle_t::host_ptr_action_t::import
: ur_integrated_mem_handle_t::host_ptr_action_t::copy;
*phBuffer =
new ur_host_mem_handle_t(hContext, hostPtr, size, hostPtrAction);
new ur_integrated_mem_handle_t(hContext, hostPtr, size, hostPtrAction);
} else {
*phBuffer = new ur_device_mem_handle_t(hContext, hostPtr, size);
*phBuffer = new ur_discrete_mem_handle_t(hContext, hostPtr, size);
}

return UR_RESULT_SUCCESS;
Expand Down
Loading

0 comments on commit c08e73b

Please sign in to comment.