Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support UR program creation from multiple device binaries #2147

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4201,17 +4201,19 @@ urProgramCreateWithIL(
);

///////////////////////////////////////////////////////////////////////////////
/// @brief Create a program object from device native binary.
/// @brief Create a program object from native binaries for the specified
/// devices.
///
/// @details
/// - The application may call this function from simultaneous threads.
/// - Following a successful call to this entry point, `phProgram` will
/// contain a binary of type ::UR_PROGRAM_BINARY_TYPE_COMPILED_OBJECT or
/// ::UR_PROGRAM_BINARY_TYPE_LIBRARY for `hDevice`.
/// - The device specified by `hDevice` must be device associated with
/// contain binaries of type ::UR_PROGRAM_BINARY_TYPE_COMPILED_OBJECT or
/// ::UR_PROGRAM_BINARY_TYPE_LIBRARY for the specified devices in
/// `phDevices`.
/// - The devices specified by `phDevices` must be associated with the
/// context.
/// - The adapter may (but is not required to) perform validation of the
/// provided module during this call.
/// provided modules during this call.
///
/// @remarks
/// _Analogues_
Expand All @@ -4224,21 +4226,27 @@ urProgramCreateWithIL(
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hContext`
/// + `NULL == hDevice`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == pBinary`
/// + `NULL == phDevices`
/// + `NULL == pLengths`
/// + `NULL == ppBinaries`
/// + `NULL == phProgram`
/// + `NULL != pProperties && pProperties->count > 0 && NULL == pProperties->pMetadatas`
/// - ::UR_RESULT_ERROR_INVALID_SIZE
/// + `NULL != pProperties && NULL != pProperties->pMetadatas && pProperties->count == 0`
/// + `numDevices == 0`
/// - ::UR_RESULT_ERROR_INVALID_NATIVE_BINARY
/// + If `pBinary` isn't a valid binary for `hDevice.`
/// + If any binary in `ppBinaries` isn't a valid binary for the corresponding device in `phDevices.`
UR_APIEXPORT ur_result_t UR_APICALL
urProgramCreateWithBinary(
ur_context_handle_t hContext, ///< [in] handle of the context instance
ur_device_handle_t hDevice, ///< [in] handle to device associated with binary.
size_t size, ///< [in] size in bytes.
const uint8_t *pBinary, ///< [in] pointer to binary.
uint32_t numDevices, ///< [in] number of devices
ur_device_handle_t *phDevices, ///< [in][range(0, numDevices)] a pointer to a list of device handles. The
///< binaries are loaded for devices specified in this list.
size_t *pLengths, ///< [in][range(0, numDevices)] array of sizes of program binaries
///< specified by `pBinaries` (in bytes).
const uint8_t **ppBinaries, ///< [in][range(0, numDevices)] pointer to program binaries to be loaded
///< for devices specified by `phDevices`.
const ur_program_properties_t *pProperties, ///< [in][optional] pointer to program creation properties.
ur_program_handle_t *phProgram ///< [out] pointer to handle of Program object created.
);
Expand Down Expand Up @@ -10035,9 +10043,10 @@ typedef struct ur_program_create_with_il_params_t {
/// allowing the callback the ability to modify the parameter's value
typedef struct ur_program_create_with_binary_params_t {
ur_context_handle_t *phContext;
ur_device_handle_t *phDevice;
size_t *psize;
const uint8_t **ppBinary;
uint32_t *pnumDevices;
ur_device_handle_t **pphDevices;
size_t **ppLengths;
const uint8_t ***pppBinaries;
const ur_program_properties_t **ppProperties;
ur_program_handle_t **pphProgram;
} ur_program_create_with_binary_params_t;
Expand Down
7 changes: 4 additions & 3 deletions include/ur_ddi.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,10 @@ typedef ur_result_t(UR_APICALL *ur_pfnProgramCreateWithIL_t)(
/// @brief Function-pointer for urProgramCreateWithBinary
typedef ur_result_t(UR_APICALL *ur_pfnProgramCreateWithBinary_t)(
ur_context_handle_t,
ur_device_handle_t,
size_t,
const uint8_t *,
uint32_t,
ur_device_handle_t *,
size_t *,
const uint8_t **,
const ur_program_properties_t *,
ur_program_handle_t *);

Expand Down
39 changes: 31 additions & 8 deletions include/ur_print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11127,21 +11127,44 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
*(params->phContext));

os << ", ";
os << ".hDevice = ";
os << ".numDevices = ";

ur::details::printPtr(os,
*(params->phDevice));
os << *(params->pnumDevices);

os << ", ";
os << ".size = ";
os << ".phDevices = {";
for (size_t i = 0; *(params->pphDevices) != NULL && i < *params->pnumDevices; ++i) {
if (i != 0) {
os << ", ";
}

os << *(params->psize);
ur::details::printPtr(os,
(*(params->pphDevices))[i]);
}
os << "}";

os << ", ";
os << ".pBinary = ";
os << ".pLengths = {";
for (size_t i = 0; *(params->ppLengths) != NULL && i < *params->pnumDevices; ++i) {
if (i != 0) {
os << ", ";
}

ur::details::printPtr(os,
*(params->ppBinary));
os << (*(params->ppLengths))[i];
}
os << "}";

os << ", ";
os << ".ppBinaries = {";
for (size_t i = 0; *(params->pppBinaries) != NULL && i < *params->pnumDevices; ++i) {
if (i != 0) {
os << ", ";
}

ur::details::printPtr(os,
(*(params->pppBinaries))[i]);
}
os << "}";

os << ", ";
os << ".pProperties = ";
Expand Down
32 changes: 18 additions & 14 deletions scripts/core/program.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ returns:
- "`length == 0`"
--- #--------------------------------------------------------------------------
type: function
desc: "Create a program object from device native binary."
desc: "Create a program object from native binaries for the specified devices."
class: $xProgram
name: CreateWithBinary
decl: static
Expand All @@ -128,22 +128,25 @@ analogue:
- "**clCreateProgramWithBinary**"
details:
- "The application may call this function from simultaneous threads."
- "Following a successful call to this entry point, `phProgram` will contain a binary of type $X_PROGRAM_BINARY_TYPE_COMPILED_OBJECT or $X_PROGRAM_BINARY_TYPE_LIBRARY for `hDevice`."
- "The device specified by `hDevice` must be device associated with context."
- "The adapter may (but is not required to) perform validation of the provided module during this call."
- "Following a successful call to this entry point, `phProgram` will contain binaries of type $X_PROGRAM_BINARY_TYPE_COMPILED_OBJECT or $X_PROGRAM_BINARY_TYPE_LIBRARY for the specified devices in `phDevices`."
- "The devices specified by `phDevices` must be associated with the context."
- "The adapter may (but is not required to) perform validation of the provided modules during this call."
params:
- type: $x_context_handle_t
name: hContext
desc: "[in] handle of the context instance"
- type: $x_device_handle_t
name: hDevice
desc: "[in] handle to device associated with binary."
- type: size_t
name: size
desc: "[in] size in bytes."
- type: const uint8_t*
name: pBinary
desc: "[in] pointer to binary."
- type: uint32_t
name: numDevices
desc: "[in] number of devices"
- type: $x_device_handle_t*
name: phDevices
desc: "[in][range(0, numDevices)] a pointer to a list of device handles. The binaries are loaded for devices specified in this list."
- type: size_t*
name: pLengths
desc: "[in][range(0, numDevices)] array of sizes of program binaries specified by `pBinaries` (in bytes)."
- type: const uint8_t**
name: ppBinaries
desc: "[in][range(0, numDevices)] pointer to program binaries to be loaded for devices specified by `phDevices`."
- type: const $x_program_properties_t*
name: pProperties
desc: "[in][optional] pointer to program creation properties."
Expand All @@ -155,8 +158,9 @@ returns:
- "`NULL != pProperties && pProperties->count > 0 && NULL == pProperties->pMetadatas`"
- $X_RESULT_ERROR_INVALID_SIZE:
- "`NULL != pProperties && NULL != pProperties->pMetadatas && pProperties->count == 0`"
- "`numDevices == 0`"
- $X_RESULT_ERROR_INVALID_NATIVE_BINARY:
- "If `pBinary` isn't a valid binary for `hDevice.`"
- "If any binary in `ppBinaries` isn't a valid binary for the corresponding device in `phDevices.`"
--- #--------------------------------------------------------------------------
type: function
desc: "Produces an executable program from one program, negates need for the linking step."
Expand Down
11 changes: 7 additions & 4 deletions source/adapters/cuda/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,12 +482,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetNativeHandle(
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
ur_context_handle_t hContext, ur_device_handle_t hDevice, size_t size,
const uint8_t *pBinary, const ur_program_properties_t *pProperties,
ur_context_handle_t hContext, uint32_t numDevices,
ur_device_handle_t *phDevices, size_t *pLengths, const uint8_t **ppBinaries,
const ur_program_properties_t *pProperties,
ur_program_handle_t *phProgram) {
if (numDevices > 1)
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;

UR_CHECK_ERROR(
createProgram(hContext, hDevice, size, pBinary, pProperties, phProgram));
UR_CHECK_ERROR(createProgram(hContext, phDevices[0], pLengths[0],
ppBinaries[0], pProperties, phProgram));
(*phProgram)->BinaryType = UR_PROGRAM_BINARY_TYPE_COMPILED_OBJECT;

return UR_RESULT_SUCCESS;
Expand Down
11 changes: 9 additions & 2 deletions source/adapters/hip/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,9 +480,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetNativeHandle(
///
/// Note: Only supports one device
UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
ur_context_handle_t hContext, ur_device_handle_t hDevice, size_t size,
const uint8_t *pBinary, const ur_program_properties_t *pProperties,
ur_context_handle_t hContext, uint32_t numDevices,
ur_device_handle_t *phDevices, size_t *pLengths, const uint8_t **ppBinaries,
const ur_program_properties_t *pProperties,
ur_program_handle_t *phProgram) {
if (numDevices > 1)
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;

auto hDevice = phDevices[0];
auto pBinary = ppBinaries[0];
auto size = pLengths[0];
UR_ASSERT(std::find(hContext->getDevices().begin(),
hContext->getDevices().end(),
hDevice) != hContext->getDevices().end(),
Expand Down
41 changes: 15 additions & 26 deletions source/adapters/level_zero/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,18 +495,11 @@ ur_result_t urEnqueueDeviceGlobalVariableWrite(
///< this particular kernel execution instance.
) {
std::scoped_lock<ur_shared_mutex> lock(Queue->Mutex);

ze_module_handle_t ZeModule{};
auto It = Program->ZeModuleMap.find(Queue->Device->ZeDevice);
if (It != Program->ZeModuleMap.end()) {
ZeModule = It->second;
} else {
ZeModule = Program->ZeModule;
}

// Find global variable pointer
size_t GlobalVarSize = 0;
void *GlobalVarPtr = nullptr;
ze_module_handle_t ZeModule =
Program->getZeModuleHandle(Queue->Device->ZeDevice);
ZE2UR_CALL(zeModuleGetGlobalPointer,
(ZeModule, Name, &GlobalVarSize, &GlobalVarPtr));
if (GlobalVarSize < Offset + Count) {
Expand Down Expand Up @@ -557,15 +550,8 @@ ur_result_t urEnqueueDeviceGlobalVariableRead(
///< this particular kernel execution instance.
) {
std::scoped_lock<ur_shared_mutex> lock(Queue->Mutex);

ze_module_handle_t ZeModule{};
auto It = Program->ZeModuleMap.find(Queue->Device->ZeDevice);
if (It != Program->ZeModuleMap.end()) {
ZeModule = It->second;
} else {
ZeModule = Program->ZeModule;
}

ze_module_handle_t ZeModule =
Program->getZeModuleHandle(Queue->Device->ZeDevice);
// Find global variable pointer
size_t GlobalVarSize = 0;
void *GlobalVarPtr = nullptr;
Expand Down Expand Up @@ -603,10 +589,6 @@ ur_result_t urKernelCreate(
*RetKernel ///< [out] pointer to handle of kernel object created.
) {
std::shared_lock<ur_shared_mutex> Guard(Program->Mutex);
if (Program->State != ur_program_handle_t_::state::Exe) {
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
}

try {
ur_kernel_handle_t_ *UrKernel = new ur_kernel_handle_t_(true, Program);
*RetKernel = reinterpret_cast<ur_kernel_handle_t>(UrKernel);
Expand All @@ -616,8 +598,14 @@ ur_result_t urKernelCreate(
return UR_RESULT_ERROR_UNKNOWN;
}

for (auto It : Program->ZeModuleMap) {
auto ZeModule = It.second;
for (auto &Dev : Program->AssociatedDevices) {
auto ZeDevice = Dev->ZeDevice;
// Program may be associated with all devices from the context but built
// only for subset of devices.
if (Program->getState(ZeDevice) != ur_program_handle_t_::state::Exe)
continue;

auto ZeModule = Program->getZeModuleHandle(ZeDevice);
ZeStruct<ze_kernel_desc_t> ZeKernelDesc;
ZeKernelDesc.flags = 0;
ZeKernelDesc.pKernelName = KernelName;
Expand All @@ -632,8 +620,6 @@ ur_result_t urKernelCreate(
return ze2urResult(ZeResult);
}

auto ZeDevice = It.first;

// Store the kernel in the ZeKernelMap so the correct
// kernel can be retrieved later for a specific device
// where a queue is being submitted.
Expand All @@ -651,6 +637,9 @@ ur_result_t urKernelCreate(
(*RetKernel)->ZeKernelMap[ZeSubDevice] = ZeKernel;
}
}
// There is no any successfully built executable for program.
if ((*RetKernel)->ZeKernelMap.empty())
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;

(*RetKernel)->ZeKernel = (*RetKernel)->ZeKernelMap.begin()->second;

Expand Down
Loading
Loading