From 58ba98b0d42b4189a6ce7def10fcefa6c2ed0e23 Mon Sep 17 00:00:00 2001 From: omarahmed1111 Date: Thu, 8 Feb 2024 14:38:21 +0000 Subject: [PATCH] Refactor ext-function caching --- source/adapters/opencl/adapter.cpp | 8 +- source/adapters/opencl/command_buffer.cpp | 138 +++++--------- source/adapters/opencl/command_buffer.hpp | 4 + source/adapters/opencl/common.hpp | 104 ---------- source/adapters/opencl/context.hpp | 2 + source/adapters/opencl/device.hpp | 1 - source/adapters/opencl/enqueue.cpp | 129 ++++++------- source/adapters/opencl/kernel.cpp | 55 +++--- source/adapters/opencl/kernel.hpp | 2 + source/adapters/opencl/memory.cpp | 20 +- source/adapters/opencl/platform.hpp | 73 +++++-- source/adapters/opencl/program.cpp | 37 ++-- source/adapters/opencl/queue.hpp | 2 + source/adapters/opencl/usm.cpp | 222 ++++++++++------------ 14 files changed, 335 insertions(+), 462 deletions(-) diff --git a/source/adapters/opencl/adapter.cpp b/source/adapters/opencl/adapter.cpp index 8ae1e77755..fbbdd84e59 100644 --- a/source/adapters/opencl/adapter.cpp +++ b/source/adapters/opencl/adapter.cpp @@ -22,9 +22,7 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) { if (NumEntries > 0 && phAdapters) { std::lock_guard Lock{adapter.Mutex}; - if (adapter.RefCount++ == 0) { - cl_ext::ExtFuncPtrCache = std::make_unique(); - } + adapter.RefCount++; *phAdapters = &adapter; } @@ -43,9 +41,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { std::lock_guard Lock{adapter.Mutex}; - if (--adapter.RefCount == 0) { - cl_ext::ExtFuncPtrCache.reset(); - } + --adapter.RefCount; return UR_RESULT_SUCCESS; } diff --git a/source/adapters/opencl/command_buffer.cpp b/source/adapters/opencl/command_buffer.cpp index 1c57246eca..cd7ff446c8 100644 --- a/source/adapters/opencl/command_buffer.cpp +++ b/source/adapters/opencl/command_buffer.cpp @@ -14,6 +14,7 @@ #include "event.hpp" #include "kernel.hpp" #include "memory.hpp" +#include "platform.hpp" #include "queue.hpp" UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp( @@ -24,15 +25,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp( ur_queue_handle_t Queue = nullptr; UR_RETURN_ON_FAILURE(urQueueCreate(hContext, hDevice, nullptr, &Queue)); - cl_context CLContext = hContext->get(); - cl_ext::clCreateCommandBufferKHR_fn clCreateCommandBufferKHR = nullptr; - cl_int Res = - cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clCreateCommandBufferKHRCache, - cl_ext::CreateCommandBufferName, &clCreateCommandBufferKHR); + ur_platform_handle_t Platform = hDevice->Platform; + cl_ext::clCreateCommandBufferKHR_fn clCreateCommandBufferKHR = + Platform->ExtFuncPtr->clCreateCommandBufferKHRCache; + UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clCreateCommandBufferKHR, + cl_ext::CreateCommandBufferName)); - if (!clCreateCommandBufferKHR || Res != CL_SUCCESS) - return UR_RESULT_ERROR_INVALID_OPERATION; + cl_int Res = 0; cl_command_queue CLQueue = Queue->get(); auto CLCommandBuffer = clCreateCommandBufferKHR(1, &CLQueue, nullptr, &Res); CL_RETURN_ON_FAILURE_AND_SET_NULL(Res, phCommandBuffer); @@ -55,14 +54,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) { UR_RETURN_ON_FAILURE(urQueueRetain(hCommandBuffer->hInternalQueue)); - cl_context CLContext = hCommandBuffer->hContext->get(); - cl_ext::clRetainCommandBufferKHR_fn clRetainCommandBuffer = nullptr; - cl_int Res = cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clRetainCommandBufferKHRCache, - cl_ext::RetainCommandBufferName, &clRetainCommandBuffer); - - if (!clRetainCommandBuffer || Res != CL_SUCCESS) - return UR_RESULT_ERROR_INVALID_OPERATION; + ur_platform_handle_t Platform = hCommandBuffer->getPlatform(); + cl_ext::clRetainCommandBufferKHR_fn clRetainCommandBuffer = + Platform->ExtFuncPtr->clRetainCommandBufferKHRCache; + UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clRetainCommandBuffer, + cl_ext::RetainCommandBufferName)); CL_RETURN_ON_FAILURE(clRetainCommandBuffer(hCommandBuffer->CLCommandBuffer)); return UR_RESULT_SUCCESS; @@ -72,15 +68,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) { UR_RETURN_ON_FAILURE(urQueueRelease(hCommandBuffer->hInternalQueue)); - cl_context CLContext = hCommandBuffer->hContext->get(); - cl_ext::clReleaseCommandBufferKHR_fn clReleaseCommandBufferKHR = nullptr; - cl_int Res = - cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clReleaseCommandBufferKHRCache, - cl_ext::ReleaseCommandBufferName, &clReleaseCommandBufferKHR); - - if (!clReleaseCommandBufferKHR || Res != CL_SUCCESS) - return UR_RESULT_ERROR_INVALID_OPERATION; + ur_platform_handle_t Platform = hCommandBuffer->getPlatform(); + cl_ext::clReleaseCommandBufferKHR_fn clReleaseCommandBufferKHR = + Platform->ExtFuncPtr->clReleaseCommandBufferKHRCache; + UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clReleaseCommandBufferKHR, + cl_ext::ReleaseCommandBufferName)); CL_RETURN_ON_FAILURE( clReleaseCommandBufferKHR(hCommandBuffer->CLCommandBuffer)); @@ -89,15 +81,11 @@ urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) { UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t hCommandBuffer) { - cl_context CLContext = hCommandBuffer->hContext->get(); - cl_ext::clFinalizeCommandBufferKHR_fn clFinalizeCommandBufferKHR = nullptr; - cl_int Res = - cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clFinalizeCommandBufferKHRCache, - cl_ext::FinalizeCommandBufferName, &clFinalizeCommandBufferKHR); - - if (!clFinalizeCommandBufferKHR || Res != CL_SUCCESS) - return UR_RESULT_ERROR_INVALID_OPERATION; + ur_platform_handle_t Platform = hCommandBuffer->getPlatform(); + cl_ext::clFinalizeCommandBufferKHR_fn clFinalizeCommandBufferKHR = + Platform->ExtFuncPtr->clFinalizeCommandBufferKHRCache; + UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clFinalizeCommandBufferKHR, + cl_ext::FinalizeCommandBufferName)); CL_RETURN_ON_FAILURE( clFinalizeCommandBufferKHR(hCommandBuffer->CLCommandBuffer)); @@ -113,15 +101,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( ur_exp_command_buffer_sync_point_t *pSyncPoint, ur_exp_command_buffer_command_handle_t *) { - cl_context CLContext = hCommandBuffer->hContext->get(); - cl_ext::clCommandNDRangeKernelKHR_fn clCommandNDRangeKernelKHR = nullptr; - cl_int Res = - cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clCommandNDRangeKernelKHRCache, - cl_ext::CommandNRRangeKernelName, &clCommandNDRangeKernelKHR); - - if (!clCommandNDRangeKernelKHR || Res != CL_SUCCESS) - return UR_RESULT_ERROR_INVALID_OPERATION; + ur_platform_handle_t Platform = hCommandBuffer->getPlatform(); + cl_ext::clCommandNDRangeKernelKHR_fn clCommandNDRangeKernelKHR = + Platform->ExtFuncPtr->clCommandNDRangeKernelKHRCache; + UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clCommandNDRangeKernelKHR, + cl_ext::CommandNRRangeKernelName)); CL_RETURN_ON_FAILURE(clCommandNDRangeKernelKHR( hCommandBuffer->CLCommandBuffer, nullptr, nullptr, hKernel->get(), @@ -160,14 +144,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp( const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, ur_exp_command_buffer_sync_point_t *pSyncPoint) { - cl_context CLContext = hCommandBuffer->hContext->get(); - cl_ext::clCommandCopyBufferKHR_fn clCommandCopyBufferKHR = nullptr; - cl_int Res = cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clCommandCopyBufferKHRCache, - cl_ext::CommandCopyBufferName, &clCommandCopyBufferKHR); - - if (!clCommandCopyBufferKHR || Res != CL_SUCCESS) - return UR_RESULT_ERROR_INVALID_OPERATION; + ur_platform_handle_t Platform = hCommandBuffer->getPlatform(); + cl_ext::clCommandCopyBufferKHR_fn clCommandCopyBufferKHR = + Platform->ExtFuncPtr->clCommandCopyBufferKHRCache; + UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clCommandCopyBufferKHR, + cl_ext::CommandCopyBufferName)); CL_RETURN_ON_FAILURE(clCommandCopyBufferKHR( hCommandBuffer->CLCommandBuffer, nullptr, hSrcMem->get(), hDstMem->get(), @@ -195,15 +176,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp( size_t OpenCLDstRect[3]{dstOrigin.x, dstOrigin.y, dstOrigin.z}; size_t OpenCLRegion[3]{region.width, region.height, region.depth}; - cl_context CLContext = hCommandBuffer->hContext->get(); - cl_ext::clCommandCopyBufferRectKHR_fn clCommandCopyBufferRectKHR = nullptr; - cl_int Res = - cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clCommandCopyBufferRectKHRCache, - cl_ext::CommandCopyBufferRectName, &clCommandCopyBufferRectKHR); - - if (!clCommandCopyBufferRectKHR || Res != CL_SUCCESS) - return UR_RESULT_ERROR_INVALID_OPERATION; + ur_platform_handle_t Platform = hCommandBuffer->getPlatform(); + cl_ext::clCommandCopyBufferRectKHR_fn clCommandCopyBufferRectKHR = + Platform->ExtFuncPtr->clCommandCopyBufferRectKHRCache; + UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clCommandCopyBufferRectKHR, + cl_ext::CommandCopyBufferRectName)); CL_RETURN_ON_FAILURE(clCommandCopyBufferRectKHR( hCommandBuffer->CLCommandBuffer, nullptr, hSrcMem->get(), hDstMem->get(), @@ -284,14 +261,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp( const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, ur_exp_command_buffer_sync_point_t *pSyncPoint) { - cl_context CLContext = hCommandBuffer->hContext->get(); - cl_ext::clCommandFillBufferKHR_fn clCommandFillBufferKHR = nullptr; - cl_int Res = cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clCommandFillBufferKHRCache, - cl_ext::CommandFillBufferName, &clCommandFillBufferKHR); - - if (!clCommandFillBufferKHR || Res != CL_SUCCESS) - return UR_RESULT_ERROR_INVALID_OPERATION; + ur_platform_handle_t Platform = hCommandBuffer->getPlatform(); + cl_ext::clCommandFillBufferKHR_fn clCommandFillBufferKHR = + Platform->ExtFuncPtr->clCommandFillBufferKHRCache; + UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clCommandFillBufferKHR, + cl_ext::CommandFillBufferName)); CL_RETURN_ON_FAILURE(clCommandFillBufferKHR( hCommandBuffer->CLCommandBuffer, nullptr, hBuffer->get(), pPattern, @@ -340,15 +314,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - cl_context CLContext = hCommandBuffer->hContext->get(); - cl_ext::clEnqueueCommandBufferKHR_fn clEnqueueCommandBufferKHR = nullptr; - cl_int Res = - cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clEnqueueCommandBufferKHRCache, - cl_ext::EnqueueCommandBufferName, &clEnqueueCommandBufferKHR); - - if (!clEnqueueCommandBufferKHR || Res != CL_SUCCESS) - return UR_RESULT_ERROR_INVALID_OPERATION; + ur_platform_handle_t Platform = hCommandBuffer->getPlatform(); + cl_ext::clEnqueueCommandBufferKHR_fn clEnqueueCommandBufferKHR = + Platform->ExtFuncPtr->clEnqueueCommandBufferKHRCache; + UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clEnqueueCommandBufferKHR, + cl_ext::EnqueueCommandBufferName)); const uint32_t NumberOfQueues = 1; cl_event Event; @@ -396,15 +366,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetInfoExp( ur_exp_command_buffer_info_t propName, size_t propSize, void *pPropValue, size_t *pPropSizeRet) { - cl_context CLContext = cl_adapter::cast(hCommandBuffer->hContext); - cl_ext::clGetCommandBufferInfoKHR_fn clGetCommandBufferInfoKHR = nullptr; - cl_int Res = - cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clGetCommandBufferInfoKHRCache, - cl_ext::GetCommandBufferInfoName, &clGetCommandBufferInfoKHR); - - if (!clGetCommandBufferInfoKHR || Res != CL_SUCCESS) - return UR_RESULT_ERROR_INVALID_OPERATION; + ur_platform_handle_t Platform = hCommandBuffer->getPlatform(); + cl_ext::clGetCommandBufferInfoKHR_fn clGetCommandBufferInfoKHR = + Platform->ExtFuncPtr->clGetCommandBufferInfoKHRCache; + UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clGetCommandBufferInfoKHR, + cl_ext::GetCommandBufferInfoName)); if (propName != UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT) { return UR_RESULT_ERROR_INVALID_ENUMERATION; diff --git a/source/adapters/opencl/command_buffer.hpp b/source/adapters/opencl/command_buffer.hpp index d80f29594b..17e4e1f7b6 100644 --- a/source/adapters/opencl/command_buffer.hpp +++ b/source/adapters/opencl/command_buffer.hpp @@ -11,6 +11,8 @@ #include #include +#include "context.hpp" + struct ur_exp_command_buffer_handle_t_ { ur_queue_handle_t hInternalQueue; ur_context_handle_t hContext; @@ -21,4 +23,6 @@ struct ur_exp_command_buffer_handle_t_ { cl_command_buffer_khr CLCommandBuffer) : hInternalQueue(hQueue), hContext(hContext), CLCommandBuffer(CLCommandBuffer) {} + + ur_platform_handle_t getPlatform() { return hContext->Devices[0]->Platform; } }; diff --git a/source/adapters/opencl/common.hpp b/source/adapters/opencl/common.hpp index 0667cd3d17..2fd00afd36 100644 --- a/source/adapters/opencl/common.hpp +++ b/source/adapters/opencl/common.hpp @@ -305,110 +305,6 @@ using clGetCommandBufferInfoKHR_fn = CL_API_ENTRY cl_int(CL_API_CALL *)( cl_command_buffer_khr command_buffer, cl_command_buffer_info_khr param_name, size_t param_value_size, void *param_value, size_t *param_value_size_ret); -template struct FuncPtrCache { - std::map Map; - std::mutex Mutex; -}; - -// FIXME: There's currently no mechanism for cleaning up this cache, meaning -// that it is invalidated whenever a context is destroyed. This could lead to -// reusing an invalid function pointer if another context happens to have the -// same native handle. -struct ExtFuncPtrCacheT { - FuncPtrCache clHostMemAllocINTELCache; - FuncPtrCache clDeviceMemAllocINTELCache; - FuncPtrCache clSharedMemAllocINTELCache; - FuncPtrCache clGetDeviceFunctionPointerCache; - FuncPtrCache - clCreateBufferWithPropertiesINTELCache; - FuncPtrCache clMemBlockingFreeINTELCache; - FuncPtrCache - clSetKernelArgMemPointerINTELCache; - FuncPtrCache clEnqueueMemFillINTELCache; - FuncPtrCache clEnqueueMemcpyINTELCache; - FuncPtrCache clGetMemAllocInfoINTELCache; - FuncPtrCache - clEnqueueWriteGlobalVariableCache; - FuncPtrCache clEnqueueReadGlobalVariableCache; - FuncPtrCache clEnqueueReadHostPipeINTELCache; - FuncPtrCache clEnqueueWriteHostPipeINTELCache; - FuncPtrCache - clSetProgramSpecializationConstantCache; - FuncPtrCache clCreateCommandBufferKHRCache; - FuncPtrCache clRetainCommandBufferKHRCache; - FuncPtrCache clReleaseCommandBufferKHRCache; - FuncPtrCache clFinalizeCommandBufferKHRCache; - FuncPtrCache clCommandNDRangeKernelKHRCache; - FuncPtrCache clCommandCopyBufferKHRCache; - FuncPtrCache clCommandCopyBufferRectKHRCache; - FuncPtrCache clCommandFillBufferKHRCache; - FuncPtrCache clEnqueueCommandBufferKHRCache; - FuncPtrCache clGetCommandBufferInfoKHRCache; -}; -// A raw pointer is used here since the lifetime of this map has to be tied to -// piTeardown to avoid issues with static destruction order (a user application -// might have static objects that indirectly access this cache in their -// destructor). -inline std::unique_ptr ExtFuncPtrCache; - -// USM helper function to get an extension function pointer -template -static ur_result_t getExtFuncFromContext(cl_context Context, - FuncPtrCache &FPtrCache, - const char *FuncName, T *Fptr) { - // TODO - // Potentially redo caching as UR interface changes. - // if cached, return cached FuncPtr - std::lock_guard CacheLock{FPtrCache.Mutex}; - std::map &FPtrMap = FPtrCache.Map; - auto It = FPtrMap.find(Context); - if (It != FPtrMap.end()) { - auto F = It->second; - // if cached that extension is not available return nullptr and - // UR_RESULT_ERROR_INVALID_VALUE - *Fptr = F; - return F ? UR_RESULT_SUCCESS : UR_RESULT_ERROR_INVALID_VALUE; - } - - cl_uint DeviceCount; - cl_int RetErr = clGetContextInfo(Context, CL_CONTEXT_NUM_DEVICES, - sizeof(cl_uint), &DeviceCount, nullptr); - - if (RetErr != CL_SUCCESS || DeviceCount < 1) { - return UR_RESULT_ERROR_INVALID_CONTEXT; - } - - std::vector DevicesInCtx(DeviceCount); - RetErr = clGetContextInfo(Context, CL_CONTEXT_DEVICES, - DeviceCount * sizeof(cl_device_id), - DevicesInCtx.data(), nullptr); - - if (RetErr != CL_SUCCESS) { - return UR_RESULT_ERROR_INVALID_CONTEXT; - } - - cl_platform_id CurPlatform; - RetErr = clGetDeviceInfo(DevicesInCtx[0], CL_DEVICE_PLATFORM, - sizeof(cl_platform_id), &CurPlatform, nullptr); - - if (RetErr != CL_SUCCESS) { - return UR_RESULT_ERROR_INVALID_CONTEXT; - } - - T FuncPtr = reinterpret_cast( - clGetExtensionFunctionAddressForPlatform(CurPlatform, FuncName)); - - if (!FuncPtr) { - // Cache that the extension is not available - FPtrMap[Context] = nullptr; - return UR_RESULT_ERROR_INVALID_VALUE; - } - - *Fptr = FuncPtr; - FPtrMap[Context] = FuncPtr; - - return UR_RESULT_SUCCESS; -} } // namespace cl_ext ur_result_t mapCLErrorToUR(cl_int Result); diff --git a/source/adapters/opencl/context.hpp b/source/adapters/opencl/context.hpp index 555636b1b8..cc537b2c8f 100644 --- a/source/adapters/opencl/context.hpp +++ b/source/adapters/opencl/context.hpp @@ -81,5 +81,7 @@ struct ur_context_handle_t_ { native_type get() { return Context; } + ur_platform_handle_t getPlatform() { return Devices[0]->Platform; } + const std::vector &getDevices() { return Devices; } }; diff --git a/source/adapters/opencl/device.hpp b/source/adapters/opencl/device.hpp index 22f554ce73..b1cd437e8f 100644 --- a/source/adapters/opencl/device.hpp +++ b/source/adapters/opencl/device.hpp @@ -10,7 +10,6 @@ #pragma once #include "common.hpp" -#include "platform.hpp" struct ur_device_handle_t_ { using native_type = cl_device_id; diff --git a/source/adapters/opencl/enqueue.cpp b/source/adapters/opencl/enqueue.cpp index 7ffaefd733..11ea7c0fad 100644 --- a/source/adapters/opencl/enqueue.cpp +++ b/source/adapters/opencl/enqueue.cpp @@ -13,6 +13,7 @@ #include "event.hpp" #include "kernel.hpp" #include "memory.hpp" +#include "platform.hpp" #include "program.hpp" #include "queue.hpp" @@ -518,22 +519,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite( uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - cl_context Ctx = hQueue->Context->get(); + ur_platform_handle_t Platform = hQueue->getPlatform(); - cl_ext::clEnqueueWriteGlobalVariable_fn F = nullptr; - cl_int Res = cl_ext::getExtFuncFromContext( - Ctx, cl_ext::ExtFuncPtrCache->clEnqueueWriteGlobalVariableCache, - cl_ext::EnqueueWriteGlobalVariableName, &F); + cl_ext::clEnqueueWriteGlobalVariable_fn clEnqueueWriteGlobalVariable = + Platform->ExtFuncPtr->clEnqueueWriteGlobalVariableCache; + UR_RETURN_ON_FAILURE(Platform->getExtFunc( + &clEnqueueWriteGlobalVariable, cl_ext::EnqueueWriteGlobalVariableName)); - if (!F || Res != CL_SUCCESS) - return UR_RESULT_ERROR_INVALID_OPERATION; cl_event Event; std::vector CLWaitEvents(numEventsInWaitList); for (uint32_t i = 0; i < numEventsInWaitList; i++) { CLWaitEvents[i] = phEventWaitList[i]->get(); } - Res = F(hQueue->get(), hProgram->get(), name, blockingWrite, count, offset, - pSrc, numEventsInWaitList, CLWaitEvents.data(), &Event); + cl_int Res = clEnqueueWriteGlobalVariable( + hQueue->get(), hProgram->get(), name, blockingWrite, count, offset, pSrc, + numEventsInWaitList, CLWaitEvents.data(), &Event); if (phEvent) { try { auto UREvent = @@ -554,22 +554,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead( uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - cl_context Ctx = hQueue->Context->get(); + ur_platform_handle_t Platform = hQueue->getPlatform(); - cl_ext::clEnqueueReadGlobalVariable_fn F = nullptr; - cl_int Res = cl_ext::getExtFuncFromContext( - Ctx, cl_ext::ExtFuncPtrCache->clEnqueueReadGlobalVariableCache, - cl_ext::EnqueueReadGlobalVariableName, &F); + cl_ext::clEnqueueReadGlobalVariable_fn clEnqueueReadGlobalVariable = + Platform->ExtFuncPtr->clEnqueueReadGlobalVariableCache; + UR_RETURN_ON_FAILURE(Platform->getExtFunc( + &clEnqueueReadGlobalVariable, cl_ext::EnqueueReadGlobalVariableName)); - if (!F || Res != CL_SUCCESS) - return UR_RESULT_ERROR_INVALID_OPERATION; cl_event Event; std::vector CLWaitEvents(numEventsInWaitList); for (uint32_t i = 0; i < numEventsInWaitList; i++) { CLWaitEvents[i] = phEventWaitList[i]->get(); } - Res = F(hQueue->get(), hProgram->get(), name, blockingRead, count, offset, - pDst, numEventsInWaitList, CLWaitEvents.data(), &Event); + cl_int Res = clEnqueueReadGlobalVariable( + hQueue->get(), hProgram->get(), name, blockingRead, count, offset, pDst, + numEventsInWaitList, CLWaitEvents.data(), &Event); if (phEvent) { try { auto UREvent = @@ -590,33 +589,30 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueReadHostPipe( uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - cl_context CLContext = hQueue->Context->get(); + ur_platform_handle_t Platform = hQueue->getPlatform(); - cl_ext::clEnqueueReadHostPipeINTEL_fn FuncPtr = nullptr; - ur_result_t RetVal = - cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clEnqueueReadHostPipeINTELCache, - cl_ext::EnqueueReadHostPipeName, &FuncPtr); + cl_ext::clEnqueueReadHostPipeINTEL_fn clEnqueueReadHostPipe = + Platform->ExtFuncPtr->clEnqueueReadHostPipeINTELCache; + UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clEnqueueReadHostPipe, + cl_ext::EnqueueReadHostPipeName)); - if (FuncPtr) { - cl_event Event; - std::vector CLWaitEvents(numEventsInWaitList); - for (uint32_t i = 0; i < numEventsInWaitList; i++) { - CLWaitEvents[i] = phEventWaitList[i]->get(); - } - RetVal = mapCLErrorToUR(FuncPtr(hQueue->get(), hProgram->get(), pipe_symbol, - blocking, pDst, size, numEventsInWaitList, - CLWaitEvents.data(), &Event)); - if (phEvent) { - try { - auto UREvent = std::make_unique( - Event, hQueue->Context, hQueue); - *phEvent = UREvent.release(); - } catch (std::bad_alloc &) { - return UR_RESULT_ERROR_OUT_OF_RESOURCES; - } catch (...) { - return UR_RESULT_ERROR_UNKNOWN; - } + cl_event Event; + std::vector CLWaitEvents(numEventsInWaitList); + for (uint32_t i = 0; i < numEventsInWaitList; i++) { + CLWaitEvents[i] = phEventWaitList[i]->get(); + } + ur_result_t RetVal = mapCLErrorToUR(clEnqueueReadHostPipe( + hQueue->get(), hProgram->get(), pipe_symbol, blocking, pDst, size, + numEventsInWaitList, CLWaitEvents.data(), &Event)); + if (phEvent) { + try { + auto UREvent = + std::make_unique(Event, hQueue->Context, hQueue); + *phEvent = UREvent.release(); + } catch (std::bad_alloc &) { + return UR_RESULT_ERROR_OUT_OF_RESOURCES; + } catch (...) { + return UR_RESULT_ERROR_UNKNOWN; } } @@ -629,33 +625,30 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueWriteHostPipe( uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - cl_context CLContext = hQueue->Context->get(); + ur_platform_handle_t Platform = hQueue->getPlatform(); - cl_ext::clEnqueueWriteHostPipeINTEL_fn FuncPtr = nullptr; - ur_result_t RetVal = - cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clEnqueueWriteHostPipeINTELCache, - cl_ext::EnqueueWriteHostPipeName, &FuncPtr); + cl_ext::clEnqueueWriteHostPipeINTEL_fn clEnqueueWriteHostPipe = + Platform->ExtFuncPtr->clEnqueueWriteHostPipeINTELCache; + UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clEnqueueWriteHostPipe, + cl_ext::EnqueueWriteHostPipeName)); - if (FuncPtr) { - cl_event Event; - std::vector CLWaitEvents(numEventsInWaitList); - for (uint32_t i = 0; i < numEventsInWaitList; i++) { - CLWaitEvents[i] = phEventWaitList[i]->get(); - } - RetVal = mapCLErrorToUR(FuncPtr(hQueue->get(), hProgram->get(), pipe_symbol, - blocking, pSrc, size, numEventsInWaitList, - CLWaitEvents.data(), &Event)); - if (phEvent) { - try { - auto UREvent = std::make_unique( - Event, hQueue->Context, hQueue); - *phEvent = UREvent.release(); - } catch (std::bad_alloc &) { - return UR_RESULT_ERROR_OUT_OF_RESOURCES; - } catch (...) { - return UR_RESULT_ERROR_UNKNOWN; - } + cl_event Event; + std::vector CLWaitEvents(numEventsInWaitList); + for (uint32_t i = 0; i < numEventsInWaitList; i++) { + CLWaitEvents[i] = phEventWaitList[i]->get(); + } + ur_result_t RetVal = mapCLErrorToUR(clEnqueueWriteHostPipe( + hQueue->get(), hProgram->get(), pipe_symbol, blocking, pSrc, size, + numEventsInWaitList, CLWaitEvents.data(), &Event)); + if (phEvent) { + try { + auto UREvent = + std::make_unique(Event, hQueue->Context, hQueue); + *phEvent = UREvent.release(); + } catch (std::bad_alloc &) { + return UR_RESULT_ERROR_OUT_OF_RESOURCES; + } catch (...) { + return UR_RESULT_ERROR_UNKNOWN; } } diff --git a/source/adapters/opencl/kernel.cpp b/source/adapters/opencl/kernel.cpp index c1e283029c..9a485b0c1e 100644 --- a/source/adapters/opencl/kernel.cpp +++ b/source/adapters/opencl/kernel.cpp @@ -11,6 +11,7 @@ #include "common.hpp" #include "device.hpp" #include "memory.hpp" +#include "platform.hpp" #include "program.hpp" #include "sampler.hpp" @@ -294,42 +295,35 @@ urKernelRelease(ur_kernel_handle_t hKernel) { static ur_result_t usmSetIndirectAccess(ur_kernel_handle_t hKernel) { cl_bool TrueVal = CL_TRUE; - clHostMemAllocINTEL_fn HFunc = nullptr; - clSharedMemAllocINTEL_fn SFunc = nullptr; - clDeviceMemAllocINTEL_fn DFunc = nullptr; - cl_context CLContext; - + ur_platform_handle_t Platform = hKernel->getPlatform(); /* We test that each alloc type is supported before we actually try to set * KernelExecInfo. */ - CL_RETURN_ON_FAILURE(clGetKernelInfo(hKernel->get(), CL_KERNEL_CONTEXT, - sizeof(cl_context), &CLContext, - nullptr)); - - UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clHostMemAllocINTELCache, - cl_ext::HostMemAllocName, &HFunc)); + clHostMemAllocINTEL_fn clHostMemAlloc = + Platform->ExtFuncPtr->clHostMemAllocINTELCache; + ur_result_t Res = + Platform->getExtFunc(&clHostMemAlloc, cl_ext::HostMemAllocName); - if (HFunc) { + if (Res == UR_RESULT_SUCCESS) { CL_RETURN_ON_FAILURE(clSetKernelExecInfo( hKernel->get(), CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL, sizeof(cl_bool), &TrueVal)); } - UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clDeviceMemAllocINTELCache, - cl_ext::DeviceMemAllocName, &DFunc)); + clDeviceMemAllocINTEL_fn clDeviceMemAlloc = + Platform->ExtFuncPtr->clDeviceMemAllocINTELCache; + Res = Platform->getExtFunc(&clDeviceMemAlloc, cl_ext::DeviceMemAllocName); - if (DFunc) { + if (Res == UR_RESULT_SUCCESS) { CL_RETURN_ON_FAILURE(clSetKernelExecInfo( hKernel->get(), CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL, sizeof(cl_bool), &TrueVal)); } - UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clSharedMemAllocINTELCache, - cl_ext::SharedMemAllocName, &SFunc)); + clSharedMemAllocINTEL_fn clSharedMemAlloc = + Platform->ExtFuncPtr->clSharedMemAllocINTELCache; + Res = Platform->getExtFunc(&clSharedMemAlloc, cl_ext::SharedMemAllocName); - if (SFunc) { + if (Res == UR_RESULT_SUCCESS) { CL_RETURN_ON_FAILURE(clSetKernelExecInfo( hKernel->get(), CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL, sizeof(cl_bool), &TrueVal)); @@ -374,24 +368,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgPointer( sizeof(cl_context), &CLContext, nullptr)); - clSetKernelArgMemPointerINTEL_fn FuncPtr = nullptr; - UR_RETURN_ON_FAILURE( - cl_ext::getExtFuncFromContext( - CLContext, - cl_ext::ExtFuncPtrCache->clSetKernelArgMemPointerINTELCache, - cl_ext::SetKernelArgMemPointerName, &FuncPtr)); + ur_platform_handle_t Platform = hKernel->getPlatform(); - if (FuncPtr) { + clSetKernelArgMemPointerINTEL_fn clSetKernelArgMemPointer = + Platform->ExtFuncPtr->clSetKernelArgMemPointerINTELCache; + ur_result_t Res = Platform->getExtFunc(&clSetKernelArgMemPointer, + cl_ext::SetKernelArgMemPointerName); + + if (Res == UR_RESULT_SUCCESS) { /* OpenCL passes pointers by value not by reference. This means we need to * deref the arg to get the pointer value */ auto PtrToPtr = reinterpret_cast(pArgValue); auto DerefPtr = reinterpret_cast(*PtrToPtr); - CL_RETURN_ON_FAILURE( - FuncPtr(hKernel->get(), cl_adapter::cast(argIndex), DerefPtr)); + CL_RETURN_ON_FAILURE(clSetKernelArgMemPointer( + hKernel->get(), cl_adapter::cast(argIndex), DerefPtr)); } return UR_RESULT_SUCCESS; } + UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle( ur_kernel_handle_t hKernel, ur_native_handle_t *phNativeKernel) { diff --git a/source/adapters/opencl/kernel.hpp b/source/adapters/opencl/kernel.hpp index 44651ebfc7..50f48b41c8 100644 --- a/source/adapters/opencl/kernel.hpp +++ b/source/adapters/opencl/kernel.hpp @@ -78,4 +78,6 @@ struct ur_kernel_handle_t_ { } native_type get() { return Kernel; } + + ur_platform_handle_t getPlatform() { return Context->Devices[0]->Platform; } }; diff --git a/source/adapters/opencl/memory.cpp b/source/adapters/opencl/memory.cpp index 9ebf347daf..5c43fc0184 100644 --- a/source/adapters/opencl/memory.cpp +++ b/source/adapters/opencl/memory.cpp @@ -11,6 +11,7 @@ #include "memory.hpp" #include "common.hpp" #include "context.hpp" +#include "platform.hpp" cl_image_format mapURImageFormatToCL(const ur_image_format_t *PImageFormat) { cl_image_format CLImageFormat; @@ -230,15 +231,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate( if (pProperties) { // TODO: need to check if all properties are supported by OpenCL RT and // ignore unsupported - clCreateBufferWithPropertiesINTEL_fn FuncPtr = nullptr; - cl_context CLContext = hContext->get(); + // First we need to look up the function pointer - RetErr = - cl_ext::getExtFuncFromContext( - CLContext, - cl_ext::ExtFuncPtrCache->clCreateBufferWithPropertiesINTELCache, - cl_ext::CreateBufferWithPropertiesName, &FuncPtr); - if (FuncPtr) { + cl_context CLContext = hContext->get(); + ur_platform_handle_t Platform = hContext->getPlatform(); + clCreateBufferWithPropertiesINTEL_fn clCreateBufferWithProperties = + Platform->ExtFuncPtr->clCreateBufferWithPropertiesINTELCache; + ur_result_t Res = Platform->getExtFunc( + &clCreateBufferWithProperties, cl_ext::CreateBufferWithPropertiesName); + + if (Res == UR_RESULT_SUCCESS) { std::vector PropertiesIntel; auto Prop = static_cast(pProperties->pNext); while (Prop) { @@ -263,7 +265,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate( PropertiesIntel.push_back(0); try { - cl_mem Buffer = FuncPtr( + cl_mem Buffer = clCreateBufferWithProperties( CLContext, PropertiesIntel.data(), static_cast(flags), size, pProperties->pHost, cl_adapter::cast(&RetErr)); CL_RETURN_ON_FAILURE(RetErr); diff --git a/source/adapters/opencl/platform.hpp b/source/adapters/opencl/platform.hpp index 6a9c49eb37..552705ed7c 100644 --- a/source/adapters/opencl/platform.hpp +++ b/source/adapters/opencl/platform.hpp @@ -14,32 +14,23 @@ #include +using namespace cl_ext; + struct ur_platform_handle_t_ { using native_type = cl_platform_id; native_type Platform = nullptr; std::vector> Devices; - ur_platform_handle_t_(native_type Plat) : Platform(Plat) {} + ur_platform_handle_t_(native_type Plat) : Platform(Plat) { + ExtFuncPtr = std::make_unique(); + } ~ur_platform_handle_t_() { for (auto &Dev : Devices) { Dev.reset(); } Devices.clear(); - } - - template - ur_result_t getExtFunc(T CachedExtFunc, const char *FuncName, T *Fptr) { - if (!CachedExtFunc) { - // TODO: check that the function is available - CachedExtFunc = reinterpret_cast( - clGetExtensionFunctionAddressForPlatform(Platform, FuncName)); - if (!CachedExtFunc) { - return UR_RESULT_ERROR_INVALID_VALUE; - } - } - *Fptr = CachedExtFunc; - return UR_RESULT_SUCCESS; + ExtFuncPtr.reset(); } native_type get() { return Platform; } @@ -86,4 +77,56 @@ struct ur_platform_handle_t_ { return UR_RESULT_SUCCESS; } + + struct ExtFuncPtrT { + clHostMemAllocINTEL_fn clHostMemAllocINTELCache = nullptr; + clDeviceMemAllocINTEL_fn clDeviceMemAllocINTELCache = nullptr; + clSharedMemAllocINTEL_fn clSharedMemAllocINTELCache = nullptr; + clGetDeviceFunctionPointer_fn clGetDeviceFunctionPointerCache = nullptr; + clCreateBufferWithPropertiesINTEL_fn + clCreateBufferWithPropertiesINTELCache = nullptr; + clMemBlockingFreeINTEL_fn clMemBlockingFreeINTELCache = nullptr; + clSetKernelArgMemPointerINTEL_fn clSetKernelArgMemPointerINTELCache = + nullptr; + clEnqueueMemFillINTEL_fn clEnqueueMemFillINTELCache = nullptr; + clEnqueueMemcpyINTEL_fn clEnqueueMemcpyINTELCache = nullptr; + clGetMemAllocInfoINTEL_fn clGetMemAllocInfoINTELCache = nullptr; + clEnqueueWriteGlobalVariable_fn clEnqueueWriteGlobalVariableCache = nullptr; + clEnqueueReadGlobalVariable_fn clEnqueueReadGlobalVariableCache = nullptr; + clEnqueueReadHostPipeINTEL_fn clEnqueueReadHostPipeINTELCache = nullptr; + clEnqueueWriteHostPipeINTEL_fn clEnqueueWriteHostPipeINTELCache = nullptr; + clSetProgramSpecializationConstant_fn + clSetProgramSpecializationConstantCache = nullptr; + clCreateCommandBufferKHR_fn clCreateCommandBufferKHRCache = nullptr; + clRetainCommandBufferKHR_fn clRetainCommandBufferKHRCache = nullptr; + clReleaseCommandBufferKHR_fn clReleaseCommandBufferKHRCache = nullptr; + clFinalizeCommandBufferKHR_fn clFinalizeCommandBufferKHRCache = nullptr; + clCommandNDRangeKernelKHR_fn clCommandNDRangeKernelKHRCache = nullptr; + clCommandCopyBufferKHR_fn clCommandCopyBufferKHRCache = nullptr; + clCommandCopyBufferRectKHR_fn clCommandCopyBufferRectKHRCache = nullptr; + clCommandFillBufferKHR_fn clCommandFillBufferKHRCache = nullptr; + clEnqueueCommandBufferKHR_fn clEnqueueCommandBufferKHRCache = nullptr; + clGetCommandBufferInfoKHR_fn clGetCommandBufferInfoKHRCache = nullptr; + }; + + std::unique_ptr ExtFuncPtr; + template + ur_result_t getExtFunc(T *CachedExtFunc, const char *FuncName) { + // Check that the function ext is supported by the device first. + bool Supported = false; + UR_RETURN_ON_FAILURE( + Devices[0]->checkDeviceExtensions({FuncName}, Supported)); + if (!Supported) { + return UR_RESULT_ERROR_INVALID_OPERATION; + } + + if (!(*CachedExtFunc)) { + *CachedExtFunc = reinterpret_cast( + clGetExtensionFunctionAddressForPlatform(Platform, FuncName)); + if (!(*CachedExtFunc)) { + return UR_RESULT_ERROR_INVALID_OPERATION; + } + } + return UR_RESULT_SUCCESS; + } }; diff --git a/source/adapters/opencl/program.cpp b/source/adapters/opencl/program.cpp index 799edf8d9c..5c69ce3af3 100644 --- a/source/adapters/opencl/program.cpp +++ b/source/adapters/opencl/program.cpp @@ -424,17 +424,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramSetSpecializationConstants( } } else { cl_ext::clSetProgramSpecializationConstant_fn - SetProgramSpecializationConstant = nullptr; - const ur_result_t URResult = cl_ext::getExtFuncFromContext< - decltype(SetProgramSpecializationConstant)>( - Ctx->get(), - cl_ext::ExtFuncPtrCache->clSetProgramSpecializationConstantCache, - cl_ext::SetProgramSpecializationConstantName, - &SetProgramSpecializationConstant); - - if (URResult != UR_RESULT_SUCCESS) { - return URResult; - } + SetProgramSpecializationConstant = + CurPlatform->ExtFuncPtr->clSetProgramSpecializationConstantCache; + UR_RETURN_ON_FAILURE( + CurPlatform->getExtFunc(&SetProgramSpecializationConstant, + cl_ext::SetProgramSpecializationConstantName)); for (uint32_t i = 0; i < count; ++i) { CL_RETURN_ON_FAILURE(SetProgramSpecializationConstant( @@ -475,16 +469,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetFunctionPointer( ur_device_handle_t hDevice, ur_program_handle_t hProgram, const char *pFunctionName, void **ppFunctionPointer) { - cl_context CLContext = hProgram->Context->get(); - - cl_ext::clGetDeviceFunctionPointer_fn FuncT = nullptr; - - UR_RETURN_ON_FAILURE( - cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clGetDeviceFunctionPointerCache, - cl_ext::GetDeviceFunctionPointerName, &FuncT)); + ur_platform_handle_t Platform = hDevice->Platform; + cl_ext::clGetDeviceFunctionPointer_fn clGetDeviceFunctionPointer = + Platform->ExtFuncPtr->clGetDeviceFunctionPointerCache; + ur_result_t Res = Platform->getExtFunc(&clGetDeviceFunctionPointer, + cl_ext::GetDeviceFunctionPointerName); - if (!FuncT) { + if (Res != UR_RESULT_SUCCESS) { return UR_RESULT_ERROR_INVALID_FUNCTION_NAME; } @@ -511,9 +502,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetFunctionPointer( return UR_RESULT_ERROR_INVALID_KERNEL_NAME; } - const cl_int CLResult = - FuncT(hDevice->get(), hProgram->get(), pFunctionName, - reinterpret_cast(ppFunctionPointer)); + const cl_int CLResult = clGetDeviceFunctionPointer( + hDevice->get(), hProgram->get(), pFunctionName, + reinterpret_cast(ppFunctionPointer)); // GPU runtime sometimes returns CL_INVALID_ARG_VALUE if the function address // cannot be found but the kernel exists. As the kernel does exist, return // that the function name is invalid. diff --git a/source/adapters/opencl/queue.hpp b/source/adapters/opencl/queue.hpp index e44af5f4d9..e5723f3204 100644 --- a/source/adapters/opencl/queue.hpp +++ b/source/adapters/opencl/queue.hpp @@ -75,4 +75,6 @@ struct ur_queue_handle_t_ { uint32_t getReferenceCount() const noexcept { return RefCount; } native_type get() { return Queue; } + + ur_platform_handle_t getPlatform() { return Device->Platform; } }; diff --git a/source/adapters/opencl/usm.cpp b/source/adapters/opencl/usm.cpp index 6e1917a034..afadb70ba6 100644 --- a/source/adapters/opencl/usm.cpp +++ b/source/adapters/opencl/usm.cpp @@ -12,6 +12,7 @@ #include "context.hpp" #include "device.hpp" #include "event.hpp" +#include "platform.hpp" #include "queue.hpp" inline cl_mem_alloc_flags_intel @@ -97,24 +98,21 @@ urUSMHostAlloc(ur_context_handle_t Context, const ur_usm_desc_t *pUSMDesc, } // First we need to look up the function pointer - clHostMemAllocINTEL_fn FuncPtr = nullptr; cl_context CLContext = Context->get(); - if (auto UrResult = cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clHostMemAllocINTELCache, - cl_ext::HostMemAllocName, &FuncPtr)) { - return UrResult; - } - - if (FuncPtr) { - cl_int ClResult = CL_SUCCESS; - Ptr = FuncPtr(CLContext, - AllocProperties.empty() ? nullptr : AllocProperties.data(), - size, Alignment, &ClResult); - if (ClResult == CL_INVALID_BUFFER_SIZE) { - return UR_RESULT_ERROR_INVALID_USM_SIZE; - } - CL_RETURN_ON_FAILURE(ClResult); + ur_platform_handle_t Platform = Context->getPlatform(); + clHostMemAllocINTEL_fn clHostMemAlloc = + Platform->ExtFuncPtr->clHostMemAllocINTELCache; + UR_RETURN_ON_FAILURE( + Platform->getExtFunc(&clHostMemAlloc, cl_ext::HostMemAllocName)); + + cl_int ClResult = CL_SUCCESS; + Ptr = clHostMemAlloc( + CLContext, AllocProperties.empty() ? nullptr : AllocProperties.data(), + size, Alignment, &ClResult); + if (ClResult == CL_INVALID_BUFFER_SIZE) { + return UR_RESULT_ERROR_INVALID_USM_SIZE; } + CL_RETURN_ON_FAILURE(ClResult); *ppMem = Ptr; @@ -140,24 +138,22 @@ urUSMDeviceAlloc(ur_context_handle_t Context, ur_device_handle_t hDevice, } // First we need to look up the function pointer - clDeviceMemAllocINTEL_fn FuncPtr = nullptr; cl_context CLContext = Context->get(); - if (auto UrResult = cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clDeviceMemAllocINTELCache, - cl_ext::DeviceMemAllocName, &FuncPtr)) { - return UrResult; - } - - if (FuncPtr) { - cl_int ClResult = CL_SUCCESS; - Ptr = FuncPtr(CLContext, hDevice->get(), - AllocProperties.empty() ? nullptr : AllocProperties.data(), - size, Alignment, &ClResult); - if (ClResult == CL_INVALID_BUFFER_SIZE) { - return UR_RESULT_ERROR_INVALID_USM_SIZE; - } - CL_RETURN_ON_FAILURE(ClResult); + ur_platform_handle_t Platform = hDevice->Platform; + clDeviceMemAllocINTEL_fn clDeviceMemAlloc = + Platform->ExtFuncPtr->clDeviceMemAllocINTELCache; + UR_RETURN_ON_FAILURE( + Platform->getExtFunc(&clDeviceMemAlloc, cl_ext::DeviceMemAllocName)); + + cl_int ClResult = CL_SUCCESS; + Ptr = clDeviceMemAlloc(CLContext, hDevice->get(), + AllocProperties.empty() ? nullptr + : AllocProperties.data(), + size, Alignment, &ClResult); + if (ClResult == CL_INVALID_BUFFER_SIZE) { + return UR_RESULT_ERROR_INVALID_USM_SIZE; } + CL_RETURN_ON_FAILURE(ClResult); *ppMem = Ptr; @@ -183,24 +179,22 @@ urUSMSharedAlloc(ur_context_handle_t Context, ur_device_handle_t hDevice, } // First we need to look up the function pointer - clSharedMemAllocINTEL_fn FuncPtr = nullptr; cl_context CLContext = Context->get(); - if (auto UrResult = cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clSharedMemAllocINTELCache, - cl_ext::SharedMemAllocName, &FuncPtr)) { - return UrResult; - } - - if (FuncPtr) { - cl_int ClResult = CL_SUCCESS; - Ptr = FuncPtr(CLContext, hDevice->get(), - AllocProperties.empty() ? nullptr : AllocProperties.data(), - size, Alignment, cl_adapter::cast(&ClResult)); - if (ClResult == CL_INVALID_BUFFER_SIZE) { - return UR_RESULT_ERROR_INVALID_USM_SIZE; - } - CL_RETURN_ON_FAILURE(ClResult); + ur_platform_handle_t Platform = hDevice->Platform; + clSharedMemAllocINTEL_fn clSharedMemAlloc = + Platform->ExtFuncPtr->clSharedMemAllocINTELCache; + UR_RETURN_ON_FAILURE( + Platform->getExtFunc(&clSharedMemAlloc, cl_ext::SharedMemAllocName)); + + cl_int ClResult = CL_SUCCESS; + Ptr = clSharedMemAlloc( + CLContext, hDevice->get(), + AllocProperties.empty() ? nullptr : AllocProperties.data(), size, + Alignment, cl_adapter::cast(&ClResult)); + if (ClResult == CL_INVALID_BUFFER_SIZE) { + return UR_RESULT_ERROR_INVALID_USM_SIZE; } + CL_RETURN_ON_FAILURE(ClResult); *ppMem = Ptr; @@ -215,19 +209,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t Context, // Use a blocking free to avoid issues with indirect access from kernels that // might be still running. - clMemBlockingFreeINTEL_fn FuncPtr = nullptr; - cl_context CLContext = Context->get(); - ur_result_t RetVal = UR_RESULT_ERROR_INVALID_OPERATION; - RetVal = cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clMemBlockingFreeINTELCache, - cl_ext::MemBlockingFreeName, &FuncPtr); + ur_platform_handle_t Platform = Context->getPlatform(); + clMemBlockingFreeINTEL_fn clMemBlockingFree = + Platform->ExtFuncPtr->clMemBlockingFreeINTELCache; + UR_RETURN_ON_FAILURE( + Platform->getExtFunc(&clMemBlockingFree, cl_ext::MemBlockingFreeName)); - if (FuncPtr) { - RetVal = mapCLErrorToUR(FuncPtr(CLContext, pMem)); - } - - return RetVal; + return mapCLErrorToUR(clMemBlockingFree(CLContext, pMem)); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( @@ -236,13 +225,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { // Have to look up the context from the kernel cl_context CLContext = hQueue->Context->get(); - + ur_platform_handle_t Platform = hQueue->Context->getPlatform(); if (patternSize <= 128) { - clEnqueueMemFillINTEL_fn EnqueueMemFill = nullptr; + clEnqueueMemFillINTEL_fn EnqueueMemFill = + Platform->ExtFuncPtr->clEnqueueMemFillINTELCache; UR_RETURN_ON_FAILURE( - cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemFillINTELCache, - cl_ext::EnqueueMemFillName, &EnqueueMemFill)); + Platform->getExtFunc(&EnqueueMemFill, cl_ext::EnqueueMemFillName)); + cl_event Event; std::vector CLWaitEvents(numEventsInWaitList); for (uint32_t i = 0; i < numEventsInWaitList; i++) { @@ -268,20 +257,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( // OpenCL only supports pattern sizes as large as the largest CL type // (double16/long16 - 128 bytes), anything larger we need to do on the host // side and copy it into the target allocation. - clHostMemAllocINTEL_fn HostMemAlloc = nullptr; - UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clHostMemAllocINTELCache, - cl_ext::HostMemAllocName, &HostMemAlloc)); + clHostMemAllocINTEL_fn HostMemAlloc = + Platform->ExtFuncPtr->clHostMemAllocINTELCache; + UR_RETURN_ON_FAILURE( + Platform->getExtFunc(&HostMemAlloc, cl_ext::HostMemAllocName)); - clEnqueueMemcpyINTEL_fn USMMemcpy = nullptr; - UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemcpyINTELCache, - cl_ext::EnqueueMemcpyName, &USMMemcpy)); + clEnqueueMemcpyINTEL_fn USMMemcpy = + Platform->ExtFuncPtr->clEnqueueMemcpyINTELCache; + UR_RETURN_ON_FAILURE( + Platform->getExtFunc(&USMMemcpy, cl_ext::EnqueueMemcpyName)); - clMemBlockingFreeINTEL_fn USMFree = nullptr; - UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clMemBlockingFreeINTELCache, - cl_ext::MemBlockingFreeName, &USMFree)); + clMemBlockingFreeINTEL_fn USMFree = + Platform->ExtFuncPtr->clMemBlockingFreeINTELCache; + UR_RETURN_ON_FAILURE( + Platform->getExtFunc(&USMFree, cl_ext::MemBlockingFreeName)); cl_int ClErr = CL_SUCCESS; auto HostBuffer = static_cast( @@ -360,32 +349,29 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy( const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { // Have to look up the context from the kernel - cl_context CLContext = hQueue->Context->get(); + ur_platform_handle_t Platform = hQueue->Context->getPlatform(); + clEnqueueMemcpyINTEL_fn clEnqueueMemcpy = + Platform->ExtFuncPtr->clEnqueueMemcpyINTELCache; + UR_RETURN_ON_FAILURE( + Platform->getExtFunc(&clEnqueueMemcpy, cl_ext::EnqueueMemcpyName)); - clEnqueueMemcpyINTEL_fn FuncPtr = nullptr; - ur_result_t RetVal = cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemcpyINTELCache, - cl_ext::EnqueueMemcpyName, &FuncPtr); - - if (FuncPtr) { - cl_event Event; - std::vector CLWaitEvents(numEventsInWaitList); - for (uint32_t i = 0; i < numEventsInWaitList; i++) { - CLWaitEvents[i] = phEventWaitList[i]->get(); - } - RetVal = mapCLErrorToUR(FuncPtr(hQueue->get(), blocking, pDst, pSrc, size, - numEventsInWaitList, CLWaitEvents.data(), - &Event)); - if (phEvent) { - try { - auto UREvent = std::make_unique( - Event, hQueue->Context, hQueue); - *phEvent = UREvent.release(); - } catch (std::bad_alloc &) { - return UR_RESULT_ERROR_OUT_OF_RESOURCES; - } catch (...) { - return UR_RESULT_ERROR_UNKNOWN; - } + cl_event Event; + std::vector CLWaitEvents(numEventsInWaitList); + for (uint32_t i = 0; i < numEventsInWaitList; i++) { + CLWaitEvents[i] = phEventWaitList[i]->get(); + } + ur_result_t RetVal = mapCLErrorToUR( + clEnqueueMemcpy(hQueue->get(), blocking, pDst, pSrc, size, + numEventsInWaitList, CLWaitEvents.data(), &Event)); + if (phEvent) { + try { + auto UREvent = + std::make_unique(Event, hQueue->Context, hQueue); + *phEvent = UREvent.release(); + } catch (std::bad_alloc &) { + return UR_RESULT_ERROR_OUT_OF_RESOURCES; + } catch (...) { + return UR_RESULT_ERROR_UNKNOWN; } } @@ -495,16 +481,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D( const void *pSrc, size_t srcPitch, size_t width, size_t height, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - cl_context CLContext = hQueue->Context->get(); - clEnqueueMemcpyINTEL_fn FuncPtr = nullptr; - ur_result_t RetVal = cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemcpyINTELCache, - cl_ext::EnqueueMemcpyName, &FuncPtr); - - if (!FuncPtr) { - return RetVal; - } + ur_platform_handle_t Platform = hQueue->Context->getPlatform(); + clEnqueueMemcpyINTEL_fn clEnqueueMemcpy = + Platform->ExtFuncPtr->clEnqueueMemcpyINTELCache; + UR_RETURN_ON_FAILURE( + Platform->getExtFunc(&clEnqueueMemcpy, cl_ext::EnqueueMemcpyName)); std::vector Events(height); for (size_t HeightIndex = 0; HeightIndex < height; HeightIndex++) { @@ -513,11 +495,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D( for (uint32_t i = 0; i < numEventsInWaitList; i++) { CLWaitEvents[i] = phEventWaitList[i]->get(); } - auto ClResult = - FuncPtr(hQueue->get(), false, - static_cast(pDst) + dstPitch * HeightIndex, - static_cast(pSrc) + srcPitch * HeightIndex, - width, numEventsInWaitList, CLWaitEvents.data(), &Event); + auto ClResult = clEnqueueMemcpy( + hQueue->get(), false, + static_cast(pDst) + dstPitch * HeightIndex, + static_cast(pSrc) + srcPitch * HeightIndex, width, + numEventsInWaitList, CLWaitEvents.data(), &Event); Events[HeightIndex] = Event; if (ClResult != CL_SUCCESS) { for (const auto &E : Events) { @@ -572,11 +554,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMGetMemAllocInfo( ur_context_handle_t Context, const void *pMem, ur_usm_alloc_info_t propName, size_t propSize, void *pPropValue, size_t *pPropSizeRet) { - clGetMemAllocInfoINTEL_fn GetMemAllocInfo = nullptr; - cl_context CLContext = Context->get(); - UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext( - CLContext, cl_ext::ExtFuncPtrCache->clGetMemAllocInfoINTELCache, - cl_ext::GetMemAllocInfoName, &GetMemAllocInfo)); + ur_platform_handle_t Platform = Context->getPlatform(); + clGetMemAllocInfoINTEL_fn GetMemAllocInfo = + Platform->ExtFuncPtr->clGetMemAllocInfoINTELCache; + UR_RETURN_ON_FAILURE( + Platform->getExtFunc(&GetMemAllocInfo, cl_ext::GetMemAllocInfoName)); cl_mem_info_intel PropNameCL; switch (propName) {