Skip to content

Commit d445a34

Browse files
committed
[ET-VK][Op Redesign][7/n] Generalize ExecuteNode args with ArgGroup
Leftover from Op Redesign 5/n - D54445787. --- Typically, we specify outputs first and inputs second in the shader layout, but not always. In `image_to_nchw.glsl`, this is flipped: https://www.internalfb.com/code/fbsource/[d303d229f22616bfba32e5bb5d4d27dc656f41a7]/fbcode/caffe2/aten/src/ATen/native/vulkan/glsl/image_to_nchw.glsl?lines=8-19 Hence, we generalize our `ExecuteNode` specification to take a vector of args (image, buffer, etc.), with specification of access type. Since typically we will group args of the same access together, we correspond one access specification to multiple args. We reuse `api::MemoryAccessType` for access specification. Differential Revision: [D54518840](https://our.internmc.facebook.com/intern/diff/D54518840/) ghstack-source-id: 217489419 Pull Request resolved: #2262
1 parent 91c3d65 commit d445a34

File tree

6 files changed

+47
-36
lines changed

6 files changed

+47
-36
lines changed

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,7 @@ void ExecuteNode::encode(ComputeGraph* graph) {
2727

2828
uint32_t idx = 0;
2929
idx = bind_values_to_descriptor_set(
30-
graph,
31-
outputs_,
32-
pipeline_barrier,
33-
api::MemoryAccessType::WRITE,
34-
descriptor_set,
35-
idx);
36-
idx = bind_values_to_descriptor_set(
37-
graph,
38-
inputs_,
39-
pipeline_barrier,
40-
api::MemoryAccessType::READ,
41-
descriptor_set,
42-
idx);
30+
graph, args_, pipeline_barrier, descriptor_set, idx);
4331
descriptor_set.bind(idx, params_.buffer());
4432

4533
context->register_shader_dispatch(

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,23 @@ namespace vulkan {
2222

2323
class ComputeGraph;
2424

25+
/*
26+
* Represents a group of shader arguments (images and/or buffers), with a common
27+
* access permission.
28+
*/
29+
struct ArgGroup {
30+
ArgGroup(const ValueRef ref, const api::MemoryAccessType access)
31+
: refs{ref}, access(access) {}
32+
33+
ArgGroup(
34+
const std::vector<ValueRef>& refs,
35+
const api::MemoryAccessType access)
36+
: refs(refs), access(access) {}
37+
38+
const std::vector<ValueRef> refs;
39+
const api::MemoryAccessType access;
40+
};
41+
2542
/*
2643
* Represents a single execution op in a ML model. In graph mode, ops will be
2744
* implemented in a derived class that implements encode, which will implement
@@ -36,14 +53,12 @@ class ExecuteNode final {
3653
const api::ShaderInfo& shader,
3754
const api::utils::uvec3& global_workgroup_size,
3855
const api::utils::uvec3& local_workgroup_size,
39-
const std::vector<ValueRef>& outputs,
40-
const std::vector<ValueRef>& inputs,
56+
const std::vector<ArgGroup>& args,
4157
api::UniformParamsBuffer&& params)
4258
: shader_(shader),
4359
global_workgroup_size_(global_workgroup_size),
4460
local_workgroup_size_(local_workgroup_size),
45-
outputs_(outputs),
46-
inputs_(inputs),
61+
args_(args),
4762
params_(std::move(params)) {}
4863

4964
~ExecuteNode() = default;
@@ -54,8 +69,7 @@ class ExecuteNode final {
5469
const api::ShaderInfo shader_;
5570
const api::utils::uvec3 global_workgroup_size_;
5671
const api::utils::uvec3 local_workgroup_size_;
57-
const std::vector<ValueRef> outputs_;
58-
const std::vector<ValueRef> inputs_;
72+
const std::vector<ArgGroup> args_;
5973
// TODO(T180906086): pass multiple buffers and index with ValueRef.
6074
// TODO(T180906457): allow re-computing param buffers.
6175
api::UniformParamsBuffer params_;

backends/vulkan/runtime/graph/ops/Utils.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,26 @@ void bind_staging_to_descriptor_set(
4646

4747
uint32_t bind_values_to_descriptor_set(
4848
ComputeGraph* graph,
49-
const std::vector<ValueRef>& args,
49+
const std::vector<ArgGroup>& args,
5050
api::PipelineBarrier& pipeline_barrier,
51-
const api::MemoryAccessType accessType,
5251
api::DescriptorSet& descriptor_set,
5352
const uint32_t base_idx) {
5453
uint32_t idx = base_idx;
5554
for (auto& arg : args) {
56-
Value& val = graph->get_val(arg);
57-
if (val.isTensor()) {
58-
bind_tensor_to_descriptor_set(
59-
val.toTensor(), pipeline_barrier, accessType, descriptor_set, idx++);
60-
} else if (val.isStaging()) {
61-
bind_staging_to_descriptor_set(val.toStaging(), descriptor_set, idx++);
62-
} else {
63-
VK_THROW("Unsupported type: ", val.type());
55+
for (auto& ref : arg.refs) {
56+
Value& val = graph->get_val(ref);
57+
if (val.isTensor()) {
58+
bind_tensor_to_descriptor_set(
59+
val.toTensor(),
60+
pipeline_barrier,
61+
arg.access,
62+
descriptor_set,
63+
idx++);
64+
} else if (val.isStaging()) {
65+
bind_staging_to_descriptor_set(val.toStaging(), descriptor_set, idx++);
66+
} else {
67+
VK_THROW("Unsupported type: ", val.type());
68+
}
6469
}
6570
}
6671
return idx;

backends/vulkan/runtime/graph/ops/Utils.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@ void bind_staging_to_descriptor_set(
3737

3838
uint32_t bind_values_to_descriptor_set(
3939
ComputeGraph* graph,
40-
const std::vector<ValueRef>& args,
40+
const std::vector<ArgGroup>& args,
4141
api::PipelineBarrier& pipeline_barrier,
42-
const api::MemoryAccessType accessType,
4342
api::DescriptorSet& descriptor_set,
4443
const uint32_t base_idx);
4544

backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,12 @@ void add_arithmetic_node(
7070
api::UniformParamsBuffer params(graph.context(), block);
7171

7272
graph.execute_nodes().emplace_back(new ExecuteNode(
73-
shader, global_size, local_size, {out}, {arg1, arg2}, std::move(params)));
73+
shader,
74+
global_size,
75+
local_size,
76+
{{out, api::MemoryAccessType::WRITE},
77+
{{arg1, arg2}, api::MemoryAccessType::READ}},
78+
std::move(params)));
7479
}
7580

7681
} // namespace vulkan

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ void add_staging_to_tensor_node(
5252
shader,
5353
global_size,
5454
local_size,
55-
{out_tensor},
56-
{in_staging},
55+
{{out_tensor, api::MemoryAccessType::WRITE},
56+
{in_staging, api::MemoryAccessType::READ}},
5757
std::move(params)));
5858
}
5959

@@ -94,8 +94,8 @@ void add_tensor_to_staging_node(
9494
shader,
9595
global_size,
9696
local_size,
97-
{in_tensor},
98-
{out_staging},
97+
{{in_tensor, api::MemoryAccessType::READ},
98+
{out_staging, api::MemoryAccessType::WRITE}},
9999
std::move(params)));
100100
}
101101

0 commit comments

Comments
 (0)