@@ -30,12 +30,33 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
30
30
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
31
31
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
32
32
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
33
+ std::vector<size_t > compiledLocalWorksize;
34
+ if (!pLocalWorkSize) {
35
+ cl_device_id device = nullptr ;
36
+ CL_RETURN_ON_FAILURE (clGetCommandQueueInfo (
37
+ cl_adapter::cast<cl_command_queue>(hQueue), CL_QUEUE_DEVICE,
38
+ sizeof (device), &device, nullptr ));
39
+ // This query always returns size_t[3], if nothing was specified it returns
40
+ // all zeroes.
41
+ size_t queriedLocalWorkSize[3 ] = {0 , 0 , 0 };
42
+ CL_RETURN_ON_FAILURE (clGetKernelWorkGroupInfo (
43
+ cl_adapter::cast<cl_kernel>(hKernel), device,
44
+ CL_KERNEL_COMPILE_WORK_GROUP_SIZE, sizeof (size_t [3 ]),
45
+ queriedLocalWorkSize, nullptr ));
46
+ if (queriedLocalWorkSize[0 ] != 0 ) {
47
+ for (uint32_t i = 0 ; i < workDim; i++) {
48
+ compiledLocalWorksize.push_back (queriedLocalWorkSize[i]);
49
+ }
50
+ }
51
+ }
33
52
34
53
CL_RETURN_ON_FAILURE (clEnqueueNDRangeKernel (
35
54
cl_adapter::cast<cl_command_queue>(hQueue),
36
55
cl_adapter::cast<cl_kernel>(hKernel), workDim, pGlobalWorkOffset,
37
- pGlobalWorkSize, pLocalWorkSize, numEventsInWaitList,
38
- cl_adapter::cast<const cl_event *>(phEventWaitList),
56
+ pGlobalWorkSize,
57
+ compiledLocalWorksize.empty () ? pLocalWorkSize
58
+ : compiledLocalWorksize.data (),
59
+ numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
39
60
cl_adapter::cast<cl_event *>(phEvent)));
40
61
41
62
return UR_RESULT_SUCCESS;
0 commit comments