@@ -34,12 +34,12 @@ void resize_conv2d_node(
34
34
if (ndim == 4 ) {
35
35
new_out_sizes.at (ndim - 4 ) = self.sizes ().at (ndim - 4 );
36
36
}
37
- const auto weight_sizes = graph->get_val (extra_args[0 ]).toTensorRef ().sizes ;
37
+ const auto & weight_sizes = graph->get_val (extra_args[0 ]).toTensorRef ().sizes ;
38
38
new_out_sizes.at (ndim - 3 ) =
39
39
transposed ? weight_sizes.at (ndim - 3 ) : weight_sizes.at (ndim - 4 );
40
40
41
41
// Height, Width
42
- const auto new_out_sizes_hw = calc_out_sizes_hw (
42
+ const auto & new_out_sizes_hw = calc_out_sizes_hw (
43
43
*graph,
44
44
self.sizes (),
45
45
extra_args[0 ],
@@ -84,13 +84,24 @@ enum class Conv2dMethod : uint8_t {
84
84
};
85
85
86
86
api::ShaderInfo get_conv2d_shader (
87
+ ComputeGraph& graph,
87
88
const vTensor& t_out,
88
89
const bool prepack_weights,
89
- const Conv2dMethod method) {
90
+ const Conv2dMethod method,
91
+ const ValueRef weight) {
90
92
std::stringstream kernel_name;
91
93
switch (method) {
92
94
case Conv2dMethod::Depthwise:
93
95
kernel_name << " conv2d_dw" ;
96
+ if (!prepack_weights) {
97
+ const auto & weight_sizes = graph.get_val (weight).toTensorRef ().sizes ;
98
+ if (weight_sizes.at (2 ) == 3 && weight_sizes.at (3 ) == 3 ) {
99
+ kernel_name << " _output_tile_3x3" ;
100
+ }
101
+ if (weight_sizes.at (2 ) == 5 && weight_sizes.at (3 ) == 5 ) {
102
+ kernel_name << " _output_tile_5x5" ;
103
+ }
104
+ }
94
105
break ;
95
106
case Conv2dMethod::SlidingWindow:
96
107
kernel_name << " conv2d" ;
@@ -153,7 +164,7 @@ ValueRef prepack_weights(
153
164
const ValueRef vref,
154
165
const Conv2dMethod method) {
155
166
const auto original_sizes = graph.get_val (vref).toTensorRef ().sizes ;
156
- const auto final_sizes = get_final_sizes (original_sizes, method);
167
+ const auto & final_sizes = get_final_sizes (original_sizes, method);
157
168
158
169
ValueRef v = graph.add_tensor (
159
170
final_sizes,
@@ -166,9 +177,9 @@ ValueRef prepack_weights(
166
177
api::utils::uvec3 local_size = adaptive_work_group_size (global_size);
167
178
168
179
api::ShaderInfo shader =
169
- get_conv2d_shader (t, /* prepack_weights = */ true , method);
180
+ get_conv2d_shader (graph, t, /* prepack_weights = */ true , method, vref );
170
181
171
- const auto padded_sizes = get_padded_sizes (original_sizes, method);
182
+ const auto & padded_sizes = get_padded_sizes (original_sizes, method);
172
183
173
184
graph.prepack_nodes ().emplace_back (new PrepackNode (
174
185
graph,
@@ -205,13 +216,13 @@ Conv2dParams create_conv2d_params(
205
216
const ValueRef weight,
206
217
const KernelParams& p,
207
218
const bool transposed) {
208
- const auto overlay_region = api::utils::make_ivec2 ({
219
+ const auto & overlay_region = api::utils::make_ivec2 ({
209
220
p.kernel_size .data [0 ] +
210
221
(p.kernel_size .data [0 ] - 1 ) * (p.dilation .data [0 ] - 1 ),
211
222
p.kernel_size .data [1 ] +
212
223
(p.kernel_size .data [1 ] - 1 ) * (p.dilation .data [1 ] - 1 ),
213
224
});
214
- const auto weight_sizes = graph.get_val (weight).toTensorRef ().sizes ;
225
+ const auto & weight_sizes = graph.get_val (weight).toTensorRef ().sizes ;
215
226
const int32_t in_group_size =
216
227
api::utils::safe_downcast<int32_t >(api::utils::align_up (
217
228
transposed ? weight_sizes.at (0 ) : weight_sizes.at (1 ), INT64_C (4 )));
@@ -239,7 +250,7 @@ Conv2dMethod get_conv2d_method(
239
250
const ValueRef weight,
240
251
const int64_t groups,
241
252
const bool transposed) {
242
- const auto weight_sizes = graph.get_val (weight).toTensorRef ().sizes ;
253
+ const auto & weight_sizes = graph.get_val (weight).toTensorRef ().sizes ;
243
254
if (!transposed && weight_sizes.at (0 ) == groups && weight_sizes.at (1 ) == 1 ) {
244
255
return Conv2dMethod::Depthwise;
245
256
}
@@ -293,8 +304,8 @@ void add_conv2d_node(
293
304
294
305
check_conv2d_params (kernel_params, transposed_val);
295
306
296
- api::ShaderInfo shader =
297
- get_conv2d_shader ( t_out, /* prepack_weights = */ false , method);
307
+ api::ShaderInfo shader = get_conv2d_shader (
308
+ graph, t_out, /* prepack_weights = */ false , method, weight );
298
309
299
310
graph.execute_nodes ().emplace_back (new ExecuteNode (
300
311
graph,
0 commit comments