Skip to content

Commit b2862ea

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Merge StagingNode into ExecuteNode (#2260)
Summary: bypass-github-export-checks Pull Request resolved: #2260 We dispose of `StagingNode` in favor of functions `add_staging_to_tensor_node()` and `add_tensor_to_staging_node()`, which each create an `ExecuteNode`. Hence, we fulfill our goal of making `ExecuteNode` a final class. These `add_X_node()` are not an `OpFunction` since staging is not an operator; its purpose is specific to starting and ending Vulkan execution. Note that we can't remove `encode_copy_to_vtensor()` as it's still used in ArithmeticPrepack. The prepack refactor is next. ghstack-source-id: 217439329 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D54445787 fbshipit-source-id: f455327630de2873be85d035f42efedda2810047
1 parent a5c1890 commit b2862ea

File tree

7 files changed

+237
-70
lines changed

7 files changed

+237
-70
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ ValueRef ComputeGraph::set_input_tensor(
8181
if (use_staging) {
8282
vTensor& tensor = get_val(idx).toTensor();
8383
ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel());
84-
execute_nodes_.emplace_back(new StagingNode(staging_idx, idx));
84+
add_staging_to_tensor_node(*this, staging_idx, idx);
8585
inputs_.push_back(staging_idx);
8686
return staging_idx;
8787
}
@@ -95,7 +95,7 @@ ValueRef ComputeGraph::set_output_tensor(
9595
if (use_staging) {
9696
vTensor& tensor = get_val(idx).toTensor();
9797
ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel());
98-
execute_nodes_.emplace_back(new StagingNode(idx, staging_idx));
98+
add_tensor_to_staging_node(*this, idx, staging_idx);
9999
outputs_.push_back(staging_idx);
100100
return staging_idx;
101101
}

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,10 @@ class ComputeGraph;
2828
* encoding of the shader corresponding to the op into the command buffer of a
2929
* ComputeGraph.
3030
*/
31-
class ExecuteNode {
31+
class ExecuteNode final {
3232
friend class ComputeGraph;
3333

3434
public:
35-
ExecuteNode(ValueRef input, ValueRef output)
36-
: outputs_{output}, inputs_{input} {}
37-
3835
ExecuteNode(
3936
const api::ShaderInfo& shader,
4037
const api::utils::uvec3& global_workgroup_size,
@@ -49,21 +46,19 @@ class ExecuteNode {
4946
inputs_(inputs),
5047
params_(std::move(params)) {}
5148

52-
virtual ~ExecuteNode() = default;
49+
~ExecuteNode() = default;
50+
51+
void encode(ComputeGraph* graph);
5352

5453
protected:
55-
// TODO: Consider making members const after we remove StagingNode.
56-
api::ShaderInfo shader_;
57-
api::utils::uvec3 global_workgroup_size_;
58-
api::utils::uvec3 local_workgroup_size_;
59-
std::vector<ValueRef> outputs_;
60-
std::vector<ValueRef> inputs_;
54+
const api::ShaderInfo shader_;
55+
const api::utils::uvec3 global_workgroup_size_;
56+
const api::utils::uvec3 local_workgroup_size_;
57+
const std::vector<ValueRef> outputs_;
58+
const std::vector<ValueRef> inputs_;
6159
// TODO(T180906086): pass multiple buffers and index with ValueRef.
6260
// TODO(T180906457): allow re-computing param buffers.
6361
api::UniformParamsBuffer params_;
64-
65-
public:
66-
virtual void encode(ComputeGraph* graph);
6762
};
6863

6964
} // namespace vulkan

backends/vulkan/runtime/graph/ops/Utils.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ void bind_tensor_to_descriptor_set(
3737
}
3838
}
3939

40+
void bind_staging_to_descriptor_set(
41+
api::StorageBuffer& staging,
42+
api::DescriptorSet& descriptor_set,
43+
const uint32_t idx) {
44+
descriptor_set.bind(idx, staging.buffer());
45+
}
46+
4047
uint32_t bind_values_to_descriptor_set(
4148
ComputeGraph* graph,
4249
const std::vector<ValueRef>& args,
@@ -48,9 +55,10 @@ uint32_t bind_values_to_descriptor_set(
4855
for (auto& arg : args) {
4956
Value& val = graph->get_val(arg);
5057
if (val.isTensor()) {
51-
vTensor& tensor = val.toTensor();
5258
bind_tensor_to_descriptor_set(
53-
tensor, pipeline_barrier, accessType, descriptor_set, idx++);
59+
val.toTensor(), pipeline_barrier, accessType, descriptor_set, idx++);
60+
} else if (val.isStaging()) {
61+
bind_staging_to_descriptor_set(val.toStaging(), descriptor_set, idx++);
5462
} else {
5563
VK_THROW("Unsupported type: ", val.type());
5664
}

backends/vulkan/runtime/graph/ops/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ void bind_tensor_to_descriptor_set(
3030
api::DescriptorSet& descriptor_set,
3131
const uint32_t idx);
3232

33+
void bind_staging_to_descriptor_set(
34+
api::StorageBuffer& staging,
35+
api::DescriptorSet& descriptor_set,
36+
const uint32_t idx);
37+
3338
uint32_t bind_values_to_descriptor_set(
3439
ComputeGraph* graph,
3540
const std::vector<ValueRef>& args,

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 180 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1010

11+
#include <ATen/native/vulkan/impl/Common.h>
1112
#include <ATen/native/vulkan/impl/Packing.h>
1213

1314
namespace at {
@@ -72,7 +73,7 @@ void encode_copy_to_vtensor(
7273
api::Context* context,
7374
api::StorageBuffer& staging,
7475
vTensor& tensor) {
75-
api::ShaderInfo shader = packing::get_nchw_to_image_shader(tensor);
76+
api::ShaderInfo shader = get_nchw_to_image_shader(tensor);
7677
api::PipelineBarrier pipeline_barrier{};
7778
packing::record_nchw_to_image_op(
7879
context,
@@ -83,41 +84,190 @@ void encode_copy_to_vtensor(
8384
VK_NULL_HANDLE);
8485
}
8586

86-
void encode_copy_from_vtensor(
87-
api::Context* context,
88-
vTensor& tensor,
89-
api::StorageBuffer& staging) {
90-
api::ShaderInfo shader = packing::get_image_to_nchw_shader(tensor);
91-
api::PipelineBarrier pipeline_barrier{};
92-
packing::record_image_to_nchw_op(
93-
context,
87+
struct StagingParams final {
88+
api::utils::ivec3 extents;
89+
int32_t plane_size;
90+
api::utils::ivec2 channel_info;
91+
};
92+
93+
StagingParams create_staging_params(const vTensor& t) {
94+
int32_t height = api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Height>(t));
95+
int32_t width = api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Width>(t));
96+
int32_t channels =
97+
api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Channel>(t));
98+
99+
int32_t plane_size = height * width;
100+
int32_t c_depth = api::utils::div_up(channels, 4);
101+
102+
return {
103+
api::utils::make_ivec3(t.extents()),
104+
plane_size,
105+
{c_depth, channels},
106+
};
107+
}
108+
109+
void add_staging_to_tensor_node(
110+
ComputeGraph& graph,
111+
const ValueRef in_staging,
112+
const ValueRef out_tensor) {
113+
vTensor& t_out = graph.get_val(out_tensor).toTensor();
114+
VK_CHECK_COND(graph.get_val(in_staging).isStaging());
115+
116+
api::ShaderInfo shader = get_nchw_to_image_shader(t_out);
117+
118+
api::utils::uvec3 global_size = t_out.extents();
119+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
120+
121+
api::UniformParamsBuffer params(
122+
graph.context(), create_staging_params(t_out));
123+
124+
graph.execute_nodes().emplace_back(new ExecuteNode(
94125
shader,
95-
tensor,
96-
staging.buffer(),
97-
pipeline_barrier,
98-
VK_NULL_HANDLE);
126+
global_size,
127+
local_size,
128+
{out_tensor},
129+
{in_staging},
130+
std::move(params)));
99131
}
100132

101-
StagingNode::StagingNode(ValueRef from, ValueRef to) : ExecuteNode(from, to) {}
133+
void add_tensor_to_staging_node(
134+
ComputeGraph& graph,
135+
const ValueRef in_tensor,
136+
const ValueRef out_staging) {
137+
vTensor& t_in = graph.get_val(in_tensor).toTensor();
138+
VK_CHECK_COND(graph.get_val(out_staging).isStaging());
102139

103-
void StagingNode::encode(ComputeGraph* graph) {
104-
Value& in_val = graph->get_val(inputs_[0]);
105-
Value& out_val = graph->get_val(outputs_[0]);
140+
api::ShaderInfo shader = get_image_to_nchw_shader(t_in);
141+
142+
api::utils::uvec3 global_size = t_in.extents();
143+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
144+
145+
StagingParams sp = create_staging_params(t_in);
146+
api::UniformParamsBuffer params(graph.context(), sp);
147+
148+
// TODO(T181194784): These are workgroup sizes for special cases. Refactor the
149+
// calculation of workgroup sizes to a standalone function. We should use
150+
// scalar type to get the shader name, and use the shader name to get the
151+
// workgroup size.
152+
if (t_in.dtype() == api::ScalarType::QUInt8 ||
153+
t_in.dtype() == api::ScalarType::QInt8 || t_in.dtype() == api::kBool) {
154+
if (sp.plane_size % 4 == 0) {
155+
global_size.data[0u] = sp.plane_size / 4;
156+
global_size.data[1u] = 1;
157+
local_size.data[0u] *= local_size.data[1u];
158+
local_size.data[1u] = 1;
159+
} else {
160+
uint32_t numel = t_in.numel();
161+
global_size = {api::utils::div_up(numel, uint32_t(4)), 1u, 1u};
162+
local_size = {64u, 1u, 1u};
163+
}
164+
}
165+
166+
graph.execute_nodes().emplace_back(new ExecuteNode(
167+
shader,
168+
global_size,
169+
local_size,
170+
{in_tensor},
171+
{out_staging},
172+
std::move(params)));
173+
}
174+
175+
api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) {
176+
if (v_dst.is_quantized()) {
177+
switch (v_dst.storage_type()) {
178+
case api::StorageType::TEXTURE_3D:
179+
switch (v_dst.dtype()) {
180+
case api::ScalarType::QUInt8:
181+
return VK_KERNEL(nchw_to_image_uint8);
182+
case api::ScalarType::QInt8:
183+
return VK_KERNEL(nchw_to_image_int8);
184+
case api::ScalarType::QInt32:
185+
return VK_KERNEL(nchw_to_image_int32);
186+
default:
187+
VK_THROW(
188+
"Vulkan quantization currently not supported for dtype ",
189+
v_dst.dtype());
190+
}
191+
case api::StorageType::TEXTURE_2D:
192+
switch (v_dst.dtype()) {
193+
case api::ScalarType::QUInt8:
194+
return VK_KERNEL(nchw_to_image2d_uint8);
195+
case api::ScalarType::QInt8:
196+
return VK_KERNEL(nchw_to_image2d_int8);
197+
case api::ScalarType::QInt32:
198+
return VK_KERNEL(nchw_to_image2d_int32);
199+
default:
200+
VK_THROW(
201+
"Vulkan quantization currently not supported for dtype ",
202+
v_dst.dtype());
203+
}
204+
default:
205+
VK_THROW("No kernel available!");
206+
case api::StorageType::BUFFER:
207+
case api::StorageType::UNKNOWN:
208+
VK_THROW("Requested storage type must be a texture type.");
209+
}
210+
}
211+
212+
if (v_dst.dtype() == api::kFloat) {
213+
switch (v_dst.storage_type()) {
214+
case api::StorageType::TEXTURE_3D:
215+
return VK_KERNEL(nchw_to_image);
216+
case api::StorageType::TEXTURE_2D:
217+
return VK_KERNEL(nchw_to_image2d);
218+
default:
219+
VK_THROW("No kernel available!");
220+
}
221+
} else if (v_dst.dtype() == api::kBool) {
222+
switch (v_dst.storage_type()) {
223+
case api::StorageType::TEXTURE_3D:
224+
return VK_KERNEL(nchw_to_image_bool);
225+
default:
226+
VK_THROW("No kernel available!");
227+
}
228+
} else {
229+
VK_THROW("Unsupported dtype!");
230+
}
231+
}
232+
233+
api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src) {
234+
if (v_src.is_quantized() || v_src.dtype() == api::kBool) {
235+
auto plane_size =
236+
dim_at<Dim4D::Height>(v_src) * dim_at<Dim4D::Width>(v_src);
237+
switch (v_src.storage_type()) {
238+
case api::StorageType::TEXTURE_3D:
239+
switch (v_src.dtype()) {
240+
case api::ScalarType::QUInt8:
241+
case api::ScalarType::QInt8:
242+
case api::kBool:
243+
return plane_size % 4 == 0 ? VK_KERNEL(image_to_nchw_quantized_mul4)
244+
: VK_KERNEL(image_to_nchw_uint);
245+
case api::ScalarType::QInt32:
246+
return VK_KERNEL(image_to_nchw_int32);
247+
default:
248+
VK_THROW(
249+
"Vulkan quantization currently not supported for dtype ",
250+
v_src.dtype());
251+
}
252+
default:
253+
VK_THROW("No kernel available!");
254+
case api::StorageType::BUFFER:
255+
case api::StorageType::UNKNOWN:
256+
VK_THROW("Requested storage type must be a texture type.");
257+
}
258+
}
106259

107-
if (in_val.isStaging() && out_val.isTensor()) {
108-
api::StorageBuffer& from_staging = graph->get_val(inputs_[0]).toStaging();
109-
vTensor& to_tensor = graph->get_val(outputs_[0]).toTensor();
110-
encode_copy_to_vtensor(graph->context(), from_staging, to_tensor);
111-
} else if (in_val.isTensor() && out_val.isStaging()) {
112-
vTensor& from_tensor = graph->get_val(inputs_[0]).toTensor();
113-
api::StorageBuffer& to_staging = graph->get_val(outputs_[0]).toStaging();
114-
encode_copy_from_vtensor(graph->context(), from_tensor, to_staging);
260+
if (v_src.dtype() == api::kFloat) {
261+
switch (v_src.storage_type()) {
262+
case api::StorageType::TEXTURE_3D:
263+
return VK_KERNEL(image_to_nchw);
264+
case api::StorageType::TEXTURE_2D:
265+
return VK_KERNEL(image2d_to_nchw);
266+
default:
267+
VK_THROW("No kernel available!");
268+
}
115269
} else {
116-
VK_THROW(
117-
"Unexpected input value type ",
118-
in_val.type(),
119-
" and output value type ",
120-
out_val.type());
270+
VK_THROW("Unsupported dtype!");
121271
}
122272
}
123273

backends/vulkan/runtime/graph/ops/impl/Staging.h

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1414

15-
#include <string.h>
15+
#include <cstring>
1616

1717
namespace at {
1818
namespace native {
@@ -76,20 +76,26 @@ void encode_copy_to_vtensor(
7676
api::Context* context,
7777
api::StorageBuffer& staging,
7878
vTensor& tensor);
79-
void encode_copy_from_vtensor(
80-
api::Context* context,
81-
vTensor& tensor,
82-
api::StorageBuffer& staging);
8379

84-
/*
85-
* OpNode that allows copying data into and out of a staging buffer.
86-
*/
87-
class StagingNode : public virtual ExecuteNode {
88-
public:
89-
explicit StagingNode(ValueRef from, ValueRef to);
80+
//
81+
// Functions to initialize ExecuteNode
82+
//
83+
84+
void add_staging_to_tensor_node(
85+
ComputeGraph& graph,
86+
const ValueRef in_staging,
87+
const ValueRef out_tensor);
88+
void add_tensor_to_staging_node(
89+
ComputeGraph& graph,
90+
const ValueRef in_tensor,
91+
const ValueRef out_staging);
92+
93+
//
94+
// Functions to get shaders
95+
//
9096

91-
void encode(ComputeGraph* graph) override;
92-
};
97+
api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst);
98+
api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src);
9399

94100
} // namespace vulkan
95101
} // namespace native

0 commit comments

Comments
 (0)