Skip to content

Commit f00afe7

Browse files
jorgep31415facebook-github-bot
authored andcommitted
aten.convolution (Depthwise Output-Tile) (#2885)
Summary: Pull Request resolved: #2885 We port an optimization from ATen-VK for specific weight sizes: [`conv2d_dw_output_tile.glsl`](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/glsl/conv2d_dw_output_tile.glsl) ghstack-source-id: 221887576 exported-using-ghexport bypass-github-export-checks Reviewed By: SS-JIA Differential Revision: D55814588 fbshipit-source-id: 86a85d122abbcebfed41466bc0a4907a6ddc80f9
1 parent 02f565e commit f00afe7

File tree

5 files changed

+130
-15
lines changed

5 files changed

+130
-15
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ void main() {
7070
int kx = 0;
7171
for (int y = start.y; y < end.y; y += params.dilation.y) {
7272
for (int x = start.x; x < end.x; x += params.dilation.x) {
73-
// The weight kernel was rearranged so that every NxN filter is flattened
74-
// to fits in one row. Each filter was then stacked on top of each other
75-
// vertically.
73+
// The weight kernel was rearranged such that every NxN filter is
74+
// flattened to fit in one row. Each filter was then stacked on top of
75+
// each other vertically.
7676
const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0);
7777
sum = fma(in_texel, texelFetch(kernel_in, ivec2(kx, pos.z), 0), sum);
7878
++kx;
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#include "indexing_utils.h"
14+
15+
layout(std430) buffer;
16+
17+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
18+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
19+
layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in;
20+
layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in;
21+
22+
layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents {
23+
uvec4 data;
24+
}
25+
out_extents;
26+
27+
layout(set = 0, binding = 5) uniform PRECISION restrict InExtents {
28+
uvec4 data;
29+
}
30+
in_extents;
31+
32+
layout(set = 0, binding = 6) uniform PRECISION restrict Params {
33+
ivec2 kernel_size;
34+
ivec2 stride;
35+
ivec2 padding;
36+
ivec2 dilation;
37+
}
38+
params;
39+
40+
// If fields are separated, SwiftShader cannot identify in_group_size.
41+
layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams {
42+
ivec2 overlay_region;
43+
int in_group_size;
44+
}
45+
extra_params;
46+
47+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
48+
49+
/*
50+
* Computes a depthwise convolution. Each shader invocation calculates the
51+
* output at a single output location.
52+
*/
53+
void main() {
54+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
55+
56+
if (any(greaterThanEqual(pos, out_extents.data.xyz))) {
57+
return;
58+
}
59+
60+
// Compute the index of the top-left element of the overlay region. Negative
61+
// indices indicate that the top-left element is in a region added by padding.
62+
const ivec2 ipos = pos.xy * params.stride - params.padding;
63+
64+
// Compute the start and end of the input indices to load. Padding is assumed
65+
// to be constant 0 padding, so any reads from the padding region is skipped.
66+
const ivec2 start = ipos;
67+
const ivec2 end = ipos + extra_params.overlay_region.xy;
68+
69+
${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
70+
int kx = 0;
71+
for (int y = start.y, i = 0; i < ${TILE_SIZE}; y += params.dilation.y, i++) {
72+
for (int x = start.x, j = 0; j < ${TILE_SIZE}; x += params.dilation.x, j++) {
73+
// The weight kernel was rearranged such that every NxN filter is
74+
// flattened to fit in one row. Each filter was then stacked on top of
75+
// each other vertically.
76+
const vec4 in_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0);
77+
sum = fma(in_texel, texelFetch(kernel_in, ivec2(kx, pos.z), 0), sum);
78+
kx++;
79+
}
80+
}
81+
82+
imageStore(image_out, pos, sum);
83+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv2d_dw_output_tile:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
TILE_SIZE: 3
12+
generate_variant_forall:
13+
DTYPE:
14+
- VALUE: half
15+
SUFFIX: half
16+
- VALUE: float
17+
SUFFIX: float
18+
shader_variants:
19+
- NAME: conv2d_dw_output_tile_3x3
20+
- NAME: conv2d_dw_output_tile_5x5
21+
TILE_SIZE: 5

backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ void resize_conv2d_node(
3434
if (ndim == 4) {
3535
new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4);
3636
}
37-
const auto weight_sizes = graph->get_val(extra_args[0]).toTensorRef().sizes;
37+
const auto& weight_sizes = graph->get_val(extra_args[0]).toTensorRef().sizes;
3838
new_out_sizes.at(ndim - 3) =
3939
transposed ? weight_sizes.at(ndim - 3) : weight_sizes.at(ndim - 4);
4040

4141
// Height, Width
42-
const auto new_out_sizes_hw = calc_out_sizes_hw(
42+
const auto& new_out_sizes_hw = calc_out_sizes_hw(
4343
*graph,
4444
self.sizes(),
4545
extra_args[0],
@@ -84,13 +84,24 @@ enum class Conv2dMethod : uint8_t {
8484
};
8585

8686
api::ShaderInfo get_conv2d_shader(
87+
ComputeGraph& graph,
8788
const vTensor& t_out,
8889
const bool prepack_weights,
89-
const Conv2dMethod method) {
90+
const Conv2dMethod method,
91+
const ValueRef weight) {
9092
std::stringstream kernel_name;
9193
switch (method) {
9294
case Conv2dMethod::Depthwise:
9395
kernel_name << "conv2d_dw";
96+
if (!prepack_weights) {
97+
const auto& weight_sizes = graph.get_val(weight).toTensorRef().sizes;
98+
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
99+
kernel_name << "_output_tile_3x3";
100+
}
101+
if (weight_sizes.at(2) == 5 && weight_sizes.at(3) == 5) {
102+
kernel_name << "_output_tile_5x5";
103+
}
104+
}
94105
break;
95106
case Conv2dMethod::SlidingWindow:
96107
kernel_name << "conv2d";
@@ -153,7 +164,7 @@ ValueRef prepack_weights(
153164
const ValueRef vref,
154165
const Conv2dMethod method) {
155166
const auto original_sizes = graph.get_val(vref).toTensorRef().sizes;
156-
const auto final_sizes = get_final_sizes(original_sizes, method);
167+
const auto& final_sizes = get_final_sizes(original_sizes, method);
157168

158169
ValueRef v = graph.add_tensor(
159170
final_sizes,
@@ -166,9 +177,9 @@ ValueRef prepack_weights(
166177
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
167178

168179
api::ShaderInfo shader =
169-
get_conv2d_shader(t, /*prepack_weights = */ true, method);
180+
get_conv2d_shader(graph, t, /*prepack_weights = */ true, method, vref);
170181

171-
const auto padded_sizes = get_padded_sizes(original_sizes, method);
182+
const auto& padded_sizes = get_padded_sizes(original_sizes, method);
172183

173184
graph.prepack_nodes().emplace_back(new PrepackNode(
174185
graph,
@@ -205,13 +216,13 @@ Conv2dParams create_conv2d_params(
205216
const ValueRef weight,
206217
const KernelParams& p,
207218
const bool transposed) {
208-
const auto overlay_region = api::utils::make_ivec2({
219+
const auto& overlay_region = api::utils::make_ivec2({
209220
p.kernel_size.data[0] +
210221
(p.kernel_size.data[0] - 1) * (p.dilation.data[0] - 1),
211222
p.kernel_size.data[1] +
212223
(p.kernel_size.data[1] - 1) * (p.dilation.data[1] - 1),
213224
});
214-
const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes;
225+
const auto& weight_sizes = graph.get_val(weight).toTensorRef().sizes;
215226
const int32_t in_group_size =
216227
api::utils::safe_downcast<int32_t>(api::utils::align_up(
217228
transposed ? weight_sizes.at(0) : weight_sizes.at(1), INT64_C(4)));
@@ -239,7 +250,7 @@ Conv2dMethod get_conv2d_method(
239250
const ValueRef weight,
240251
const int64_t groups,
241252
const bool transposed) {
242-
const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes;
253+
const auto& weight_sizes = graph.get_val(weight).toTensorRef().sizes;
243254
if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) {
244255
return Conv2dMethod::Depthwise;
245256
}
@@ -293,8 +304,8 @@ void add_conv2d_node(
293304

294305
check_conv2d_params(kernel_params, transposed_val);
295306

296-
api::ShaderInfo shader =
297-
get_conv2d_shader(t_out, /*prepack_weights = */ false, method);
307+
api::ShaderInfo shader = get_conv2d_shader(
308+
graph, t_out, /*prepack_weights = */ false, method, weight);
298309

299310
graph.execute_nodes().emplace_back(new ExecuteNode(
300311
graph,

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void resize_max_pool2d_node(
3535
new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3);
3636

3737
// Height, Width
38-
const auto new_out_sizes_hw = calc_out_sizes_hw(
38+
const auto& new_out_sizes_hw = calc_out_sizes_hw(
3939
*graph,
4040
self.sizes(),
4141
extra_args[0],

0 commit comments

Comments
 (0)