@@ -33,6 +33,17 @@ ur_single_device_kernel_t::ur_single_device_kernel_t(ur_device_handle_t hDevice,
3333 };
3434}
3535
36+ ur_result_t ur_single_device_kernel_t::setArgValue (uint32_t argIndex,
37+ size_t argSize,
38+ const void *pArgValue) {
39+ return setArgValueOnZeKernel (hKernel.get (), argIndex, argSize, pArgValue);
40+ }
41+
42+ ur_result_t ur_single_device_kernel_t::setArgPointer (uint32_t argIndex,
43+ const void *pArgValue) {
44+ return setArgValue (argIndex, sizeof (void *), &pArgValue);
45+ }
46+
3647ur_result_t ur_single_device_kernel_t::release () {
3748 hKernel.reset ();
3849 return UR_RESULT_SUCCESS;
@@ -187,19 +198,6 @@ ur_result_t ur_kernel_handle_t_::setArgValue(
187198 uint32_t argIndex, size_t argSize,
188199 const ur_kernel_arg_value_properties_t * /* pProperties*/ ,
189200 const void *pArgValue) {
190-
191- // OpenCL: "the arg_value pointer can be NULL or point to a NULL value
192- // in which case a NULL value will be used as the value for the argument
193- // declared as a pointer to global or constant memory in the kernel"
194- //
195- // We don't know the type of the argument but it seems that the only time
196- // SYCL RT would send a pointer to NULL in 'arg_value' is when the argument
197- // is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL.
198- if (argSize == sizeof (void *) && pArgValue &&
199- *(void **)(const_cast <void *>(pArgValue)) == nullptr ) {
200- pArgValue = nullptr ;
201- }
202-
203201 if (argIndex > zeCommonProperties->numKernelArgs - 1 ) {
204202 return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
205203 }
@@ -209,15 +207,8 @@ ur_result_t ur_kernel_handle_t_::setArgValue(
209207 continue ;
210208 }
211209
212- auto zeResult = ZE_CALL_NOCHECK (zeKernelSetArgumentValue,
213- (singleDeviceKernel.value ().hKernel .get (),
214- argIndex, argSize, pArgValue));
215-
216- if (zeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) {
217- return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE;
218- } else if (zeResult != ZE_RESULT_SUCCESS) {
219- return ze2urResult (zeResult);
220- }
210+ UR_CALL (setArgValueOnZeKernel (singleDeviceKernel.value ().hKernel .get (),
211+ argIndex, argSize, pArgValue));
221212 }
222213 return UR_RESULT_SUCCESS;
223214}
@@ -281,7 +272,11 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(
281272 const size_t *pGlobalWorkOffset, uint32_t workDim, uint32_t groupSizeX,
282273 uint32_t groupSizeY, uint32_t groupSizeZ,
283274 ze_command_list_handle_t commandList, wait_list_view &waitListView) {
284- auto hZeKernel = getZeHandle (hDevice);
275+ auto &deviceKernelOpt = deviceKernels[deviceIndex (hDevice)];
276+ if (!deviceKernelOpt.has_value ())
277+ return UR_RESULT_ERROR_INVALID_KERNEL;
278+ auto &deviceKernel = deviceKernelOpt.value ();
279+ auto hZeKernel = deviceKernel.hKernel .get ();
285280
286281 if (pGlobalWorkOffset != NULL ) {
287282 UR_CALL (
@@ -304,10 +299,17 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(
304299 zePtr = reinterpret_cast <void *>(hImage->getZeImage ());
305300 }
306301 }
307- UR_CALL (setArgPointer (pending.argIndex , nullptr , zePtr));
302+ // Set the argument only on this device's kernel.
303+ UR_CALL (deviceKernel.setArgPointer (pending.argIndex , zePtr));
308304 }
309305 pending_allocations.clear ();
310306
307+ // Apply any pending raw pointer arguments (USM pointers) for this device.
308+ for (auto &pending : pending_pointer_args) {
309+ UR_CALL (deviceKernel.setArgPointer (pending.argIndex , pending.ptrArgValue ));
310+ }
311+ pending_pointer_args.clear ();
312+
311313 return UR_RESULT_SUCCESS;
312314}
313315
@@ -322,6 +324,18 @@ ur_result_t ur_kernel_handle_t_::addPendingMemoryAllocation(
322324 return UR_RESULT_SUCCESS;
323325}
324326
327+ ur_result_t
328+ ur_kernel_handle_t_::addPendingPointerArgument (uint32_t argIndex,
329+ const void *pArgValue) {
330+ if (argIndex > zeCommonProperties->numKernelArgs - 1 ) {
331+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
332+ }
333+
334+ pending_pointer_args.push_back ({argIndex, pArgValue});
335+
336+ return UR_RESULT_SUCCESS;
337+ }
338+
325339std::vector<char > ur_kernel_handle_t_::getSourceAttributes () const {
326340 uint32_t size;
327341 ZE2UR_CALL_THROWS (zeKernelGetSourceAttributes,
@@ -408,14 +422,16 @@ ur_result_t urKernelSetArgPointer(
408422 ur_kernel_handle_t hKernel, // /< [in] handle of the kernel object
409423 uint32_t argIndex, // /< [in] argument index in range [0, num args - 1]
410424 const ur_kernel_arg_pointer_properties_t
411- *pProperties, // /< [in][optional] argument properties
425+ * /* pProperties*/ , // /< [in][optional] argument properties
412426 const void
413427 *pArgValue // /< [in] argument value represented as matching arg type.
414428 ) try {
415429 TRACK_SCOPE_LATENCY (" urKernelSetArgPointer" );
416430
417431 std::scoped_lock<ur_shared_mutex> guard (hKernel->Mutex );
418- return hKernel->setArgPointer (argIndex, pProperties, pArgValue);
432+ // Store the raw pointer value and defer setting the
433+ // argument until we know the device where kernel is being submitted.
434+ return hKernel->addPendingPointerArgument (argIndex, pArgValue);
419435} catch (...) {
420436 return exceptionToResult (std::current_exception ());
421437}
0 commit comments