@@ -202,29 +202,29 @@ static bool isDeviceInvalidForBe(const device &Device) {
202202 if (Device.is_host ())
203203 return false ;
204204
205- // Taking the version information from the platform gives us more useful
206- // information than the driver_version of the device.
205+ // Retrieve Platform version to identify CUDA OpenCL platform
206+ // String: OpenCL 1.2 CUDA <version>
207207 const platform platform = Device.get_info <info::device::platform>();
208208 const std::string platformVersion =
209209 platform.get_info <info::platform::version>();
210+ const bool HasOpenCL = (platformVersion.find (" OpenCL" ) != std::string::npos);
211+ const bool HasCUDA = (platformVersion.find (" CUDA" ) != std::string::npos);
210212
211- backend *BackendPref = detail::SYCLConfig<detail::SYCL_BE>::get ();
212- auto BackendType = detail::getSyclObjImpl (Device)->getPlugin ().getBackend ();
213- static_assert (std::is_same<backend, decltype (BackendType)>(),
214- " Type is not the same" );
213+ backend *PrefBackend = detail::SYCLConfig<detail::SYCL_BE>::get ();
214+ auto DeviceBackend = detail::getSyclObjImpl (Device)->getPlugin ().getBackend ();
215215
216- // If no preference, assume OpenCL and reject CUDA backend
217- if (BackendType == backend::cuda && !BackendPref) {
216+ // Reject the NVIDIA OpenCL implementation
217+ if (DeviceBackend == backend::opencl && HasCUDA && HasOpenCL)
218218 return true ;
219- } else if (!BackendPref)
220- return false ;
221219
222- // If using PI_CUDA, don't accept a non-CUDA device
223- if (BackendType == backend::opencl && *BackendPref == backend::cuda)
220+ // If no preference, assume OpenCL and reject CUDA
221+ if (DeviceBackend == backend::cuda && !PrefBackend) {
224222 return true ;
223+ } else if (!PrefBackend)
224+ return false ;
225225
226- // If using PI_OPENCL, don't accept a non-OpenCL device
227- if (BackendType == backend::cuda && *BackendPref == backend::opencl)
226+ // If using PI_OPENCL, reject the CUDA backend
227+ if (DeviceBackend == backend::cuda && *PrefBackend == backend::opencl)
228228 return true ;
229229
230230 return false ;
0 commit comments