@@ -2577,62 +2577,62 @@ pi_result cuda_piEnqueueKernelLaunch(
2577
2577
size_t maxThreadsPerBlock[3 ] = {};
2578
2578
bool providedLocalWorkGroupSize = (local_work_size != nullptr );
2579
2579
pi_uint32 local_size = kernel->get_local_size ();
2580
+ pi_result retError = PI_SUCCESS;
2580
2581
2581
- // Set the active context here as guessLocalWorkSize needs an active context
2582
- ScopedContext active (command_queue->get_context ());
2583
- {
2584
- size_t *reqdThreadsPerBlock = kernel->reqdThreadsPerBlock_ ;
2585
- maxWorkGroupSize = command_queue->device_ ->get_max_work_group_size ();
2586
- command_queue->device_ ->get_max_work_item_sizes (sizeof (maxThreadsPerBlock),
2587
- maxThreadsPerBlock);
2588
-
2589
- if (providedLocalWorkGroupSize) {
2590
- auto isValid = [&](int dim) {
2591
- if (reqdThreadsPerBlock[dim] != 0 &&
2592
- local_work_size[dim] != reqdThreadsPerBlock[dim])
2593
- return PI_INVALID_WORK_GROUP_SIZE;
2594
-
2595
- if (local_work_size[dim] > maxThreadsPerBlock[dim])
2596
- return PI_INVALID_WORK_ITEM_SIZE;
2597
- // Checks that local work sizes are a divisor of the global work sizes
2598
- // which includes that the local work sizes are neither larger than the
2599
- // global work sizes and not 0.
2600
- if (0u == local_work_size[dim])
2601
- return PI_INVALID_WORK_GROUP_SIZE;
2602
- if (0u != (global_work_size[dim] % local_work_size[dim]))
2603
- return PI_INVALID_WORK_GROUP_SIZE;
2604
- threadsPerBlock[dim] = static_cast <int >(local_work_size[dim]);
2605
- return PI_SUCCESS;
2606
- };
2607
-
2608
- for (size_t dim = 0 ; dim < work_dim; dim++) {
2609
- auto err = isValid (dim);
2610
- if (err != PI_SUCCESS)
2611
- return err;
2582
+ try {
2583
+ // Set the active context here as guessLocalWorkSize needs an active context
2584
+ ScopedContext active (command_queue->get_context ());
2585
+ {
2586
+ size_t *reqdThreadsPerBlock = kernel->reqdThreadsPerBlock_ ;
2587
+ maxWorkGroupSize = command_queue->device_ ->get_max_work_group_size ();
2588
+ command_queue->device_ ->get_max_work_item_sizes (
2589
+ sizeof (maxThreadsPerBlock), maxThreadsPerBlock);
2590
+
2591
+ if (providedLocalWorkGroupSize) {
2592
+ auto isValid = [&](int dim) {
2593
+ if (reqdThreadsPerBlock[dim] != 0 &&
2594
+ local_work_size[dim] != reqdThreadsPerBlock[dim])
2595
+ return PI_INVALID_WORK_GROUP_SIZE;
2596
+
2597
+ if (local_work_size[dim] > maxThreadsPerBlock[dim])
2598
+ return PI_INVALID_WORK_ITEM_SIZE;
2599
+ // Checks that local work sizes are a divisor of the global work sizes
2600
+ // which includes that the local work sizes are neither larger than
2601
+ // the global work sizes and not 0.
2602
+ if (0u == local_work_size[dim])
2603
+ return PI_INVALID_WORK_GROUP_SIZE;
2604
+ if (0u != (global_work_size[dim] % local_work_size[dim]))
2605
+ return PI_INVALID_WORK_GROUP_SIZE;
2606
+ threadsPerBlock[dim] = static_cast <int >(local_work_size[dim]);
2607
+ return PI_SUCCESS;
2608
+ };
2609
+
2610
+ for (size_t dim = 0 ; dim < work_dim; dim++) {
2611
+ auto err = isValid (dim);
2612
+ if (err != PI_SUCCESS)
2613
+ return err;
2614
+ }
2615
+ } else {
2616
+ guessLocalWorkSize (threadsPerBlock, global_work_size,
2617
+ maxThreadsPerBlock, kernel, local_size);
2612
2618
}
2613
- } else {
2614
- guessLocalWorkSize (threadsPerBlock, global_work_size, maxThreadsPerBlock,
2615
- kernel, local_size);
2616
2619
}
2617
- }
2618
2620
2619
- if (maxWorkGroupSize <
2620
- size_t (threadsPerBlock[0 ] * threadsPerBlock[1 ] * threadsPerBlock[2 ])) {
2621
- return PI_INVALID_WORK_GROUP_SIZE;
2622
- }
2621
+ if (maxWorkGroupSize <
2622
+ size_t (threadsPerBlock[0 ] * threadsPerBlock[1 ] * threadsPerBlock[2 ])) {
2623
+ return PI_INVALID_WORK_GROUP_SIZE;
2624
+ }
2623
2625
2624
- int blocksPerGrid[3 ] = {1 , 1 , 1 };
2626
+ int blocksPerGrid[3 ] = {1 , 1 , 1 };
2625
2627
2626
- for (size_t i = 0 ; i < work_dim; i++) {
2627
- blocksPerGrid[i] =
2628
- static_cast <int >(global_work_size[i] + threadsPerBlock[i] - 1 ) /
2629
- threadsPerBlock[i];
2630
- }
2628
+ for (size_t i = 0 ; i < work_dim; i++) {
2629
+ blocksPerGrid[i] =
2630
+ static_cast <int >(global_work_size[i] + threadsPerBlock[i] - 1 ) /
2631
+ threadsPerBlock[i];
2632
+ }
2631
2633
2632
- pi_result retError = PI_SUCCESS;
2633
- std::unique_ptr<_pi_event> retImplEv{nullptr };
2634
+ std::unique_ptr<_pi_event> retImplEv{nullptr };
2634
2635
2635
- try {
2636
2636
CUstream cuStream = command_queue->get ();
2637
2637
CUfunction cuFunc = kernel->get ();
2638
2638
0 commit comments