Skip to content

Commit 973eb1c

Browse files
[ET-VK] Replacing the use of uvec3 with WorkgroupSize class to reduce memory usage and improve processing speed (#8671)
* [ET-VK] Adding a workgroup class to VecUtils Pull Request resolved: #8632 This diff adds a new class called `WorkgroupSize` to the `VecUtils` header file. The `WorkgroupSize` class takes three `uint32_t` values as parameters and stores them in a single `uint32_t` variable using bitwise operations. This class is used in the Vulkan backend to specify the size of a workgroup for a given operation. ghstack-source-id: 268172661 @exported-using-ghexport Differential Revision: [D70021019](https://our.internmc.facebook.com/intern/diff/D70021019/) * [ET-VK] Adding reserve and append functions to SpecVarList Pull Request resolved: #8633 This diff adds two new functions to the SpecVarList class in the Vulkan runtime library. The first function, reserve, allows the user to reserve a certain amount of space in the SpecVarList before adding any elements. The second function, append, allows the user to add a single SpecVar to the SpecVarList. These functions are useful for optimizing memory usage and improving performance in the Vulkan runtime. ghstack-source-id: 268172659 @exported-using-ghexport Differential Revision: [D70021782](https://our.internmc.facebook.com/intern/diff/D70021782/) * [ET-VK] Replacing the use of uvec3 with WorkgroupSize class to reduce memory usage and improve processing speed Pull Request resolved: #8634 This diff replaces the use of `uvec3` with `WorkgroupSize` class to reduce memory usage and improve processing speed in the Vulkan backend of Executorch. ghstack-source-id: 268172660 @exported-using-ghexport Differential Revision: [D70021032](https://our.internmc.facebook.com/intern/diff/D70021032/) --------- Co-authored-by: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com>
1 parent 54910bc commit 973eb1c

File tree

11 files changed

+42
-31
lines changed

11 files changed

+42
-31
lines changed

backends/vulkan/runtime/api/Context.cpp

+5-11
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ void Context::cmd_reset_querypool() {
7474
void Context::report_shader_dispatch_start(
7575
const std::string& shader_name,
7676
const utils::uvec3& global_wg_size,
77-
const utils::uvec3& local_wg_size,
77+
const utils::WorkgroupSize& local_wg_size,
7878
const uint32_t dispatch_id) {
7979
if (querypool_) {
8080
querypool_.shader_profile_begin(
8181
cmd_,
8282
dispatch_id,
8383
shader_name,
8484
vkapi::create_extent3d(global_wg_size),
85-
vkapi::create_extent3d(local_wg_size));
85+
vkapi::create_extent3d((utils::uvec3)local_wg_size));
8686
}
8787
}
8888

@@ -115,7 +115,7 @@ void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) {
115115

116116
vkapi::DescriptorSet Context::get_descriptor_set(
117117
const vkapi::ShaderInfo& shader_descriptor,
118-
const utils::uvec3& local_workgroup_size,
118+
const utils::WorkgroupSize& local_workgroup_size,
119119
const vkapi::SpecVarList& additional_constants,
120120
const uint32_t push_constants_size) {
121121
VkDescriptorSetLayout shader_layout =
@@ -124,17 +124,11 @@ vkapi::DescriptorSet Context::get_descriptor_set(
124124
VkPipelineLayout pipeline_layout =
125125
pipeline_layout_cache().retrieve(shader_layout, push_constants_size);
126126

127-
vkapi::SpecVarList spec_constants = {
128-
SV(local_workgroup_size[0u]),
129-
SV(local_workgroup_size[1u]),
130-
SV(local_workgroup_size[2u])};
131-
132-
spec_constants.append(additional_constants);
133-
134127
VkPipeline pipeline = pipeline_cache().retrieve(
135128
{pipeline_layout_cache().retrieve(shader_layout, push_constants_size),
136129
shader_cache().retrieve(shader_descriptor),
137-
spec_constants});
130+
additional_constants,
131+
local_workgroup_size});
138132

139133
cmd_.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size);
140134

backends/vulkan/runtime/api/Context.h

+9-5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
1212

1313
#include <executorch/backends/vulkan/runtime/utils/MacroUtils.h>
14+
#include <executorch/backends/vulkan/runtime/utils/VecUtils.h>
1415

1516
#include <executorch/backends/vulkan/runtime/vk_api/Adapter.h>
1617
#include <executorch/backends/vulkan/runtime/vk_api/Command.h>
@@ -150,7 +151,7 @@ class Context final {
150151
void report_shader_dispatch_start(
151152
const std::string& shader_name,
152153
const utils::uvec3& global_wg_size,
153-
const utils::uvec3& local_wg_size,
154+
const utils::WorkgroupSize& local_wg_size,
154155
const uint32_t dispatch_id = UINT32_MAX);
155156

156157
/*
@@ -189,13 +190,13 @@ class Context final {
189190

190191
vkapi::DescriptorSet get_descriptor_set(
191192
const vkapi::ShaderInfo&,
192-
const utils::uvec3&,
193+
const utils::WorkgroupSize&,
193194
const vkapi::SpecVarList&,
194195
const uint32_t push_constants_size);
195196

196197
inline vkapi::DescriptorSet get_descriptor_set(
197198
const vkapi::ShaderInfo& shader_descriptor,
198-
const utils::uvec3& local_work_group_size) {
199+
const utils::WorkgroupSize& local_work_group_size) {
199200
return get_descriptor_set(shader_descriptor, local_work_group_size, {}, 0u);
200201
}
201202

@@ -362,14 +363,17 @@ inline bool Context::submit_compute_job(
362363
report_shader_dispatch_start(
363364
shader.kernel_name,
364365
global_work_group,
365-
local_work_group_size,
366+
utils::WorkgroupSize(local_work_group_size),
366367
dispatch_id);
367368

368369
// Factor out template parameter independent code to minimize code bloat.
369370
// Note that push constants are not exposed yet via this API, therefore the
370371
// push constants size is assumed to be 0.
371372
vkapi::DescriptorSet descriptor_set = get_descriptor_set(
372-
shader, local_work_group_size, specialization_constants, 0u);
373+
shader,
374+
utils::WorkgroupSize(local_work_group_size),
375+
specialization_constants,
376+
0u);
373377

374378
detail::bind(
375379
descriptor_set,

backends/vulkan/runtime/graph/ops/BlitNode.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ void BlitNode::encode(ComputeGraph* graph) {
4646
kernel_name += vkapi::to_string(dst_tensor->dtype());
4747

4848
context->report_shader_dispatch_start(
49-
kernel_name, utils::uvec3(), utils::uvec3(), node_id_);
49+
kernel_name, utils::uvec3(), utils::WorkgroupSize(), node_id_);
5050

5151
context->register_blit(
5252
pipeline_barrier,

backends/vulkan/runtime/graph/ops/DispatchNode.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class DispatchNode final : public ExecuteNode {
9292
protected:
9393
const vkapi::ShaderInfo shader_;
9494
const utils::uvec3 global_workgroup_size_;
95-
const utils::uvec3 local_workgroup_size_;
95+
const utils::WorkgroupSize local_workgroup_size_;
9696
const vkapi::ParamsBindList params_;
9797
const vkapi::SpecVarList spec_vars_;
9898
const std::vector<PushConstantDataInfo> push_constants_;

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ void PrepackNode::encode(ComputeGraph* graph) {
100100
// bound with the correct image layout.
101101
{
102102
vkapi::PipelineBarrier pipeline_barrier{};
103-
vkapi::DescriptorSet descriptor_set =
104-
context->get_descriptor_set(noop_shader_, {1, 1, 1});
103+
vkapi::DescriptorSet descriptor_set = context->get_descriptor_set(
104+
noop_shader_, utils::WorkgroupSize(1, 1, 1));
105105

106106
bind_tensor_to_descriptor_set(
107107
*packed,

backends/vulkan/runtime/graph/ops/PrepackNode.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class PrepackNode final {
4949
const vkapi::ShaderInfo shader_;
5050
vkapi::ShaderInfo noop_shader_;
5151
const utils::uvec3 global_workgroup_size_;
52-
const utils::uvec3 local_workgroup_size_;
52+
const utils::WorkgroupSize local_workgroup_size_;
5353
const ValueRef tref_;
5454
const ValueRef packed_;
5555
const vkapi::ParamsBindList params_;

backends/vulkan/runtime/vk_api/Command.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ void CommandBuffer::end() {
8181
void CommandBuffer::bind_pipeline(
8282
VkPipeline pipeline,
8383
VkPipelineLayout pipeline_layout,
84-
const utils::uvec3 local_workgroup_size) {
84+
const utils::WorkgroupSize local_workgroup_size) {
8585
VK_CHECK_COND(
8686
state_ == CommandBuffer::State::RECORDING,
8787
"Vulkan CommandBuffer: called bind_pipeline() on a command buffer whose state "

backends/vulkan/runtime/vk_api/Command.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class CommandBuffer final {
5151
struct Bound {
5252
VkPipeline pipeline;
5353
VkPipelineLayout pipeline_layout;
54-
utils::uvec3 local_workgroup_size;
54+
utils::WorkgroupSize local_workgroup_size;
5555
VkDescriptorSet descriptors;
5656

5757
explicit Bound()
@@ -63,7 +63,7 @@ class CommandBuffer final {
6363
inline void reset() {
6464
pipeline = VK_NULL_HANDLE;
6565
pipeline_layout = VK_NULL_HANDLE;
66-
local_workgroup_size = {0u, 0u, 0u};
66+
local_workgroup_size = utils::WorkgroupSize{0u, 0u, 0u};
6767
descriptors = VK_NULL_HANDLE;
6868
}
6969
};
@@ -87,7 +87,7 @@ class CommandBuffer final {
8787
void begin();
8888
void end();
8989

90-
void bind_pipeline(VkPipeline, VkPipelineLayout, const utils::uvec3);
90+
void bind_pipeline(VkPipeline, VkPipelineLayout, const utils::WorkgroupSize);
9191
void bind_descriptors(VkDescriptorSet);
9292
void set_push_constants(VkPipelineLayout, const void*, uint32_t);
9393

backends/vulkan/runtime/vk_api/Pipeline.cpp

+14-5
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,23 @@ ComputePipeline::ComputePipeline(
275275
const ComputePipeline::Descriptor& descriptor,
276276
VkPipelineCache pipeline_cache)
277277
: device_(device), handle_{VK_NULL_HANDLE} {
278-
std::vector<VkSpecializationMapEntry> map_entries =
279-
descriptor.specialization_constants.generate_map_entries();
278+
SpecVarList specialization_constants;
279+
280+
specialization_constants.reserve(
281+
3 + descriptor.specialization_constants.size());
282+
specialization_constants.append(descriptor.local_wg_size[0]);
283+
specialization_constants.append(descriptor.local_wg_size[1]);
284+
specialization_constants.append(descriptor.local_wg_size[2]);
285+
286+
specialization_constants.append(descriptor.specialization_constants);
287+
const std::vector<VkSpecializationMapEntry> map_entries =
288+
specialization_constants.generate_map_entries();
280289

281290
const VkSpecializationInfo specialization_info{
282-
descriptor.specialization_constants.size(), // mapEntryCount
291+
specialization_constants.size(), // mapEntryCount
283292
map_entries.data(), // pMapEntries
284-
descriptor.specialization_constants.data_nbytes(), // dataSize
285-
descriptor.specialization_constants.data(), // pData
293+
specialization_constants.data_nbytes(), // dataSize
294+
specialization_constants.data(), // pData
286295
};
287296

288297
const VkPipelineShaderStageCreateInfo shader_stage_create_info{

backends/vulkan/runtime/vk_api/Pipeline.h

+4
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ class ComputePipeline final {
156156
VkPipelineLayout pipeline_layout;
157157
VkShaderModule shader_module;
158158
SpecVarList specialization_constants;
159+
utils::WorkgroupSize local_wg_size;
159160
};
160161

161162
explicit ComputePipeline(
@@ -273,6 +274,9 @@ class ComputePipelineCache final {
273274
seed = utils::hash_combine(seed, new_seed);
274275
}
275276

277+
seed = utils::hash_combine(
278+
seed, std::hash<uint32_t>()((uint32_t)descriptor.local_wg_size));
279+
276280
return seed;
277281
}
278282
};

backends/vulkan/runtime/vk_api/Shader.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ struct ShaderInfo final {
6161
ShaderLayout::Signature kernel_layout{};
6262

6363
// Shader Metadata
64-
utils::uvec3 out_tile_size{1u, 1u, 1u};
64+
utils::WorkgroupSize out_tile_size{1u, 1u, 1u};
6565
bool requires_shader_int16 = false;
6666
bool requires_16bit_storage = false;
6767
bool requires_8bit_storage = false;

0 commit comments

Comments
 (0)