Skip to content

Commit 6fb42ef

Browse files
[ET-VK] Adding boolean parameters to add_copy_offset_node to specify index calculation function in copy op's shader. (#9437)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #9343 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/64/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/64/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/64/orig @diff-train-skip-merge --------- Co-authored-by: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com>
1 parent 8b8bd23 commit 6fb42ef

File tree

7 files changed

+116
-30
lines changed

7 files changed

+116
-30
lines changed

backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,29 @@ const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
3535
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
3636
const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);
3737

38+
${layout_declare_spec_const(C, "int", "batch_index_function", "0")}
39+
3840
void main() {
3941
const ivec3 pos = ivec3(gl_GlobalInvocationID);
4042

4143
if (any(greaterThanEqual(pos, range))) {
4244
return;
4345
}
4446

45-
const ivec3 in_pos = pos + src_offset.xyz;
47+
ivec3 in_pos = pos + src_offset.xyz;
4648
ivec3 out_pos = pos + dst_offset.xyz;
47-
48-
// If source channel size is specified compose output z based on channel and batch index
4949
if (src_offset.w > 0) {
50-
const int channel_index = in_pos.z % src_offset.w;
51-
const int batch_index = in_pos.z / src_offset.w;
52-
out_pos.z = channel_index + dst_offset.z + batch_index * dst_offset.w;
50+
if (batch_index_function == 1) {
51+
// batch index is calculated using source channel size
52+
const int channel_index = pos.z % src_offset.w;
53+
const int batch_index = pos.z / src_offset.w;
54+
out_pos.z = channel_index + dst_offset.z + batch_index * dst_offset.w;
55+
} else if (batch_index_function == 2) {
56+
// batch index is calculated using destination channel size
57+
const int channel_index = pos.z % dst_offset.w;
58+
const int batch_index = pos.z / dst_offset.w;
59+
in_pos.z = channel_index + src_offset.z + batch_index * src_offset.w;
60+
}
5361
}
5462

5563
write_texel_lpos(

backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.glsl

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,57 @@ void main() {
4444
return;
4545
}
4646

47-
// Starting offset to write at within a texel
48-
const int out_lane_offset = dst_offset[packed_dim] & 0x3;
49-
const bool has_lane_offset = out_lane_offset != 0;
50-
5147
// Position in input tensor
52-
const ivec3 in_pos = pos + src_offset.xyz;
48+
ivec3 in_pos = pos + src_offset.xyz;
49+
in_pos[packed_dim] = pos[packed_dim] + (src_offset[packed_dim] >> 2);
5350

5451
// Read input value mapping to this output texel
55-
const VEC4_T in_value = load_texel_lpos(t_in, in_pos, in_axis_map);
52+
VEC4_T in_value = load_texel_lpos(t_in, in_pos, in_axis_map);
53+
54+
// Starting offset to read from a texel
55+
const int src_lane_offset = src_offset[packed_dim] & 0x3;
56+
const bool has_src_lane_offset = src_lane_offset != 0;
57+
58+
// If input lane offset is non zero i.e packed texel is composed from multiple sources
59+
if (has_src_lane_offset) {
60+
// Boundary values will come from next input texel in the packed dim.
61+
ivec3 next_in_pos = in_pos;
62+
next_in_pos[packed_dim] = in_pos[packed_dim] + 1;
63+
VEC4_T next_value = load_texel_lpos(t_in, next_in_pos, in_axis_map);
64+
65+
// Keep input values from the end of current input pixel based on src_lane_offset
66+
// offset 1 means the first lane of current input texel is not a part of the output texel
67+
// offset 2 means first 2 lanes are not and so on
68+
if (src_lane_offset == 1) {
69+
in_value.xyz = in_value.yzw;
70+
} else if (src_lane_offset == 2) {
71+
in_value.xy = in_value.zw;
72+
} else {
73+
in_value.x = in_value.w;
74+
}
75+
// Copy next texel's values towards the end of input texel, based on lane offset
76+
// offset 1 means the first lane from next texel is part of the input texel
77+
// offset 2 means first 2 lanes from next texel is part of the input texel and so on
78+
if (src_lane_offset == 1) {
79+
in_value.w = next_value.x;
80+
} else if (src_lane_offset == 2) {
81+
in_value.zw = next_value.xy;
82+
} else {
83+
in_value.yzw = next_value.xyz;
84+
}
85+
}
86+
87+
// Starting offset to write at within a texel
88+
const int out_lane_offset = dst_offset[packed_dim] & 0x3;
89+
const bool has_dst_lane_offset = out_lane_offset != 0;
5690

5791
ivec3 out_pos = pos + dst_offset.xyz;
5892
out_pos[packed_dim] = pos[packed_dim] + (dst_offset[packed_dim] >> 2);
5993

6094
VEC4_T out_value;
6195

6296
// If lane offset is non zero i.e packed texel is composed from multiple sources
63-
if (has_lane_offset) {
97+
if (has_dst_lane_offset) {
6498
// When position in packed dim is > 0
6599
if (pos[packed_dim] > 0) {
66100
// Boundary values will come from previous input texel in the packed dim.

backends/vulkan/runtime/graph/ops/impl/Cat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void add_cat_default_node(
8080
// concatenating channels
8181
src_offset[3] = is_concat_channel ? in_channel_size : 0;
8282
add_copy_offset_node(
83-
graph, input_ref, range, src_offset, dst_offset, out);
83+
graph, input_ref, range, src_offset, dst_offset, out, true, false);
8484
dst_offset[dim_xyz_index] +=
8585
is_concat_channel ? in_channel_size : range[dim_xyz_index];
8686
}

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ void add_copy_offset_node(
2525
const ivec3& range,
2626
const ivec4& src_offset,
2727
const ivec4& dst_offset,
28-
const ValueRef out) {
28+
const ValueRef out,
29+
bool calc_out_pos_using_src_chnl,
30+
bool calc_in_pos_using_dst_chnl) {
2931
vTensorPtr t_in = graph.get_tensor(in);
3032
vTensorPtr t_out = graph.get_tensor(out);
3133

@@ -49,7 +51,11 @@ void add_copy_offset_node(
4951
// Parameter buffers
5052
{},
5153
// Specialization Constants
52-
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)},
54+
{graph.hashed_layout_of(out),
55+
graph.hashed_layout_of(in),
56+
(calc_out_pos_using_src_chnl ? 1
57+
: calc_in_pos_using_dst_chnl ? 2
58+
: 0)},
5359
nullptr,
5460
{},
5561
{
@@ -86,19 +92,37 @@ void add_copy_packed_dim_offset_node(
8692
ivec4 final_range = {
8793
range[0], range[1], range[2], dim_at(t_in->sizes(), kBatch4D)};
8894
ivec3 global_wg_size = t_out->logical_limits();
95+
// The starting offset in a texel where this tensor will start copying from
96+
const auto src_lane_offset = src_offset[packed_dim] & 0x3;
8997
// The starting offset in a texel where this tensor will start copying to
9098
const auto dst_lane_offset = dst_offset[packed_dim] & 0x3;
99+
100+
// The total packed texels this tensor will be copied from
101+
// The first texel of tensor data in packed dimension will be copied from
102+
// remaining lanes from current source Hence (4 - src_lane_offset) is added
103+
// to tensor size in packed dimension
104+
const auto src_packed_size = utils::div_up_4(
105+
(4 - src_lane_offset) +
106+
dim_at(t_out->sizes(), normalize_to_dim_index(*t_out, packed_dim)));
107+
91108
// The total packed texels this tensor will be copied to
92-
// The first texel of tensor data in packed dimension will be copied to remain
93-
// lanes from previous write Hence (4 - dst_lane_offset) is added to tensor
94-
// size in packed dimension
109+
// The first texel of tensor data in packed dimension will be copied to
110+
// remaining lanes from previous write Hence (4 - dst_lane_offset) is added to
111+
// tensor size in packed dimension
95112
const auto dst_packed_size = utils::div_up_4(
96113
(4 - dst_lane_offset) +
97114
dim_at(t_in->sizes(), normalize_to_dim_index(*t_in, packed_dim)));
98115

99-
// If the starting offset is not 0, and the total packed texels is greater
116+
// If the starting src offset is not 0, and the total packed texels is greater
117+
// than the source texel range
118+
const bool has_additional_src_work =
119+
src_lane_offset != 0 && src_packed_size > final_range[packed_dim];
120+
// If the starting dst offset is not 0, and the total packed texels is greater
100121
// than the source texel range
101-
if (dst_lane_offset != 0 && dst_packed_size > final_range[packed_dim]) {
122+
const bool has_additional_dst_work =
123+
dst_lane_offset != 0 && dst_packed_size > final_range[packed_dim];
124+
125+
if (has_additional_src_work || has_additional_dst_work) {
102126
global_wg_size[packed_dim]++; // Increase the global work group size in
103127
// packed dimension
104128
final_range[packed_dim]++; // Increase the range in packed dimension
@@ -256,7 +280,8 @@ void add_copy_offset_node(
256280
ivec4 src_offset = {src[0], src[1], src[2], 0};
257281
ivec4 dst_offset = {dst[0], dst[1], dst[2], 0};
258282

259-
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out);
283+
add_copy_offset_node(
284+
graph, in, range, src_offset, dst_offset, out, false, false);
260285
}
261286

262287
void copy_offset(ComputeGraph& graph, const std::vector<ValueRef>& args) {

backends/vulkan/runtime/graph/ops/impl/Copy.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,28 @@ namespace vkcompute {
2222
// It is possible to have input and output to point to the same image
2323
// object. But when the source range and destination range overlap, the behavior
2424
// is undefined.
25+
//
26+
// boolean flags calc_out_pos_using_src_chnl and calc_in_pos_using_dst_chnl
27+
// can be used to specify an indexing function in the shader
28+
// If calc_out_pos_using_src_chnl is set to true channel and batch index will be
29+
// calculated based on source channel size and will be used to determine
30+
// destination texel position.
31+
//
32+
// If calc_in_pos_using_dst_chnl is set to truechannel and batch index will be
33+
// calculated based on destination channel size and will be used to determine
34+
// source texel position.
35+
//
36+
// If both are true calc_out_pos_using_src_chnl is picked. If both are false no
37+
// index calculation happens.
2538
void add_copy_offset_node(
2639
ComputeGraph& graph,
2740
const ValueRef in,
2841
const utils::ivec3& range,
2942
const utils::ivec4& src_offset,
3043
const utils::ivec4& dst_offset,
31-
const ValueRef out);
44+
const ValueRef out,
45+
bool calc_out_pos_using_src_chnl,
46+
bool calc_in_pos_using_dst_chnl);
3247

3348
// add_copy_packed_dim_offset_node behaves similar to add_copy_node, except that
3449
// its used when copying packed dimension, if tensor is width or height packed.

backends/vulkan/runtime/graph/ops/impl/Repeat.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ void add_repeat_node(
151151
utils::ivec4 src_offset{0, 0, 0, 0};
152152
utils::ivec4 dst_offset{0, 0, 0, 0};
153153

154-
add_copy_offset_node(graph, in, running_range, src_offset, dst_offset, out);
154+
add_copy_offset_node(
155+
graph, in, running_range, src_offset, dst_offset, out, false, false);
155156

156157
} else {
157158
add_repeat_channel_node(graph, in, channel_repeat, out, running_range);
@@ -166,7 +167,7 @@ void add_repeat_node(
166167
utils::ivec4 dst_offset{i * dim_at<kWidth4D>(in_sizes), 0, 0, 0};
167168

168169
add_copy_offset_node(
169-
graph, out, running_range, src_offset, dst_offset, out);
170+
graph, out, running_range, src_offset, dst_offset, out, true, false);
170171
}
171172

172173
running_range[0] = running_range[0] * width_repeat;
@@ -180,7 +181,7 @@ void add_repeat_node(
180181
utils::ivec4 dst_offset = {0, i * dim_at<kHeight4D>(in_sizes), 0, 0};
181182

182183
add_copy_offset_node(
183-
graph, out, running_range, src_offset, dst_offset, out);
184+
graph, out, running_range, src_offset, dst_offset, out, true, false);
184185
}
185186

186187
running_range[1] = running_range[1] * height_repeat;
@@ -194,7 +195,7 @@ void add_repeat_node(
194195
utils::ivec4 dst_offset = {0, 0, i * running_range[2], 0};
195196

196197
add_copy_offset_node(
197-
graph, out, running_range, src_offset, dst_offset, out);
198+
graph, out, running_range, src_offset, dst_offset, out, true, false);
198199
}
199200

200201
running_range[2] = running_range[2] * batch_repeat;

backends/vulkan/runtime/graph/ops/impl/Split.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ void add_split_with_sizes_default_node(
5151
// output tensor's size matches with the split_size.
5252
vTensorPtr t_out = graph.get_tensor(out_ref);
5353
utils::ivec3 range = t_out->logical_limits();
54-
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
54+
add_copy_offset_node(
55+
graph, in, range, src_offset, dst_offset, out_ref, false, true);
5556

5657
src_offset[0] += range[0];
5758
}
@@ -62,7 +63,8 @@ void add_split_with_sizes_default_node(
6263
for (ValueRef out_ref : *out_list) {
6364
vTensorPtr t_out = graph.get_tensor(out_ref);
6465
utils::ivec3 range = t_out->logical_limits();
65-
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
66+
add_copy_offset_node(
67+
graph, in, range, src_offset, dst_offset, out_ref, false, true);
6668

6769
src_offset[1] += range[1];
6870
}
@@ -73,7 +75,8 @@ void add_split_with_sizes_default_node(
7375
for (ValueRef out_ref : *out_list) {
7476
vTensorPtr t_out = graph.get_tensor(out_ref);
7577
utils::ivec3 range = t_out->logical_limits();
76-
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
78+
add_copy_offset_node(
79+
graph, in, range, src_offset, dst_offset, out_ref, false, true);
7780

7881
src_offset[2] += range[2];
7982
}

0 commit comments

Comments
 (0)