@@ -542,9 +542,28 @@ pi_result cuda_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms,
542542
543543 static std::once_flag initFlag;
544544 static _pi_platform platformId;
545- std::call_once (initFlag,
546- [](pi_result &err) { err = PI_CHECK_ERROR (cuInit (0 )); },
547- err);
545+ std::call_once (
546+ initFlag,
547+ [](pi_result &err) {
548+ err = PI_CHECK_ERROR (cuInit (0 ));
549+
550+ int numDevices = 0 ;
551+ err = PI_CHECK_ERROR (cuDeviceGetCount (&numDevices));
552+ platformId.devices_ .reserve (numDevices);
553+ try {
554+ for (int i = 0 ; i < numDevices; ++i) {
555+ CUdevice device;
556+ err = PI_CHECK_ERROR (cuDeviceGet (&device, i));
557+ platformId.devices_ .emplace_back (
558+ new _pi_device{device, &platformId});
559+ }
560+ } catch (...) {
561+ // Clear and rethrow to allow retry
562+ platformId.devices_ .clear ();
563+ throw ;
564+ }
565+ },
566+ err);
548567
549568 *platforms = &platformId;
550569 }
@@ -594,22 +613,16 @@ pi_result cuda_piDevicesGet(pi_platform platform, pi_device_type device_type,
594613
595614 pi_result err = PI_SUCCESS;
596615 const bool askingForGPU = (device_type & PI_DEVICE_TYPE_GPU);
597- size_t numDevices = askingForGPU ? 1 : 0 ;
616+ size_t numDevices = askingForGPU ? platform-> devices_ . size () : 0 ;
598617
599618 try {
600619 if (num_devices) {
601620 *num_devices = numDevices;
602621 }
603622
604- if (askingForGPU) {
605- if (devices) {
606- CUdevice device;
607- err = PI_CHECK_ERROR (cuDeviceGet (&device, 0 ));
608- *devices = new _pi_device{device, platform};
609- }
610- } else {
611- if (devices) {
612- *devices = nullptr ;
623+ if (askingForGPU && devices) {
624+ for (size_t i = 0 ; i < std::min (size_t (num_entries), numDevices); ++i) {
625+ devices[i] = platform->devices_ [i].get ();
613626 }
614627 }
615628
0 commit comments