Skip to content

Commit d50f26f

Browse files
authored
[ET-VK] Introduce DynamicDispatchNode (#11000)
## Context The `DynamicDispatchNode` class in introduced in this diff to allow for shader re-selection upon input resize. See the previous diff in the stack for more context on why this functionality is needed. Differential Revision: [D75013780](https://our.internmc.facebook.com/intern/diff/D75013780/)
1 parent 0ece07d commit d50f26f

7 files changed

+312
-4
lines changed

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
2222

2323
#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>
24+
#include <executorch/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h>
2425
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
2526
#include <executorch/backends/vulkan/runtime/graph/ops/PrepackNode.h>
2627

backends/vulkan/runtime/graph/ops/DispatchNode.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class ComputeGraph;
2222
/*
2323
* Represents a single shader execution op in a ML model.
2424
*/
25-
class DispatchNode final : public ExecuteNode {
25+
class DispatchNode : public ExecuteNode {
2626
friend class ComputeGraph;
2727

2828
public:
@@ -43,9 +43,9 @@ class DispatchNode final : public ExecuteNode {
4343
void encode(ComputeGraph* graph) override;
4444

4545
protected:
46-
const vkapi::ShaderInfo shader_;
47-
const utils::uvec3 global_workgroup_size_;
48-
const utils::WorkgroupSize local_workgroup_size_;
46+
vkapi::ShaderInfo shader_;
47+
utils::uvec3 global_workgroup_size_;
48+
utils::WorkgroupSize local_workgroup_size_;
4949
const vkapi::ParamsBindList params_;
5050
const vkapi::SpecVarList spec_vars_;
5151
const std::vector<PushConstantDataInfo> push_constants_;
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12+
13+
namespace vkcompute {
14+
15+
DynamicDispatchNode::DynamicDispatchNode(
16+
ComputeGraph& graph,
17+
const PickShaderFn& pick_shader_fn,
18+
const PickGlobalFn& pick_global_wg_fn,
19+
const PickLocalFn& pick_local_wg_fn,
20+
const std::vector<ArgGroup>& args,
21+
const vkapi::ParamsBindList& params,
22+
const std::vector<PushConstantDataInfo>& push_constants,
23+
const vkapi::SpecVarList& spec_vars,
24+
const std::vector<ValueRef>& resize_args,
25+
const ResizeFunction& resize_fn)
26+
: DispatchNode(
27+
graph,
28+
pick_shader_fn(&graph, args, resize_args),
29+
pick_global_wg_fn(&graph, args, resize_args),
30+
pick_local_wg_fn(&graph, args, resize_args),
31+
args,
32+
params,
33+
push_constants,
34+
spec_vars,
35+
resize_args,
36+
resize_fn),
37+
pick_shader_fn_(pick_shader_fn),
38+
pick_global_wg_fn_(pick_global_wg_fn),
39+
pick_local_wg_fn_(pick_local_wg_fn) {}
40+
41+
void DynamicDispatchNode::encode(ComputeGraph* graph) {
42+
shader_ = pick_shader_fn_(graph, args_, resize_args_);
43+
global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_);
44+
local_workgroup_size_ =
45+
utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_));
46+
DispatchNode::encode(graph);
47+
}
48+
49+
} // namespace vkcompute
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/containers/PushConstantData.h>
14+
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
15+
16+
#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>
17+
18+
namespace vkcompute {
19+
20+
class ComputeGraph;
21+
22+
/*
23+
* Represents a single shader execution op in a ML model.
24+
*/
25+
class DynamicDispatchNode final : public DispatchNode {
26+
friend class ComputeGraph;
27+
28+
public:
29+
using PickShaderFn = const std::function<vkapi::ShaderInfo(
30+
ComputeGraph*,
31+
const std::vector<ArgGroup>&,
32+
const std::vector<ValueRef>&)>;
33+
using PickGlobalFn = const std::function<utils::uvec3(
34+
ComputeGraph*,
35+
const std::vector<ArgGroup>&,
36+
const std::vector<ValueRef>&)>;
37+
using PickLocalFn = const std::function<utils::uvec3(
38+
ComputeGraph*,
39+
const std::vector<ArgGroup>&,
40+
const std::vector<ValueRef>&)>;
41+
42+
explicit DynamicDispatchNode(
43+
ComputeGraph& graph,
44+
const PickShaderFn& pick_shader_fn,
45+
const PickGlobalFn& pick_global_wg_fn,
46+
const PickLocalFn& pick_local_wg_fn,
47+
const std::vector<ArgGroup>& args,
48+
const vkapi::ParamsBindList& params,
49+
const std::vector<PushConstantDataInfo>& push_constants,
50+
const vkapi::SpecVarList& spec_vars,
51+
const std::vector<ValueRef>& resize_args,
52+
const ResizeFunction& resize_fn = nullptr);
53+
54+
~DynamicDispatchNode() override = default;
55+
56+
void encode(ComputeGraph* graph) override;
57+
58+
protected:
59+
const PickShaderFn pick_shader_fn_;
60+
const PickGlobalFn pick_global_wg_fn_;
61+
const PickLocalFn pick_local_wg_fn_;
62+
63+
public:
64+
operator bool() const {
65+
return shader_;
66+
}
67+
};
68+
69+
} // namespace vkcompute
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
layout(std430) buffer;
14+
15+
${layout_declare_tensor(0, "w", "t_out", "float", "texture3d")}
16+
${layout_declare_tensor(1, "r", "t_in1", "float", "texture3d")}
17+
${layout_declare_tensor(2, "r", "t_in2", "float", "texture3d")}
18+
19+
layout(push_constant) uniform restrict Block {
20+
ivec4 out_sizes;
21+
ivec4 in1_sizes;
22+
ivec4 in2_sizes;
23+
};
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
void main() {
28+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
29+
30+
if (any(greaterThanEqual(pos, out_sizes.xyz))) {
31+
return;
32+
}
33+
34+
35+
vec4 out_texel = vec4(0.0);
36+
for (int row = 0; row < in1_sizes.y; ++row) {
37+
ivec3 in_pos = ivec3(pos.x, row, pos.z);
38+
vec4 in1_texel = texelFetch(t_in1, in_pos, 0);
39+
vec4 in2_texel = texelFetch(t_in2, in_pos, 0);
40+
41+
out_texel += in1_texel * in2_texel;
42+
}
43+
44+
imageStore(t_out, pos, out_texel + ${OFFSET});
45+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
dynamic_dispatch_test:
2+
parameter_names_with_default_values:
3+
OFFSET: 2.25
4+
shader_variants:
5+
- NAME: dynamic_dispatch_test_var1
6+
- NAME: dynamic_dispatch_test_var2
7+
OFFSET: 5.5

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3297,3 +3297,140 @@ TEST(VulkanComputeGraphOpsTest, test_to_copy) {
32973297
test_to_copy();
32983298
}
32993299
}
3300+
3301+
vkapi::ShaderInfo pick_dynamic_dispatch_shader(
3302+
ComputeGraph* graph,
3303+
const std::vector<ArgGroup>& args,
3304+
const std::vector<ValueRef>& additional_args) {
3305+
const ValueRef mat1 = args[1].refs[0];
3306+
3307+
std::string kernel_name = "dynamic_dispatch_test";
3308+
if (graph->size_at<int32_t>(-2, mat1) == 1) {
3309+
kernel_name += "_var1";
3310+
} else {
3311+
kernel_name += "_var2";
3312+
}
3313+
return VK_KERNEL_FROM_STR(kernel_name);
3314+
}
3315+
3316+
utils::uvec3 pick_dynamic_dispatch_global_wg_size(
3317+
ComputeGraph* graph,
3318+
const std::vector<ArgGroup>& args,
3319+
const std::vector<ValueRef>& additional_args) {
3320+
const ValueRef out = args[0].refs[0];
3321+
3322+
return graph->logical_limits_of(out);
3323+
}
3324+
3325+
utils::uvec3 pick_dynamic_dispatch_local_wg_size(
3326+
ComputeGraph* graph,
3327+
const std::vector<ArgGroup>& args,
3328+
const std::vector<ValueRef>& additional_args) {
3329+
return {64, 1, 1};
3330+
}
3331+
3332+
void resize_dynamic_dispatch_node(
3333+
ComputeGraph* graph,
3334+
const std::vector<ArgGroup>& args,
3335+
const std::vector<ValueRef>& additional_args) {
3336+
const ValueRef out = args[0].refs[0];
3337+
const ValueRef mat1 = args[1].refs[0];
3338+
3339+
std::vector<int64_t> out_sizes = graph->sizes_of(mat1);
3340+
out_sizes.at(out_sizes.size() - 2) = 1;
3341+
3342+
graph->get_tensor(out)->virtual_resize(out_sizes);
3343+
}
3344+
3345+
void add_dynamic_dispatch_test_node(
3346+
ComputeGraph& graph,
3347+
const ValueRef mat1,
3348+
const ValueRef mat2,
3349+
const ValueRef out) {
3350+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
3351+
graph,
3352+
pick_dynamic_dispatch_shader,
3353+
pick_dynamic_dispatch_global_wg_size,
3354+
pick_dynamic_dispatch_local_wg_size,
3355+
// Inputs and Outputs
3356+
{{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}},
3357+
// Shader params buffers
3358+
{},
3359+
// Push Constants
3360+
{graph.sizes_pc_of(out),
3361+
graph.sizes_pc_of(mat1),
3362+
graph.sizes_pc_of(mat2)},
3363+
// Specialization constants
3364+
{},
3365+
// Resize Logic
3366+
{},
3367+
resize_dynamic_dispatch_node));
3368+
}
3369+
3370+
vkcompute::ComputeGraph build_dynamic_dispatch_test_graph(int M, int N) {
3371+
using namespace vkcompute;
3372+
GraphConfig config;
3373+
ComputeGraph graph(config);
3374+
3375+
vkapi::ScalarType dtype = vkapi::kFloat;
3376+
utils::StorageType in_out_stype = utils::kTexture3D;
3377+
utils::GPUMemoryLayout memory_layout = utils::kWidthPacked;
3378+
3379+
std::vector<int64_t> mat1_size = {M, N};
3380+
std::vector<int64_t> mat2_size = {M, N};
3381+
std::vector<int64_t> out_size = {1, N};
3382+
3383+
IOValueRef mat1 =
3384+
graph.add_input_tensor(mat1_size, dtype, in_out_stype, memory_layout);
3385+
IOValueRef mat2{};
3386+
3387+
mat2.value = graph.add_tensor(mat2_size, dtype, in_out_stype, memory_layout);
3388+
mat2.staging = graph.set_input_tensor(mat2.value);
3389+
3390+
IOValueRef out;
3391+
out.value = graph.add_tensor(out_size, dtype, in_out_stype, memory_layout);
3392+
3393+
add_dynamic_dispatch_test_node(graph, mat1, mat2, out);
3394+
3395+
out.staging = graph.set_output_tensor(out.value);
3396+
3397+
return graph;
3398+
}
3399+
3400+
void test_dynamic_dispatch(int M, int N) {
3401+
ComputeGraph graph = build_dynamic_dispatch_test_graph(M, N);
3402+
3403+
graph.prepare();
3404+
graph.encode_prepack();
3405+
graph.prepack();
3406+
graph.encode_execute();
3407+
3408+
for (int i = 1; i < 4; i++) {
3409+
float val_mat1 = i;
3410+
float val_mat2 = i + 1;
3411+
// 5.3 is a hardcoded offset in the compute shader
3412+
float val_out = M * (val_mat1 * val_mat2) + 5.5;
3413+
execute_graph_and_check_output(graph, {val_mat1, val_mat2}, {val_out});
3414+
}
3415+
3416+
// Switch to GEMV mode
3417+
int new_N = N / 2;
3418+
std::vector<int64_t> new_mat1_size = {1, new_N};
3419+
std::vector<int64_t> new_mat2_size = {1, new_N};
3420+
graph.resize_input(0, new_mat1_size);
3421+
graph.resize_input(1, new_mat2_size);
3422+
graph.propagate_resize();
3423+
3424+
graph.encode_execute();
3425+
3426+
for (int i = 1; i < 4; i++) {
3427+
float val_mat1 = i;
3428+
float val_mat2 = i + 1;
3429+
float val_out = (val_mat1 * val_mat2) + 2.25;
3430+
execute_graph_and_check_output(graph, {val_mat1, val_mat2}, {val_out});
3431+
}
3432+
}
3433+
3434+
TEST(VulkanComputeGraphOpsTest, test_dynamic_dispatch_graph) {
3435+
test_dynamic_dispatch(128, 128);
3436+
}

0 commit comments

Comments
 (0)