Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 44 additions & 29 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class VulkanDeviceAPI final : public DeviceAPI {
}
void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; }
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
std::vector<uint32_t> GetComputeQueueFamilies(VkPhysicalDevice phy_dev);
void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
DLDataType type_hint) final {
const auto& vctx = context(ctx.device_id);
Expand Down Expand Up @@ -490,33 +491,20 @@ VulkanDeviceAPI::VulkanDeviceAPI() {
std::vector<VkPhysicalDevice> all_phy_devs(phy_dev_count);
VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, dmlc::BeginPtr(all_phy_devs)));
for (VkPhysicalDevice phy_dev : all_phy_devs) {
uint32_t queue_prop_count = 0;
vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr);
std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count,
dmlc::BeginPtr(queue_props));
uint32_t queue_family_index = 0;
std::vector<VkDeviceQueueCreateInfo> queue_create_info;
// Get a list of queue families supporting compute, in order of preference. We currently only
// make use of the most preferred one family.
std::vector<uint32_t> queue_family_indexes = GetComputeQueueFamilies(phy_dev);
if (queue_family_indexes.empty()) continue;
uint32_t queue_family_index = queue_family_indexes[0];
float priority = 1.0f;
for (uint32_t i = 0; i < queue_props.size(); i++) {
// find queues that support compute
if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) {
VkDeviceQueueCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.queueFamilyIndex = i;
info.queueCount = 1;
info.pQueuePriorities = &priority;

queue_create_info.push_back(info);
// only use the first available queue for now
if (queue_create_info.size() == 0) {
queue_family_index = i;
}
}
}
if (queue_create_info.size() == 0) continue;

struct VkDeviceQueueCreateInfo queue_create_info;
queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
queue_create_info.pNext = nullptr;
queue_create_info.flags = 0;
queue_create_info.queueFamilyIndex = queue_family_index;
queue_create_info.queueCount = 1;
queue_create_info.pQueuePriorities = &priority;

VulkanContext ctx;
// setup context
Expand Down Expand Up @@ -554,8 +542,8 @@ VulkanDeviceAPI::VulkanDeviceAPI() {
device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
device_create_info.pNext = nullptr;
device_create_info.flags = 0;
device_create_info.queueCreateInfoCount = static_cast<uint32_t>(queue_create_info.size());
device_create_info.pQueueCreateInfos = queue_create_info.data();
device_create_info.queueCreateInfoCount = 1;
device_create_info.pQueueCreateInfos = &queue_create_info;
device_create_info.enabledLayerCount = 0;
device_create_info.ppEnabledLayerNames = nullptr;
device_create_info.enabledExtensionCount = extensions.size();
Expand Down Expand Up @@ -677,7 +665,34 @@ VulkanDeviceAPI::VulkanDeviceAPI() {
<< "\' phy_dev_id=" << context_[i].phy_device
<< " use_immediate=" << context_[i].UseImmediate();
}
} // namespace vulkan
}

std::vector<uint32_t> VulkanDeviceAPI::GetComputeQueueFamilies(VkPhysicalDevice phy_dev) {
uint32_t queue_prop_count = 0;
vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr);
std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, dmlc::BeginPtr(queue_props));

std::vector<uint32_t> result;
// Prefer compute-only queues. On cerain devices supporting this (e.g. Mesa RADV), using
// compute-only queues gives better responsiveness for other graphics workload (e.g. desktop).
for (uint32_t i = 0; i != queue_prop_count; ++i) {
if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 &&
(VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) == 0) {
result.push_back(i);
}
}
// Now, push the compute queues that we skipped above into the list.
for (uint32_t i = 0; i != queue_prop_count; ++i) {
if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 &&
(VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) != 0) {
result.push_back(i);
}
}
return result;
}

// namespace vulkan
class VulkanModuleNode;

// a wrapped function class to get packed func.
Expand Down