Skip to content

Commit cd023a4

Browse files
[ET-VK] Minor tuning for conv2d pw op to improve performance. (#11185)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11113 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/92/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/92/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/91/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/92/orig @diff-train-skip-merge --------- Co-authored-by: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com>
1 parent ea745ac commit cd023a4

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,14 @@ void main() {
4949
const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
5050
const ivec3 gpos = ivec3(
5151
gl_GlobalInvocationID.x % out_limits_scaled.x,
52-
div_by_x % out_limits_scaled.y,
53-
div_by_x / out_limits_scaled.y);
52+
div_by_x,
53+
gl_GlobalInvocationID.y);
54+
55+
// If the top left position is out of bounds, then this invocation will have
56+
// no work to do.
57+
if (gpos.y >= out_limits_scaled.y || gpos.z >= out_limits.z) {
58+
return;
59+
}
5460

5561
// If the top left position is out of bounds, then this invocation will have
5662
// no work to do.

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,10 @@ void add_conv2d_node(
398398
utils::uvec3 wg_size = create_conv2d_global_wg_size(
399399
graph, method, out, weight_data, stride_equals_dilation);
400400

401-
if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
401+
if (method == Conv2dMethod::Depthwise) {
402402
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};
403+
} else if (method == Conv2dMethod::Pointwise) {
404+
wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1};
403405
}
404406

405407
vkapi::ParamsBindList param_buffers;

0 commit comments

Comments
 (0)