Skip to content

[ET-VK] Adding boolean parameters to add_copy_offset_node to specify index calculation function in copy op's shader. #9343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,29 @@ const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);

${layout_declare_spec_const(C, "int", "batch_index_function", "0")}

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, range))) {
return;
}

const ivec3 in_pos = pos + src_offset.xyz;
ivec3 in_pos = pos + src_offset.xyz;
ivec3 out_pos = pos + dst_offset.xyz;

// If source channel size is specified compose output z based on channel and batch index
if (src_offset.w > 0) {
const int channel_index = in_pos.z % src_offset.w;
const int batch_index = in_pos.z / src_offset.w;
out_pos.z = channel_index + dst_offset.z + batch_index * dst_offset.w;
if (batch_index_function == 1) {
// batch index is calculated using source channel size
const int channel_index = pos.z % src_offset.w;
const int batch_index = pos.z / src_offset.w;
out_pos.z = channel_index + dst_offset.z + batch_index * dst_offset.w;
} else if (batch_index_function == 2) {
// batch index is calculated using destination channel size
const int channel_index = pos.z % dst_offset.w;
const int batch_index = pos.z / dst_offset.w;
in_pos.z = channel_index + src_offset.z + batch_index * src_offset.w;
}
}

write_texel_lpos(
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void add_cat_default_node(
// concatenating channels
src_offset[3] = is_concat_channel ? in_channel_size : 0;
add_copy_offset_node(
graph, input_ref, range, src_offset, dst_offset, out);
graph, input_ref, range, src_offset, dst_offset, out, true, false);
dst_offset[dim_xyz_index] +=
is_concat_channel ? in_channel_size : range[dim_xyz_index];
}
Expand Down
13 changes: 10 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ void add_copy_offset_node(
const ivec3& range,
const ivec4& src_offset,
const ivec4& dst_offset,
const ValueRef out) {
const ValueRef out,
bool calc_out_pos_using_src_chnl,
bool calc_in_pos_using_dst_chnl) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_out = graph.get_tensor(out);

Expand All @@ -49,7 +51,11 @@ void add_copy_offset_node(
// Parameter buffers
{},
// Specialization Constants
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)},
{graph.hashed_layout_of(out),
graph.hashed_layout_of(in),
(calc_out_pos_using_src_chnl ? 1
: calc_in_pos_using_dst_chnl ? 2
: 0)},
nullptr,
{},
{
Expand Down Expand Up @@ -256,7 +262,8 @@ void add_copy_offset_node(
ivec4 src_offset = {src[0], src[1], src[2], 0};
ivec4 dst_offset = {dst[0], dst[1], dst[2], 0};

add_copy_offset_node(graph, in, range, src_offset, dst_offset, out);
add_copy_offset_node(
graph, in, range, src_offset, dst_offset, out, false, false);
}

void copy_offset(ComputeGraph& graph, const std::vector<ValueRef>& args) {
Expand Down
17 changes: 16 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,28 @@ namespace vkcompute {
// It is possible to have input and output to point to the same image
// object. But when the source range and destination range overlap, the behavior
// is undefined.
//
// boolean flags calc_out_pos_using_src_chnl and calc_in_pos_using_dst_chnl
// can be used to specify an indexing function in the shader
// If calc_out_pos_using_src_chnl is set to true channel and batch index will be
// calculated based on source channel size and will be used to determine
// destination texel position.
//
// If calc_in_pos_using_dst_chnl is set to truechannel and batch index will be
// calculated based on destination channel size and will be used to determine
// source texel position.
//
// If both are true calc_out_pos_using_src_chnl is picked. If both are false no
// index calculation happens.
void add_copy_offset_node(
ComputeGraph& graph,
const ValueRef in,
const utils::ivec3& range,
const utils::ivec4& src_offset,
const utils::ivec4& dst_offset,
const ValueRef out);
const ValueRef out,
bool calc_out_pos_using_src_chnl,
bool calc_in_pos_using_dst_chnl);

// add_copy_packed_dim_offset_node behaves similar to add_copy_node, except that
// its used when copying packed dimension, if tensor is width or height packed.
Expand Down
9 changes: 5 additions & 4 deletions backends/vulkan/runtime/graph/ops/impl/Repeat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ void add_repeat_node(
utils::ivec4 src_offset{0, 0, 0, 0};
utils::ivec4 dst_offset{0, 0, 0, 0};

add_copy_offset_node(graph, in, running_range, src_offset, dst_offset, out);
add_copy_offset_node(
graph, in, running_range, src_offset, dst_offset, out, false, false);

} else {
add_repeat_channel_node(graph, in, channel_repeat, out, running_range);
Expand All @@ -166,7 +167,7 @@ void add_repeat_node(
utils::ivec4 dst_offset{i * dim_at<kWidth4D>(in_sizes), 0, 0, 0};

add_copy_offset_node(
graph, out, running_range, src_offset, dst_offset, out);
graph, out, running_range, src_offset, dst_offset, out, true, false);
}

running_range[0] = running_range[0] * width_repeat;
Expand All @@ -180,7 +181,7 @@ void add_repeat_node(
utils::ivec4 dst_offset = {0, i * dim_at<kHeight4D>(in_sizes), 0, 0};

add_copy_offset_node(
graph, out, running_range, src_offset, dst_offset, out);
graph, out, running_range, src_offset, dst_offset, out, true, false);
}

running_range[1] = running_range[1] * height_repeat;
Expand All @@ -194,7 +195,7 @@ void add_repeat_node(
utils::ivec4 dst_offset = {0, 0, i * running_range[2], 0};

add_copy_offset_node(
graph, out, running_range, src_offset, dst_offset, out);
graph, out, running_range, src_offset, dst_offset, out, true, false);
}

running_range[2] = running_range[2] * batch_repeat;
Expand Down
9 changes: 6 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/Split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ void add_split_with_sizes_default_node(
// output tensor's size matches with the split_size.
vTensorPtr t_out = graph.get_tensor(out_ref);
utils::ivec3 range = t_out->logical_limits();
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
add_copy_offset_node(
graph, in, range, src_offset, dst_offset, out_ref, false, true);

src_offset[0] += range[0];
}
Expand All @@ -62,7 +63,8 @@ void add_split_with_sizes_default_node(
for (ValueRef out_ref : *out_list) {
vTensorPtr t_out = graph.get_tensor(out_ref);
utils::ivec3 range = t_out->logical_limits();
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
add_copy_offset_node(
graph, in, range, src_offset, dst_offset, out_ref, false, true);

src_offset[1] += range[1];
}
Expand All @@ -73,7 +75,8 @@ void add_split_with_sizes_default_node(
for (ValueRef out_ref : *out_list) {
vTensorPtr t_out = graph.get_tensor(out_ref);
utils::ivec3 range = t_out->logical_limits();
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
add_copy_offset_node(
graph, in, range, src_offset, dst_offset, out_ref, false, true);

src_offset[2] += range[2];
}
Expand Down
Loading