Skip to content

Commit 862f755

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Merge ArithmeticNode into ExecuteNode (#2247)
Summary: bypass-github-export-checks Pull Request resolved: #2247 This diff moves the logic of `ArithmeticNode` into its corresponding OpFunction `add_arithmetic_node()` and the `ExecuteNode` class. Our aim is to remove all derived classes of `ExecuteNode`, i.e., to make `ExecuteNode` a final class. All operator-specific logic will be handled in the OpFunction. Note the next change will move `StagingNode` into its OpFunction + this new ExecuteNode implementation. Until then, we can't tidy up the `ExecuteNode` class fully. Finally, we leave a few task TODOs. ghstack-source-id: 217439330 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D53982441 fbshipit-source-id: b8a51eee538b679e4168864a4870f3921c9ba333
1 parent fae9ef0 commit 862f755

File tree

9 files changed

+212
-77
lines changed

9 files changed

+212
-77
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>
14+
15+
namespace at {
16+
namespace native {
17+
namespace vulkan {
18+
19+
void ExecuteNode::encode(ComputeGraph* graph) {
20+
api::Context* const context = graph->context();
21+
api::PipelineBarrier pipeline_barrier{};
22+
23+
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
24+
25+
api::DescriptorSet descriptor_set =
26+
context->get_descriptor_set(shader_, local_workgroup_size_);
27+
28+
uint32_t idx = 0;
29+
idx = bind_values_to_descriptor_set(
30+
graph,
31+
outputs_,
32+
pipeline_barrier,
33+
api::MemoryAccessType::WRITE,
34+
descriptor_set,
35+
idx);
36+
idx = bind_values_to_descriptor_set(
37+
graph,
38+
inputs_,
39+
pipeline_barrier,
40+
api::MemoryAccessType::READ,
41+
descriptor_set,
42+
idx);
43+
descriptor_set.bind(idx, params_.buffer());
44+
45+
context->register_shader_dispatch(
46+
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
47+
}
48+
49+
} // namespace vulkan
50+
} // namespace native
51+
} // namespace at

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,37 @@ class ExecuteNode {
3333

3434
public:
3535
ExecuteNode(ValueRef input, ValueRef output)
36-
: inputs_{input}, outputs_{output} {}
36+
: outputs_{output}, inputs_{input} {}
37+
3738
ExecuteNode(
39+
const api::ShaderInfo& shader,
40+
const api::utils::uvec3& global_workgroup_size,
41+
const api::utils::uvec3& local_workgroup_size,
42+
const std::vector<ValueRef>& outputs,
3843
const std::vector<ValueRef>& inputs,
39-
const std::vector<ValueRef>& outputs)
40-
: inputs_(inputs), outputs_(outputs) {}
44+
api::UniformParamsBuffer&& params)
45+
: shader_(shader),
46+
global_workgroup_size_(global_workgroup_size),
47+
local_workgroup_size_(local_workgroup_size),
48+
outputs_(outputs),
49+
inputs_(inputs),
50+
params_(std::move(params)) {}
4151

4252
virtual ~ExecuteNode() = default;
4353

4454
protected:
45-
std::vector<ValueRef> inputs_;
55+
// TODO: Consider making members const after we remove StagingNode.
56+
api::ShaderInfo shader_;
57+
api::utils::uvec3 global_workgroup_size_;
58+
api::utils::uvec3 local_workgroup_size_;
4659
std::vector<ValueRef> outputs_;
60+
std::vector<ValueRef> inputs_;
61+
// TODO(T180906086): pass multiple buffers and index with ValueRef.
62+
// TODO(T180906457): allow re-computing param buffers.
63+
api::UniformParamsBuffer params_;
4764

4865
public:
49-
virtual void encode(ComputeGraph* graph) const = 0;
66+
virtual void encode(ComputeGraph* graph);
5067
};
5168

5269
} // namespace vulkan
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>
10+
11+
namespace at {
12+
namespace native {
13+
namespace vulkan {
14+
15+
api::utils::ivec4 get_size_as_ivec4(const vTensor& t) {
16+
return api::utils::make_ivec4(
17+
{dim_at<Dim4D::Width>(t),
18+
dim_at<Dim4D::Height>(t),
19+
dim_at<Dim4D::Channel>(t),
20+
dim_at<Dim4D::Batch>(t)});
21+
}
22+
23+
void bind_tensor_to_descriptor_set(
24+
vTensor& tensor,
25+
api::PipelineBarrier& pipeline_barrier,
26+
const api::MemoryAccessType accessType,
27+
api::DescriptorSet& descriptor_set,
28+
const uint32_t idx) {
29+
if (tensor.buffer()) {
30+
api::VulkanBuffer& buffer = tensor.buffer(
31+
pipeline_barrier, api::PipelineStage::COMPUTE, accessType);
32+
descriptor_set.bind(idx, buffer);
33+
} else {
34+
api::VulkanImage& image =
35+
tensor.image(pipeline_barrier, api::PipelineStage::COMPUTE, accessType);
36+
descriptor_set.bind(idx, image);
37+
}
38+
}
39+
40+
uint32_t bind_values_to_descriptor_set(
41+
ComputeGraph* graph,
42+
const std::vector<ValueRef>& args,
43+
api::PipelineBarrier& pipeline_barrier,
44+
const api::MemoryAccessType accessType,
45+
api::DescriptorSet& descriptor_set,
46+
const uint32_t base_idx) {
47+
uint32_t idx = base_idx;
48+
for (auto& arg : args) {
49+
Value& val = graph->get_val(arg);
50+
if (val.isTensor()) {
51+
vTensor& tensor = val.toTensor();
52+
bind_tensor_to_descriptor_set(
53+
tensor, pipeline_barrier, accessType, descriptor_set, idx++);
54+
} else {
55+
VK_THROW("Unsupported type: ", val.type());
56+
}
57+
}
58+
return idx;
59+
}
60+
61+
} // namespace vulkan
62+
} // namespace native
63+
} // namespace at

backends/vulkan/runtime/graph/ops/Utils.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
#ifdef USE_VULKAN_API
1212

13-
#include <ATen/native/vulkan/impl/Arithmetic.h>
13+
#include <ATen/native/vulkan/impl/Common.h>
1414

1515
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1616

@@ -21,6 +21,23 @@ namespace vulkan {
2121
#define DECLARE_OP_FN(function) \
2222
ValueRef function(ComputeGraph& graph, const std::vector<ValueRef>& args);
2323

24+
api::utils::ivec4 get_size_as_ivec4(const vTensor& t);
25+
26+
void bind_tensor_to_descriptor_set(
27+
vTensor& tensor,
28+
api::PipelineBarrier& pipeline_barrier,
29+
const api::MemoryAccessType accessType,
30+
api::DescriptorSet& descriptor_set,
31+
const uint32_t idx);
32+
33+
uint32_t bind_values_to_descriptor_set(
34+
ComputeGraph* graph,
35+
const std::vector<ValueRef>& args,
36+
api::PipelineBarrier& pipeline_barrier,
37+
const api::MemoryAccessType accessType,
38+
api::DescriptorSet& descriptor_set,
39+
const uint32_t base_idx);
40+
2441
} // namespace vulkan
2542
} // namespace native
2643
} // namespace at

backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,44 +8,39 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h>
1010

11-
#include <ATen/native/vulkan/impl/Common.h>
12-
1311
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1412

1513
namespace at {
1614
namespace native {
1715
namespace vulkan {
1816

19-
#define DEFINE_ARITHMETIC_FN(function, op_type) \
17+
#define DEFINE_ARITHMETIC_FN(function, shader) \
2018
ValueRef function(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
2119
return add_arithmetic_node( \
22-
graph, \
23-
args[0], \
24-
args[1], \
25-
args[2], \
26-
arithmetic::OpType::op_type, \
27-
args[3]); \
20+
graph, args[0], args[1], args[2], VK_KERNEL(shader), args[3]); \
2821
}
2922

30-
DEFINE_ARITHMETIC_FN(add, ADD);
31-
DEFINE_ARITHMETIC_FN(sub, SUB);
32-
DEFINE_ARITHMETIC_FN(mul, MUL);
33-
DEFINE_ARITHMETIC_FN(div, DIV);
34-
DEFINE_ARITHMETIC_FN(floor_div, FLOOR_DIV);
35-
DEFINE_ARITHMETIC_FN(pow, POW);
23+
DEFINE_ARITHMETIC_FN(add, add);
24+
DEFINE_ARITHMETIC_FN(sub, sub);
25+
DEFINE_ARITHMETIC_FN(mul, mul);
26+
DEFINE_ARITHMETIC_FN(div, div);
27+
DEFINE_ARITHMETIC_FN(floor_div, floor_divide);
28+
DEFINE_ARITHMETIC_FN(pow, pow);
3629

30+
// TODO(T180908843): Bypass this entrypoint function by creating `ValueRef out`
31+
// ahead of time.
3732
ValueRef add_arithmetic_node(
3833
ComputeGraph& graph,
3934
const ValueRef in1,
4035
const ValueRef in2,
4136
const float alpha,
42-
const arithmetic::OpType optype,
37+
const api::ShaderInfo& shader,
4338
const int64_t shared_object_idx) {
4439
std::vector<int64_t> in1_sizes = graph.get_val_sizes(in1);
4540
api::ScalarType in1_dtype = graph.get_val_dtype(in1);
4641

4742
ValueRef out = graph.add_tensor(in1_sizes, in1_dtype, shared_object_idx);
48-
add_arithmetic_node(graph, in1, in2, out, alpha, optype);
43+
add_arithmetic_node(graph, in1, in2, out, alpha, shader);
4944
return out;
5045
}
5146

@@ -67,12 +62,27 @@ void add_arithmetic_node(
6762
const ValueRef in2,
6863
const ValueRef out,
6964
const float alpha,
70-
const arithmetic::OpType optype) {
65+
const api::ShaderInfo& shader) {
7166
ValueRef arg1 = prepack_if_tensor_ref(graph, in1);
7267
ValueRef arg2 = prepack_if_tensor_ref(graph, in2);
7368

74-
graph.execute_nodes().emplace_back(
75-
new ArithmeticNode(arg1, arg2, out, alpha, optype));
69+
vTensor& t_in1 = graph.get_val(arg1).toTensor();
70+
vTensor& t_in2 = graph.get_val(arg2).toTensor();
71+
vTensor& t_out = graph.get_val(out).toTensor();
72+
73+
api::utils::uvec3 global_size = t_out.extents();
74+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
75+
76+
ArithmeticParams block{
77+
get_size_as_ivec4(t_out),
78+
get_size_as_ivec4(t_in1),
79+
get_size_as_ivec4(t_in2),
80+
1.0,
81+
};
82+
api::UniformParamsBuffer params(graph.context(), block);
83+
84+
graph.execute_nodes().emplace_back(new ExecuteNode(
85+
shader, global_size, local_size, {out}, {arg1, arg2}, std::move(params)));
7686
}
7787

7888
ArithmeticPrepack::ArithmeticPrepack(const ValueRef tref, const ValueRef packed)
@@ -92,23 +102,6 @@ void ArithmeticPrepack::encode(ComputeGraph* graph) const {
92102
encode_copy_to_vtensor(graph->context(), staging, packed);
93103
}
94104

95-
ArithmeticNode::ArithmeticNode(
96-
const ValueRef in1,
97-
const ValueRef in2,
98-
const ValueRef out,
99-
const float alpha,
100-
const arithmetic::OpType optype)
101-
: ExecuteNode({in1, in2}, {out}), alpha_(alpha), optype_(optype) {}
102-
103-
void ArithmeticNode::encode(ComputeGraph* graph) const {
104-
vTensor& in1 = graph->get_val(inputs_[0]).toTensor();
105-
vTensor& in2 = graph->get_val(inputs_[1]).toTensor();
106-
vTensor& out = graph->get_val(outputs_[0]).toTensor();
107-
108-
api::ShaderInfo kernel = arithmetic::get_shader(optype_);
109-
arithmetic::record_op(graph->context(), kernel, in1, in2, out, alpha_);
110-
}
111-
112105
} // namespace vulkan
113106
} // namespace native
114107
} // namespace at

backends/vulkan/runtime/graph/ops/impl/Arithmetic.h

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ ValueRef add_arithmetic_node(
3232
const ValueRef in1,
3333
const ValueRef in2,
3434
const float alpha,
35-
const arithmetic::OpType optype,
35+
const api::ShaderInfo& shader,
3636
const int64_t shared_object_idx = -1);
3737

3838
void add_arithmetic_node(
@@ -41,29 +41,20 @@ void add_arithmetic_node(
4141
const ValueRef in2,
4242
const ValueRef out,
4343
const float alpha,
44-
const arithmetic::OpType optype);
44+
const api::ShaderInfo& shader);
4545

46-
class ArithmeticPrepack : public virtual PrepackNode {
47-
public:
48-
explicit ArithmeticPrepack(const ValueRef tref, const ValueRef packed);
49-
50-
void encode(ComputeGraph* graph) const override;
46+
struct ArithmeticParams final {
47+
api::utils::ivec4 outputSizes;
48+
api::utils::ivec4 input1Sizes;
49+
api::utils::ivec4 input2Sizes;
50+
float alpha;
5151
};
5252

53-
class ArithmeticNode : public virtual ExecuteNode {
53+
class ArithmeticPrepack : public virtual PrepackNode {
5454
public:
55-
explicit ArithmeticNode(
56-
const ValueRef in1,
57-
const ValueRef in2,
58-
const ValueRef out,
59-
const float alpha,
60-
const arithmetic::OpType optype);
55+
explicit ArithmeticPrepack(const ValueRef tref, const ValueRef packed);
6156

6257
void encode(ComputeGraph* graph) const override;
63-
64-
private:
65-
float alpha_;
66-
arithmetic::OpType optype_;
6758
};
6859

6960
} // namespace vulkan

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ void encode_copy_from_vtensor(
100100

101101
StagingNode::StagingNode(ValueRef from, ValueRef to) : ExecuteNode(from, to) {}
102102

103-
void StagingNode::encode(ComputeGraph* graph) const {
103+
void StagingNode::encode(ComputeGraph* graph) {
104104
Value& in_val = graph->get_val(inputs_[0]);
105105
Value& out_val = graph->get_val(outputs_[0]);
106106

backends/vulkan/runtime/graph/ops/impl/Staging.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class StagingNode : public virtual ExecuteNode {
8888
public:
8989
explicit StagingNode(ValueRef from, ValueRef to);
9090

91-
void encode(ComputeGraph* graph) const override;
91+
void encode(ComputeGraph* graph) override;
9292
};
9393

9494
} // namespace vulkan

0 commit comments

Comments
 (0)