|
| 1 | +use crate::{ |
| 2 | + feature_info::with_feature_info, |
| 3 | + nvsdk_ngx::{ |
| 4 | + DlssError, NVSDK_NGX_VULKAN_GetFeatureDeviceExtensionRequirements, |
| 5 | + NVSDK_NGX_VULKAN_GetFeatureInstanceExtensionRequirements, check_ngx_result, |
| 6 | + }, |
| 7 | +}; |
| 8 | +use ash::{Entry, vk::PhysicalDevice}; |
| 9 | +use std::{ffi::CStr, ptr, slice}; |
| 10 | +use uuid::Uuid; |
| 11 | +use wgpu::{ |
| 12 | + Adapter, Device, DeviceDescriptor, Instance, InstanceDescriptor, Queue, RequestDeviceError, |
| 13 | + hal::{DeviceError, InstanceError, api::Vulkan}, |
| 14 | +}; |
| 15 | + |
| 16 | +/// Creates a wgpu [`Instance`] with the extensions required for DLSS. |
| 17 | +/// |
| 18 | +/// If the system does not support DLSS, it will set `dlss_supported` to false. |
| 19 | +pub fn create_instance( |
| 20 | + project_id: Uuid, |
| 21 | + instance_descriptor: &InstanceDescriptor, |
| 22 | + dlss_supported: &mut bool, |
| 23 | +) -> Result<Instance, InitializationError> { |
| 24 | + unsafe { |
| 25 | + let mut result = Ok(()); |
| 26 | + let raw_instance = wgpu::hal::vulkan::Instance::init_with_callback( |
| 27 | + &wgpu::hal::InstanceDescriptor { |
| 28 | + name: "wgpu", |
| 29 | + flags: instance_descriptor.flags, |
| 30 | + memory_budget_thresholds: instance_descriptor.memory_budget_thresholds, |
| 31 | + backend_options: instance_descriptor.backend_options.clone(), |
| 32 | + }, |
| 33 | + Some(Box::new(|args| { |
| 34 | + match required_instance_extensions(project_id, args.entry) { |
| 35 | + Ok((extensions, true)) => args.extensions.extend(extensions), |
| 36 | + Ok((_, false)) => *dlss_supported = false, |
| 37 | + Err(err) => result = Err(err), |
| 38 | + } |
| 39 | + })), |
| 40 | + )?; |
| 41 | + result?; |
| 42 | + |
| 43 | + Ok(Instance::from_hal::<Vulkan>(raw_instance)) |
| 44 | + } |
| 45 | +} |
| 46 | + |
| 47 | +/// Creates a wgpu [`Device`] and [`Queue`] with the extensions required for DLSS. |
| 48 | +/// |
| 49 | +/// If the system does not support DLSS, it will set `dlss_supported` to false. |
| 50 | +/// |
| 51 | +/// The provided [`Adapter`] must be using the Vulkan backend. |
| 52 | +pub fn request_device( |
| 53 | + project_id: Uuid, |
| 54 | + adapter: &Adapter, |
| 55 | + device_descriptor: &DeviceDescriptor, |
| 56 | + dlss_supported: &mut bool, |
| 57 | +) -> Result<(Device, Queue), InitializationError> { |
| 58 | + unsafe { |
| 59 | + let raw_adapter = adapter |
| 60 | + .as_hal::<Vulkan>() |
| 61 | + .ok_or(InitializationError::UnsupportedBackend)?; |
| 62 | + let raw_instance = raw_adapter.shared_instance().raw_instance(); |
| 63 | + let raw_physical_device = raw_adapter.raw_physical_device(); |
| 64 | + |
| 65 | + let mut result = Ok(()); |
| 66 | + let open_device = raw_adapter.open_with_callback( |
| 67 | + device_descriptor.required_features, |
| 68 | + &device_descriptor.memory_hints, |
| 69 | + Some(Box::new(|args| { |
| 70 | + match required_device_extensions( |
| 71 | + project_id, |
| 72 | + &raw_adapter, |
| 73 | + raw_instance.handle(), |
| 74 | + raw_physical_device, |
| 75 | + ) { |
| 76 | + Ok((extensions, true)) => args.extensions.extend(extensions), |
| 77 | + Ok((_, false)) => *dlss_supported = false, |
| 78 | + Err(err) => result = Err(err), |
| 79 | + } |
| 80 | + })), |
| 81 | + )?; |
| 82 | + result?; |
| 83 | + |
| 84 | + Ok(adapter.create_device_from_hal::<Vulkan>(open_device, device_descriptor)?) |
| 85 | + } |
| 86 | +} |
| 87 | + |
| 88 | +fn required_instance_extensions( |
| 89 | + project_id: Uuid, |
| 90 | + entry: &Entry, |
| 91 | +) -> Result<(impl Iterator<Item = &'static CStr>, bool), InitializationError> { |
| 92 | + with_feature_info(project_id, |feature_info| unsafe { |
| 93 | + // Get required extension names |
| 94 | + let mut required_extensions = ptr::null_mut(); |
| 95 | + let mut required_extension_count = 0; |
| 96 | + check_ngx_result(NVSDK_NGX_VULKAN_GetFeatureInstanceExtensionRequirements( |
| 97 | + feature_info, |
| 98 | + &mut required_extension_count, |
| 99 | + &mut required_extensions, |
| 100 | + ))?; |
| 101 | + let required_extensions = |
| 102 | + slice::from_raw_parts(required_extensions, required_extension_count as usize); |
| 103 | + let required_extensions = required_extensions |
| 104 | + .iter() |
| 105 | + .map(|extension| CStr::from_ptr(extension.extension_name.as_ptr())); |
| 106 | + |
| 107 | + // Check that the required extensions are supported |
| 108 | + let supported_extensions = entry.enumerate_instance_extension_properties(None)?; |
| 109 | + let extensions_supported = required_extensions.clone().all(|required_extension| { |
| 110 | + supported_extensions |
| 111 | + .iter() |
| 112 | + .any(|extension| extension.extension_name_as_c_str() == Ok(required_extension)) |
| 113 | + }); |
| 114 | + |
| 115 | + Ok((required_extensions, extensions_supported)) |
| 116 | + }) |
| 117 | +} |
| 118 | + |
| 119 | +fn required_device_extensions( |
| 120 | + project_id: Uuid, |
| 121 | + raw_adapter: &wgpu::hal::vulkan::Adapter, |
| 122 | + raw_instance: ash::vk::Instance, |
| 123 | + raw_physical_device: PhysicalDevice, |
| 124 | +) -> Result<(impl Iterator<Item = &'static CStr>, bool), InitializationError> { |
| 125 | + with_feature_info(project_id, |feature_info| unsafe { |
| 126 | + // Get required extension names |
| 127 | + let mut required_extensions = ptr::null_mut(); |
| 128 | + let mut required_extension_count = 0; |
| 129 | + check_ngx_result(NVSDK_NGX_VULKAN_GetFeatureDeviceExtensionRequirements( |
| 130 | + raw_instance, |
| 131 | + raw_physical_device, |
| 132 | + feature_info, |
| 133 | + &mut required_extension_count, |
| 134 | + &mut required_extensions, |
| 135 | + ))?; |
| 136 | + let required_extensions = |
| 137 | + slice::from_raw_parts(required_extensions, required_extension_count as usize); |
| 138 | + let required_extensions = required_extensions |
| 139 | + .iter() |
| 140 | + .map(|extension| CStr::from_ptr(extension.extension_name.as_ptr())); |
| 141 | + |
| 142 | + // Check that the required extensions are supported |
| 143 | + let extensions_supported = required_extensions.clone().all(|required_extension| { |
| 144 | + raw_adapter |
| 145 | + .physical_device_capabilities() |
| 146 | + .supports_extension(required_extension) |
| 147 | + }); |
| 148 | + |
| 149 | + Ok((required_extensions, extensions_supported)) |
| 150 | + }) |
| 151 | +} |
| 152 | + |
| 153 | +/// Error returned by [`request_device`]. |
| 154 | +#[derive(thiserror::Error, Debug)] |
| 155 | +pub enum InitializationError { |
| 156 | + #[error(transparent)] |
| 157 | + InstanceError(#[from] InstanceError), |
| 158 | + #[error(transparent)] |
| 159 | + RequestDeviceError(#[from] RequestDeviceError), |
| 160 | + #[error(transparent)] |
| 161 | + DeviceError(#[from] DeviceError), |
| 162 | + #[error(transparent)] |
| 163 | + VulkanError(#[from] ash::vk::Result), |
| 164 | + #[error(transparent)] |
| 165 | + DlssError(#[from] DlssError), |
| 166 | + #[error("Provided adapter is not using the Vulkan backend")] |
| 167 | + UnsupportedBackend, |
| 168 | +} |
0 commit comments