8
8
9
9
#include < executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
10
10
11
+ #include < ATen/native/vulkan/impl/Common.h>
11
12
#include < ATen/native/vulkan/impl/Packing.h>
12
13
13
14
namespace at {
@@ -72,7 +73,7 @@ void encode_copy_to_vtensor(
72
73
api::Context* context,
73
74
api::StorageBuffer& staging,
74
75
vTensor& tensor) {
75
- api::ShaderInfo shader = packing:: get_nchw_to_image_shader (tensor);
76
+ api::ShaderInfo shader = get_nchw_to_image_shader (tensor);
76
77
api::PipelineBarrier pipeline_barrier{};
77
78
packing::record_nchw_to_image_op (
78
79
context,
@@ -83,41 +84,190 @@ void encode_copy_to_vtensor(
83
84
VK_NULL_HANDLE);
84
85
}
85
86
86
- void encode_copy_from_vtensor (
87
- api::Context* context,
88
- vTensor& tensor,
89
- api::StorageBuffer& staging) {
90
- api::ShaderInfo shader = packing::get_image_to_nchw_shader (tensor);
91
- api::PipelineBarrier pipeline_barrier{};
92
- packing::record_image_to_nchw_op (
93
- context,
87
+ struct StagingParams final {
88
+ api::utils::ivec3 extents;
89
+ int32_t plane_size;
90
+ api::utils::ivec2 channel_info;
91
+ };
92
+
93
+ StagingParams create_staging_params (const vTensor& t) {
94
+ int32_t height = api::utils::safe_downcast<int32_t >(dim_at<Dim4D::Height>(t));
95
+ int32_t width = api::utils::safe_downcast<int32_t >(dim_at<Dim4D::Width>(t));
96
+ int32_t channels =
97
+ api::utils::safe_downcast<int32_t >(dim_at<Dim4D::Channel>(t));
98
+
99
+ int32_t plane_size = height * width;
100
+ int32_t c_depth = api::utils::div_up (channels, 4 );
101
+
102
+ return {
103
+ api::utils::make_ivec3 (t.extents ()),
104
+ plane_size,
105
+ {c_depth, channels},
106
+ };
107
+ }
108
+
109
+ void add_staging_to_tensor_node (
110
+ ComputeGraph& graph,
111
+ const ValueRef in_staging,
112
+ const ValueRef out_tensor) {
113
+ vTensor& t_out = graph.get_val (out_tensor).toTensor ();
114
+ VK_CHECK_COND (graph.get_val (in_staging).isStaging ());
115
+
116
+ api::ShaderInfo shader = get_nchw_to_image_shader (t_out);
117
+
118
+ api::utils::uvec3 global_size = t_out.extents ();
119
+ api::utils::uvec3 local_size = adaptive_work_group_size (global_size);
120
+
121
+ api::UniformParamsBuffer params (
122
+ graph.context (), create_staging_params (t_out));
123
+
124
+ graph.execute_nodes ().emplace_back (new ExecuteNode (
94
125
shader,
95
- tensor,
96
- staging.buffer (),
97
- pipeline_barrier,
98
- VK_NULL_HANDLE);
126
+ global_size,
127
+ local_size,
128
+ {out_tensor},
129
+ {in_staging},
130
+ std::move (params)));
99
131
}
100
132
101
- StagingNode::StagingNode (ValueRef from, ValueRef to) : ExecuteNode(from, to) {}
133
+ void add_tensor_to_staging_node (
134
+ ComputeGraph& graph,
135
+ const ValueRef in_tensor,
136
+ const ValueRef out_staging) {
137
+ vTensor& t_in = graph.get_val (in_tensor).toTensor ();
138
+ VK_CHECK_COND (graph.get_val (out_staging).isStaging ());
102
139
103
- void StagingNode::encode (ComputeGraph* graph) {
104
- Value& in_val = graph->get_val (inputs_[0 ]);
105
- Value& out_val = graph->get_val (outputs_[0 ]);
140
+ api::ShaderInfo shader = get_image_to_nchw_shader (t_in);
141
+
142
+ api::utils::uvec3 global_size = t_in.extents ();
143
+ api::utils::uvec3 local_size = adaptive_work_group_size (global_size);
144
+
145
+ StagingParams sp = create_staging_params (t_in);
146
+ api::UniformParamsBuffer params (graph.context (), sp);
147
+
148
+ // TODO(T181194784): These are workgroup sizes for special cases. Refactor the
149
+ // calculation of workgroup sizes to a standalone function. We should use
150
+ // scalar type to get the shader name, and use the shader name to get the
151
+ // workgroup size.
152
+ if (t_in.dtype () == api::ScalarType::QUInt8 ||
153
+ t_in.dtype () == api::ScalarType::QInt8 || t_in.dtype () == api::kBool ) {
154
+ if (sp.plane_size % 4 == 0 ) {
155
+ global_size.data [0u ] = sp.plane_size / 4 ;
156
+ global_size.data [1u ] = 1 ;
157
+ local_size.data [0u ] *= local_size.data [1u ];
158
+ local_size.data [1u ] = 1 ;
159
+ } else {
160
+ uint32_t numel = t_in.numel ();
161
+ global_size = {api::utils::div_up (numel, uint32_t (4 )), 1u , 1u };
162
+ local_size = {64u , 1u , 1u };
163
+ }
164
+ }
165
+
166
+ graph.execute_nodes ().emplace_back (new ExecuteNode (
167
+ shader,
168
+ global_size,
169
+ local_size,
170
+ {in_tensor},
171
+ {out_staging},
172
+ std::move (params)));
173
+ }
174
+
175
+ api::ShaderInfo get_nchw_to_image_shader (const vTensor& v_dst) {
176
+ if (v_dst.is_quantized ()) {
177
+ switch (v_dst.storage_type ()) {
178
+ case api::StorageType::TEXTURE_3D:
179
+ switch (v_dst.dtype ()) {
180
+ case api::ScalarType::QUInt8:
181
+ return VK_KERNEL (nchw_to_image_uint8);
182
+ case api::ScalarType::QInt8:
183
+ return VK_KERNEL (nchw_to_image_int8);
184
+ case api::ScalarType::QInt32:
185
+ return VK_KERNEL (nchw_to_image_int32);
186
+ default :
187
+ VK_THROW (
188
+ " Vulkan quantization currently not supported for dtype " ,
189
+ v_dst.dtype ());
190
+ }
191
+ case api::StorageType::TEXTURE_2D:
192
+ switch (v_dst.dtype ()) {
193
+ case api::ScalarType::QUInt8:
194
+ return VK_KERNEL (nchw_to_image2d_uint8);
195
+ case api::ScalarType::QInt8:
196
+ return VK_KERNEL (nchw_to_image2d_int8);
197
+ case api::ScalarType::QInt32:
198
+ return VK_KERNEL (nchw_to_image2d_int32);
199
+ default :
200
+ VK_THROW (
201
+ " Vulkan quantization currently not supported for dtype " ,
202
+ v_dst.dtype ());
203
+ }
204
+ default :
205
+ VK_THROW (" No kernel available!" );
206
+ case api::StorageType::BUFFER:
207
+ case api::StorageType::UNKNOWN:
208
+ VK_THROW (" Requested storage type must be a texture type." );
209
+ }
210
+ }
211
+
212
+ if (v_dst.dtype () == api::kFloat ) {
213
+ switch (v_dst.storage_type ()) {
214
+ case api::StorageType::TEXTURE_3D:
215
+ return VK_KERNEL (nchw_to_image);
216
+ case api::StorageType::TEXTURE_2D:
217
+ return VK_KERNEL (nchw_to_image2d);
218
+ default :
219
+ VK_THROW (" No kernel available!" );
220
+ }
221
+ } else if (v_dst.dtype () == api::kBool ) {
222
+ switch (v_dst.storage_type ()) {
223
+ case api::StorageType::TEXTURE_3D:
224
+ return VK_KERNEL (nchw_to_image_bool);
225
+ default :
226
+ VK_THROW (" No kernel available!" );
227
+ }
228
+ } else {
229
+ VK_THROW (" Unsupported dtype!" );
230
+ }
231
+ }
232
+
233
+ api::ShaderInfo get_image_to_nchw_shader (const vTensor& v_src) {
234
+ if (v_src.is_quantized () || v_src.dtype () == api::kBool ) {
235
+ auto plane_size =
236
+ dim_at<Dim4D::Height>(v_src) * dim_at<Dim4D::Width>(v_src);
237
+ switch (v_src.storage_type ()) {
238
+ case api::StorageType::TEXTURE_3D:
239
+ switch (v_src.dtype ()) {
240
+ case api::ScalarType::QUInt8:
241
+ case api::ScalarType::QInt8:
242
+ case api::kBool :
243
+ return plane_size % 4 == 0 ? VK_KERNEL (image_to_nchw_quantized_mul4)
244
+ : VK_KERNEL (image_to_nchw_uint);
245
+ case api::ScalarType::QInt32:
246
+ return VK_KERNEL (image_to_nchw_int32);
247
+ default :
248
+ VK_THROW (
249
+ " Vulkan quantization currently not supported for dtype " ,
250
+ v_src.dtype ());
251
+ }
252
+ default :
253
+ VK_THROW (" No kernel available!" );
254
+ case api::StorageType::BUFFER:
255
+ case api::StorageType::UNKNOWN:
256
+ VK_THROW (" Requested storage type must be a texture type." );
257
+ }
258
+ }
106
259
107
- if (in_val.isStaging () && out_val.isTensor ()) {
108
- api::StorageBuffer& from_staging = graph->get_val (inputs_[0 ]).toStaging ();
109
- vTensor& to_tensor = graph->get_val (outputs_[0 ]).toTensor ();
110
- encode_copy_to_vtensor (graph->context (), from_staging, to_tensor);
111
- } else if (in_val.isTensor () && out_val.isStaging ()) {
112
- vTensor& from_tensor = graph->get_val (inputs_[0 ]).toTensor ();
113
- api::StorageBuffer& to_staging = graph->get_val (outputs_[0 ]).toStaging ();
114
- encode_copy_from_vtensor (graph->context (), from_tensor, to_staging);
260
+ if (v_src.dtype () == api::kFloat ) {
261
+ switch (v_src.storage_type ()) {
262
+ case api::StorageType::TEXTURE_3D:
263
+ return VK_KERNEL (image_to_nchw);
264
+ case api::StorageType::TEXTURE_2D:
265
+ return VK_KERNEL (image2d_to_nchw);
266
+ default :
267
+ VK_THROW (" No kernel available!" );
268
+ }
115
269
} else {
116
- VK_THROW (
117
- " Unexpected input value type " ,
118
- in_val.type (),
119
- " and output value type " ,
120
- out_val.type ());
270
+ VK_THROW (" Unsupported dtype!" );
121
271
}
122
272
}
123
273
0 commit comments