Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1123,19 +1123,20 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpOld(
wait_list_view &waitListView, ur_event_handle_t phEvent) {
{
std::scoped_lock<ur_shared_mutex> guard(hKernel->Mutex);
ur_device_handle_t hDevice = this->hDevice.get();
for (uint32_t argIndex = 0; argIndex < numArgs; argIndex++) {
switch (pArgs[argIndex].type) {
case UR_EXP_KERNEL_ARG_TYPE_LOCAL:
UR_CALL(hKernel->setArgValue(pArgs[argIndex].index,
UR_CALL(hKernel->setArgValue(hDevice, pArgs[argIndex].index,
pArgs[argIndex].size, nullptr, nullptr));
break;
case UR_EXP_KERNEL_ARG_TYPE_VALUE:
UR_CALL(hKernel->setArgValue(pArgs[argIndex].index,
UR_CALL(hKernel->setArgValue(hDevice, pArgs[argIndex].index,
pArgs[argIndex].size, nullptr,
pArgs[argIndex].value.value));
break;
case UR_EXP_KERNEL_ARG_TYPE_POINTER:
UR_CALL(hKernel->setArgPointer(pArgs[argIndex].index, nullptr,
UR_CALL(hKernel->setArgPointer(hDevice, pArgs[argIndex].index, nullptr,
pArgs[argIndex].value.pointer));
break;
case UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ:
Expand All @@ -1147,7 +1148,7 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpOld(
break;
case UR_EXP_KERNEL_ARG_TYPE_SAMPLER: {
UR_CALL(
hKernel->setArgValue(argIndex, sizeof(void *), nullptr,
hKernel->setArgValue(hDevice, argIndex, sizeof(void *), nullptr,
&pArgs[argIndex].value.sampler->ZeSampler));
break;
}
Expand Down
21 changes: 15 additions & 6 deletions unified-runtime/source/adapters/level_zero/v2/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,20 @@ ur_kernel_handle_t_::getProperties(ur_device_handle_t hDevice) const {
}

ur_result_t ur_kernel_handle_t_::setArgValue(
uint32_t argIndex, size_t argSize,
ur_device_handle_t hDevice, uint32_t argIndex, size_t argSize,
const ur_kernel_arg_value_properties_t * /*pProperties*/,
const void *pArgValue) {
if (argIndex > zeCommonProperties.numKernelArgs - 1) {
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
}

if (hDevice) { // Set argument only on the specified device
auto &deviceKernel = deviceKernels[deviceIndex(hDevice)].value();
UR_CALL(setArgValueOnZeKernel(deviceKernel.hKernel.get(), argIndex, argSize,
pArgValue));
return UR_RESULT_SUCCESS;
}

for (auto &singleDeviceKernel : deviceKernels) {
if (!singleDeviceKernel.has_value()) {
continue;
Expand All @@ -213,12 +220,13 @@ ur_result_t ur_kernel_handle_t_::setArgValue(
}

ur_result_t ur_kernel_handle_t_::setArgPointer(
uint32_t argIndex,
ur_device_handle_t hDevice, uint32_t argIndex,
const ur_kernel_arg_pointer_properties_t * /*pProperties*/,
const void *pArgValue) {

// KernelSetArgValue is expecting a pointer to the argument
return setArgValue(argIndex, sizeof(const void *), nullptr, &pArgValue);
return setArgValue(hDevice, argIndex, sizeof(const void *), nullptr,
&pArgValue);
}

ur_program_handle_t ur_kernel_handle_t_::getProgramHandle() const {
Expand Down Expand Up @@ -429,7 +437,8 @@ ur_result_t urKernelSetArgValue(
TRACK_SCOPE_LATENCY("urKernelSetArgValue");

std::scoped_lock<ur_shared_mutex> guard(hKernel->Mutex);
return hKernel->setArgValue(argIndex, argSize, pProperties, pArgValue);
return hKernel->setArgValue(nullptr, argIndex, argSize, pProperties,
pArgValue);
} catch (...) {
return exceptionToResult(std::current_exception());
}
Expand Down Expand Up @@ -492,7 +501,7 @@ ur_result_t urKernelSetArgLocal(

std::scoped_lock<ur_shared_mutex> guard(hKernel->Mutex);

return hKernel->setArgValue(argIndex, argSize, nullptr, nullptr);
return hKernel->setArgValue(nullptr, argIndex, argSize, nullptr, nullptr);
} catch (...) {
return exceptionToResult(std::current_exception());
}
Expand Down Expand Up @@ -736,7 +745,7 @@ ur_result_t urKernelSetArgSampler(
ur_sampler_handle_t hArgValue) try {
TRACK_SCOPE_LATENCY("urKernelSetArgSampler");
std::scoped_lock<ur_shared_mutex> guard(hKernel->Mutex);
return hKernel->setArgValue(argIndex, sizeof(void *), nullptr,
return hKernel->setArgValue(nullptr, argIndex, sizeof(void *), nullptr,
&hArgValue->ZeSampler);
} catch (...) {
return exceptionToResult(std::current_exception());
Expand Down
5 changes: 3 additions & 2 deletions unified-runtime/source/adapters/level_zero/v2/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ struct ur_kernel_handle_t_ : ur_object {
const ze_kernel_properties_t &getProperties(ur_device_handle_t hDevice) const;

// Implementation of urKernelSetArgValue.
ur_result_t setArgValue(uint32_t argIndex, size_t argSize,
ur_result_t setArgValue(ur_device_handle_t hDevice, uint32_t argIndex,
size_t argSize,
const ur_kernel_arg_value_properties_t *pProperties,
const void *pArgValue);

// Implementation of urKernelSetArgPointer.
ur_result_t
setArgPointer(uint32_t argIndex,
setArgPointer(ur_device_handle_t hDevice, uint32_t argIndex,
const ur_kernel_arg_pointer_properties_t *pProperties,
const void *pArgValue);

Expand Down
Loading