Skip to content

Commit

Permalink
[RUNTIME][VULKAN] Support total_global_memory (#16890)
Browse files Browse the repository at this point in the history
This PR supports total_global_memory query for vulkan devices.
  • Loading branch information
tqchen authored Apr 16, 2024
1 parent d1ac73c commit 3680a0d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/runtime/vulkan/vulkan_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_

for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
VkMemoryType ty = prop.memoryTypes[k];
size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
int64_t heap_size = static_cast<int64_t>(prop.memoryHeaps[ty.heapIndex].size);
// host visible
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
// match copy requirment
Expand All @@ -312,7 +312,7 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_
win_rank = -1;
for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
VkMemoryType ty = prop.memoryTypes[k];
size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
int64_t heap_size = static_cast<int64_t>(prop.memoryHeaps[ty.heapIndex].size);
// host visible
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
// match copy requirment
Expand All @@ -324,8 +324,10 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_
if (rank > win_rank) {
win_rank = rank;
compute_mtype_index = k;
compute_memory_size = heap_size;
}
}

ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device.";

if (device_properties.supports_push_descriptor) {
Expand Down Expand Up @@ -383,6 +385,7 @@ void VulkanDevice::do_swap(VulkanDevice&& other) {
std::swap(queue_insert_debug_utils_label_functions,
other.queue_insert_debug_utils_label_functions);
std::swap(compute_mtype_index, other.compute_mtype_index);
std::swap(compute_memory_size, other.compute_memory_size);
std::swap(queue, other.queue);
std::swap(queue_family_index, other.queue_family_index);
std::swap(physical_device_, other.physical_device_);
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/vulkan/vulkan_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ class VulkanDevice {
queue_insert_debug_utils_label_functions{nullptr};
// Memory type index for compute
uint32_t compute_mtype_index{0};
// maximum memory size for compute
int64_t compute_memory_size{0};

// queue family_index;
uint32_t queue_family_index{uint32_t(-1)};
Expand Down
1 change: 1 addition & 0 deletions src/runtime/vulkan/vulkan_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
break;

case kTotalGlobalMemory: {
*rv = device(index).compute_memory_size;
return;
}
}
Expand Down

0 comments on commit 3680a0d

Please sign in to comment.