Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions unified-runtime/source/adapters/level_zero/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,11 @@ ur_result_t urProgramLinkExp(
*phProgram = reinterpret_cast<ur_program_handle_t>(UrProgram);
return UR_RESULT_ERROR_PROGRAM_LINK_FAILURE;
}

ur_result_t UrResult = UR_RESULT_SUCCESS;
try {
// Acquire a "shared" lock on each of the input programs, and also
// validate that they are all in Object state for each device in the input
// list.
// validate that they are all in Object or Native state for each device in
// the input list.
//
// There is no danger of deadlock here even if two threads call
// urProgramLink simultaneously with the same input programs in a
Expand All @@ -407,13 +406,24 @@ ur_result_t urProgramLinkExp(
std::vector<std::shared_lock<ur_shared_mutex>> Guards(count);
const ur_program_handle_t_::CodeFormat CommonCodeFormat =
phPrograms[0]->getCodeFormat();
const ur_program_handle_t_::state CommonState =
phPrograms[0]->getState(phDevices[0]->ZeDevice);
if (CommonState != ur_program_handle_t_::Object &&
CommonState != ur_program_handle_t_::Native) {
return UR_RESULT_ERROR_INVALID_OPERATION;
}
// Native programs are passed to this function to resolve external
// symbols, which requires multiple input programs.
if (CommonState == ur_program_handle_t_::Native && count == 1) {
return UR_RESULT_ERROR_INVALID_OPERATION;
}
for (uint32_t I = 0; I < count; I++) {
std::shared_lock<ur_shared_mutex> Guard(phPrograms[I]->Mutex);
Guards[I].swap(Guard);

for (uint32_t DeviceIndex = 0; DeviceIndex < numDevices; DeviceIndex++) {
auto Device = phDevices[DeviceIndex];
if (phPrograms[I]->getState(Device->ZeDevice) !=
ur_program_handle_t_::Object) {
if (phPrograms[I]->getState(Device->ZeDevice) != CommonState) {
return UR_RESULT_ERROR_INVALID_OPERATION;
}
}
Expand All @@ -426,11 +436,13 @@ ur_result_t urProgramLinkExp(
}
}

// Previous calls to urProgramCompile did not actually compile the SPIR-V.
// Instead, we postpone compilation until this point, when all the modules
// are linked together. By doing compilation and linking together, the
// JIT compiler is able see all modules and do cross-module optimizations.
//
// For SPIR-V input programs, previous calls to urProgramCompile did not
// actually compile the SPIR-V. Instead, we postpone compilation until this
// point, when all the modules are linked together. By doing compilation
// and linking together, the JIT compiler is able see all modules and do
// cross-module optimizations. This means that both SPIR-V and native
// programs can be passed to the module program extension directly.

// Construct a ze_module_program_exp_desc_t which contains information
// about all of the modules that will be linked together.
ZeStruct<ze_module_program_exp_desc_t> ZeExtModuleDesc;
Expand Down Expand Up @@ -460,6 +472,9 @@ ur_result_t urProgramLinkExp(
case ur_program_handle_t_::CodeFormat::SPIRV:
ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
break;
case ur_program_handle_t_::CodeFormat::Native:
ZeModuleDesc.format = ZE_MODULE_FORMAT_NATIVE;
break;
default:
ur::unreachable();
return UR_RESULT_ERROR_INVALID_PROGRAM;
Expand Down
Loading