|
11 | 11 | // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
|
12 | 12 |
|
13 | 13 | #include <executorch/backends/vulkan/runtime/utils/MacroUtils.h>
|
| 14 | +#include <executorch/backends/vulkan/runtime/utils/VecUtils.h> |
14 | 15 |
|
15 | 16 | #include <executorch/backends/vulkan/runtime/vk_api/Adapter.h>
|
16 | 17 | #include <executorch/backends/vulkan/runtime/vk_api/Command.h>
|
@@ -150,7 +151,7 @@ class Context final {
|
150 | 151 | void report_shader_dispatch_start(
|
151 | 152 | const std::string& shader_name,
|
152 | 153 | const utils::uvec3& global_wg_size,
|
153 |
| - const utils::uvec3& local_wg_size, |
| 154 | + const utils::WorkgroupSize& local_wg_size, |
154 | 155 | const uint32_t dispatch_id = UINT32_MAX);
|
155 | 156 |
|
156 | 157 | /*
|
@@ -189,13 +190,13 @@ class Context final {
|
189 | 190 |
|
190 | 191 | vkapi::DescriptorSet get_descriptor_set(
|
191 | 192 | const vkapi::ShaderInfo&,
|
192 |
| - const utils::uvec3&, |
| 193 | + const utils::WorkgroupSize&, |
193 | 194 | const vkapi::SpecVarList&,
|
194 | 195 | const uint32_t push_constants_size);
|
195 | 196 |
|
196 | 197 | inline vkapi::DescriptorSet get_descriptor_set(
|
197 | 198 | const vkapi::ShaderInfo& shader_descriptor,
|
198 |
| - const utils::uvec3& local_work_group_size) { |
| 199 | + const utils::WorkgroupSize& local_work_group_size) { |
199 | 200 | return get_descriptor_set(shader_descriptor, local_work_group_size, {}, 0u);
|
200 | 201 | }
|
201 | 202 |
|
@@ -362,14 +363,17 @@ inline bool Context::submit_compute_job(
|
362 | 363 | report_shader_dispatch_start(
|
363 | 364 | shader.kernel_name,
|
364 | 365 | global_work_group,
|
365 |
| - local_work_group_size, |
| 366 | + utils::WorkgroupSize(local_work_group_size), |
366 | 367 | dispatch_id);
|
367 | 368 |
|
368 | 369 | // Factor out template parameter independent code to minimize code bloat.
|
369 | 370 | // Note that push constants are not exposed yet via this API, therefore the
|
370 | 371 | // push constants size is assumed to be 0.
|
371 | 372 | vkapi::DescriptorSet descriptor_set = get_descriptor_set(
|
372 |
| - shader, local_work_group_size, specialization_constants, 0u); |
| 373 | + shader, |
| 374 | + utils::WorkgroupSize(local_work_group_size), |
| 375 | + specialization_constants, |
| 376 | + 0u); |
373 | 377 |
|
374 | 378 | detail::bind(
|
375 | 379 | descriptor_set,
|
|
0 commit comments