@@ -117,6 +117,7 @@ class VulkanDeviceAPI final : public DeviceAPI {
117
117
}
118
118
void SetDevice (TVMContext ctx) final { VulkanThreadEntry::ThreadLocal ()->ctx = ctx; }
119
119
void GetAttr (TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final ;
120
+ uint32_t FindComputeQueue (VkPhysicalDevice phy_dev);
120
121
void * AllocDataSpace (TVMContext ctx, size_t nbytes, size_t alignment,
121
122
DLDataType type_hint) final {
122
123
const auto & vctx = context (ctx.device_id );
@@ -490,33 +491,17 @@ VulkanDeviceAPI::VulkanDeviceAPI() {
490
491
std::vector<VkPhysicalDevice> all_phy_devs (phy_dev_count);
491
492
VULKAN_CALL (vkEnumeratePhysicalDevices (instance_, &phy_dev_count, dmlc::BeginPtr (all_phy_devs)));
492
493
for (VkPhysicalDevice phy_dev : all_phy_devs) {
493
- uint32_t queue_prop_count = 0 ;
494
- vkGetPhysicalDeviceQueueFamilyProperties (phy_dev, &queue_prop_count, nullptr );
495
- std::vector<VkQueueFamilyProperties> queue_props (queue_prop_count);
496
- vkGetPhysicalDeviceQueueFamilyProperties (phy_dev, &queue_prop_count,
497
- dmlc::BeginPtr (queue_props));
498
- uint32_t queue_family_index = 0 ;
499
- std::vector<VkDeviceQueueCreateInfo> queue_create_info;
494
+ uint32_t queue_family_index = FindComputeQueue (phy_dev);
495
+ if (queue_family_index == -1U ) continue ;
500
496
float priority = 1 .0f ;
501
- for (uint32_t i = 0 ; i < queue_props.size (); i++) {
502
- // find queues that support compute
503
- if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags ) {
504
- VkDeviceQueueCreateInfo info;
505
- info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
506
- info.pNext = nullptr ;
507
- info.flags = 0 ;
508
- info.queueFamilyIndex = i;
509
- info.queueCount = 1 ;
510
- info.pQueuePriorities = &priority;
511
-
512
- queue_create_info.push_back (info);
513
- // only use the first available queue for now
514
- if (queue_create_info.size () == 0 ) {
515
- queue_family_index = i;
516
- }
517
- }
518
- }
519
- if (queue_create_info.size () == 0 ) continue ;
497
+
498
+ VkDeviceQueueCreateInfo queue_create_info;
499
+ queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
500
+ queue_create_info.pNext = nullptr ;
501
+ queue_create_info.flags = 0 ;
502
+ queue_create_info.queueFamilyIndex = queue_family_index;
503
+ queue_create_info.queueCount = 1 ;
504
+ queue_create_info.pQueuePriorities = &priority;
520
505
521
506
VulkanContext ctx;
522
507
// setup context
@@ -554,8 +539,8 @@ VulkanDeviceAPI::VulkanDeviceAPI() {
554
539
device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
555
540
device_create_info.pNext = nullptr ;
556
541
device_create_info.flags = 0 ;
557
- device_create_info.queueCreateInfoCount = static_cast < uint32_t >(queue_create_info. size ()) ;
558
- device_create_info.pQueueCreateInfos = queue_create_info. data () ;
542
+ device_create_info.queueCreateInfoCount = 1 ;
543
+ device_create_info.pQueueCreateInfos = & queue_create_info;
559
544
device_create_info.enabledLayerCount = 0 ;
560
545
device_create_info.ppEnabledLayerNames = nullptr ;
561
546
device_create_info.enabledExtensionCount = extensions.size ();
@@ -677,7 +662,34 @@ VulkanDeviceAPI::VulkanDeviceAPI() {
677
662
<< " \' phy_dev_id=" << context_[i].phy_device
678
663
<< " use_immediate=" << context_[i].UseImmediate ();
679
664
}
680
- } // namespace vulkan
665
+ }
666
+
667
+ uint32_t VulkanDeviceAPI::FindComputeQueue (VkPhysicalDevice phy_dev) {
668
+ uint32_t queue_prop_count = 0 ;
669
+ vkGetPhysicalDeviceQueueFamilyProperties (phy_dev, &queue_prop_count, nullptr );
670
+ std::vector<VkQueueFamilyProperties> queue_props (queue_prop_count);
671
+ vkGetPhysicalDeviceQueueFamilyProperties (phy_dev, &queue_prop_count, dmlc::BeginPtr (queue_props));
672
+ // Prefer compute-only queues. On cerain devices supporting this (e.g. Mesa RADV), using
673
+ // compute-only queues gives better responsiveness for other graphics workload (e.g. desktop).
674
+ auto compute_dedicated = std::find_if (queue_props.begin (), queue_props.end (), [](auto prop) {
675
+ return (VK_QUEUE_COMPUTE_BIT & prop.queueFlags ) != 0 &&
676
+ (VK_QUEUE_GRAPHICS_BIT & prop.queueFlags ) == 0 ;
677
+ });
678
+ if (compute_dedicated == queue_props.end ()) {
679
+ auto compute = std::find_if (queue_props.begin (), queue_props.end (), [](auto prop) {
680
+ return (VK_QUEUE_COMPUTE_BIT & prop.queueFlags ) != 0 ;
681
+ });
682
+ if (compute == queue_props.end ()) {
683
+ return -1 ;
684
+ } else {
685
+ return std::distance (queue_props.begin (), compute);
686
+ }
687
+ } else {
688
+ return std::distance (queue_props.begin (), compute_dedicated);
689
+ }
690
+ }
691
+
692
+ // namespace vulkan
681
693
class VulkanModuleNode ;
682
694
683
695
// a wrapped function class to get packed func.
0 commit comments