Skip to content

[ET-VK][Ops] aten.index_select #3744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __contains__(self, op):
]

INDEXING_OPS = [
exir_ops.edge.aten.index_select.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.slice_copy.Tensor,
]
Expand Down
44 changes: 44 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/index_select.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_type(DTYPE)}

layout(std430) buffer;

#include "indexing_utils.h"

${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
${layout_declare_tensor(2, "r", "t_idx", "int", STORAGE)}
${layout_declare_ubo(3, "ivec4", "sizes")}
${layout_declare_ubo(4, "int", "gpu_dim", "int", "stride")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = C_DIM;

void main() {
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);

if (pos_out_of_bounds(out_pos, sizes, packed_dim)) {
return;
}

const int out_idx = out_pos[gpu_dim] / stride;
const int within_stride = out_pos[gpu_dim] % stride;
const int in_idx = texelFetch(t_idx, ivec3(out_idx, 0, 0), 0).x;

ivec3 in_pos = out_pos;
in_pos[gpu_dim] = in_idx * stride + within_stride;

imageStore(t_out, out_pos, texelFetch(t_in, in_pos, 0));
}
11 changes: 11 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/index_select.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
index_select:
parameter_names_with_default_values:
DTYPE: float
NDIM: 3
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: index_select
55 changes: 55 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/index_select_channel.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_type(DTYPE)}

layout(std430) buffer;

#include "indexing_utils.h"

${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
${layout_declare_tensor(2, "r", "t_idx", "int", STORAGE)}
${layout_declare_ubo(3, "ivec4", "out_sizes")}
${layout_declare_ubo(4, "ivec4", "in_sizes")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = C_DIM;

void main() {
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);

if (pos_out_of_bounds(out_pos, out_sizes, packed_dim)) {
return;
}

const ivec4 idx = to_tensor_idx(out_pos, out_sizes, packed_dim);
const ivec4 buffer_ixs = get_texel_nchw_buffer_ixs(idx, out_sizes, packed_dim);

vec4 out_texel;
for (int i = 0; i < 4; ++i) {
const ivec4 out_idx = from_nchw_buffer_i(buffer_ixs[i], out_sizes);
int out_channel = out_idx.z;
int in_channel = texelFetch(t_idx, ivec3(out_channel, 0, 0), 0).x;

ivec4 in_idx = out_idx;
in_idx.z = in_channel;

ivec4 in_elem_pos = to_texture_elem_pos(in_idx, in_sizes, packed_dim);

vec4 in_texel = texelFetch(t_in, in_elem_pos.xyz, 0);

out_texel[i] = in_texel[in_elem_pos.w];
}
imageStore(t_out, out_pos, out_texel);
}
10 changes: 10 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
index_select_channel:
parameter_names_with_default_values:
DTYPE: float
NDIM: 3
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: float
shader_variants:
- NAME: index_select_channel
137 changes: 137 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/Logging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void check_index_select_args(
const vTensor& in,
const vTensor& idx,
const vTensor& out) {
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
VK_CHECK_COND(check_memory_layout_is(idx, api::kChannelsPacked));
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
}

void add_index_select_channel_node(
ComputeGraph& graph,
ValueRef in,
ValueRef idx,
ValueRef out) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_idx = graph.get_tensor(idx);
vTensorPtr t_out = graph.get_tensor(out);

check_index_select_args(*t_in, *t_idx, *t_out);

std::string kernel_name = "index_select_channel";
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);

api::utils::uvec3 global_size = t_out->image_extents();
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
global_size,
local_size,
{{out, api::MemoryAccessType::WRITE},
{{in, idx}, api::MemoryAccessType::READ}},
{t_out->sizes_ubo(), t_in->sizes_ubo()}));
}

struct IndexSelectParams final {
int32_t gpu_dim;
int32_t stride;
};

IndexSelectParams create_index_select_params(
const int64_t dim_idx,
const vTensor& in) {
if (dim_idx == kWidth4D) {
return {0, 1};
} else if (dim_idx == kHeight4D) {
return {1, 1};
} else if (dim_idx == kBatch4D) {
int64_t n_channels = dim_at(in.sizes(), kChannel4D);
int64_t stride = api::utils::div_up_4(n_channels);
return {2, static_cast<int32_t>(stride)};
} else {
VK_THROW("Unexpected dim_idx!");
}
}

void add_index_select_node(
ComputeGraph& graph,
ValueRef in,
const int64_t dim_idx,
ValueRef idx,
ValueRef out) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_idx = graph.get_tensor(idx);
vTensorPtr t_out = graph.get_tensor(out);

check_index_select_args(*t_in, *t_idx, *t_out);

IndexSelectParams params = create_index_select_params(dim_idx, *t_in);

std::string kernel_name = "index_select";
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);

api::utils::uvec3 global_size = t_out->image_extents();
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
global_size,
local_size,
{{out, api::MemoryAccessType::WRITE},
{{in, idx}, api::MemoryAccessType::READ}},
{t_out->sizes_ubo(), graph.create_params_buffer(params)}));
}

int64_t get_dim_idx(ComputeGraph& graph, ValueRef in, ValueRef dim_ref) {
vTensorPtr t_in = graph.get_tensor(in);
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
dim = normalize(dim, t_in->dim());
return normalize_to_dim_index(*t_in, dim);
}

void index_select(ComputeGraph& graph, const std::vector<ValueRef>& args) {
ValueRef in = prepack_if_tensor_ref(graph, args[0]);
ValueRef dim_ref = args[1];
ValueRef idx = prepack_if_tensor_ref(graph, args[2]);
ValueRef out = args[3];

const int64_t dim_idx = get_dim_idx(graph, in, dim_ref);
if (dim_idx == kChannel4D) {
add_index_select_channel_node(graph, in, idx, out);
} else {
add_index_select_node(graph, in, dim_idx, idx, out);
}
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.index_select.default, index_select);
}

} // namespace vkcompute
28 changes: 25 additions & 3 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,34 @@ def get_slice_inputs():
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])

test_suite.dtypes = ["at::kFloat"]
test_suite.layouts = [
"api::kChannelsPacked",
]
test_suite.layouts = ["api::kChannelsPacked"]
test_suite.data_gen = "make_seq_tensor"
return test_suite


def get_index_select_inputs():
Test = namedtuple("VkIndexSelectTest", ["self", "dim", "index"])
Test.__new__.__defaults__ = (None, 0, None)

test_cases = []

for i in range(4):
test_cases += [
Test(self=[9, 9, 9, 9], dim=i, index=[0]),
Test(self=[9, 9, 9, 9], dim=i, index=[2]),
Test(self=[9, 9, 9, 9], dim=i, index=[0, 2]),
Test(self=[9, 9, 9, 9], dim=i, index=[3, 1]),
Test(self=[9, 9, 9, 9], dim=i, index=[5, 5]),
Test(self=[9, 9, 9, 9], dim=i, index=[2, 3, 4, 5, 7, 10]),
]

test_suite = VkTestSuite([tuple(tc) for tc in test_cases])

test_suite.dtypes = ["at::kFloat"]
test_suite.layouts = ["api::kChannelsPacked"]
return test_suite


def get_unsqueeze_inputs():
test_suite = VkTestSuite(
[
Expand Down Expand Up @@ -795,6 +816,7 @@ def get_gelu_inputs():
"aten.view_copy.default": get_view_inputs(),
"aten.slice_copy.Tensor": get_slice_inputs(),
"aten.slice.Tensor": get_slice_inputs(),
"aten.index_select.default": get_index_select_inputs(),
"aten.unsqueeze_copy.default": get_unsqueeze_inputs(),
"aten.clone.default": get_clone_inputs(),
"aten.repeat.default": get_repeat_inputs(),
Expand Down
19 changes: 17 additions & 2 deletions backends/vulkan/test/op_tests/utils/codegen_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,12 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901
ret_str = f"{cpp_type} {arg.name} = "

if cpp_type == AT_TENSOR:
ret_str += f"{self.suite_def.data_gen}({init_list_str(data)}, test_dtype);"
if arg.name == "index":
ret_str += f"make_index_tensor({init_list_str(data)});"
else:
ret_str += (
f"{self.suite_def.data_gen}({init_list_str(data)}, test_dtype);"
)
elif cpp_type == OPT_AT_TENSOR:
if str(data) == "None":
ret_str += "std::nullopt;"
Expand Down Expand Up @@ -267,7 +272,7 @@ def generate_suite_cpp(self) -> str:

at::Tensor make_seq_tensor(
std::vector<int64_t> sizes,
at::ScalarType dtype = at::kFloat) {{
at::ScalarType dtype = at::kFloat) {{
int64_t n = 1;
for (auto size: sizes) {{
n *= size;
Expand All @@ -283,6 +288,16 @@ def generate_suite_cpp(self) -> str:
return at::from_blob(values.data(), sizes, at::kFloat).toType(dtype).detach().clone();
}}


at::Tensor make_index_tensor(std::vector<int64_t> indices) {{
int64_t size = static_cast<int64_t>(indices.size());
at::ScalarType dtype = at::kInt;

// from_blob doesn't take ownership of data. Hence must create a copy as
// "values" will go out of scope.
return at::from_blob(indices.data(), {{size}}, dtype).detach().clone();
}}

{test_suites_cpp}
"""

Expand Down
36 changes: 36 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,3 +1255,39 @@ def forward(self, x):
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
custom_pass=[MeanToSumDiv()],
)

def test_vulkan_backend_index_select_channel(self):
class IndexSelectModule(torch.nn.Module):
def __init__(self, dim, indices):
super().__init__()
self.dim = dim
self.index = torch.tensor(indices, dtype=torch.int32)

def forward(self, x):
return torch.index_select(x, self.dim, self.index)

sample_inputs = (torch.arange(96).reshape(2, 8, 2, 3).float(),)

self.lower_module_and_test_output(
IndexSelectModule(dim=1, indices=[2, 3, 5, 6, 7]),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_index_select(self):
class IndexSelectModule(torch.nn.Module):
def __init__(self, dim, indices):
super().__init__()
self.dim = dim
self.index = torch.tensor(indices, dtype=torch.int32)

def forward(self, x):
return torch.index_select(x, self.dim, self.index)

sample_inputs = (torch.arange(144).reshape(12, 1, 3, 4).float(),)

self.lower_module_and_test_output(
IndexSelectModule(dim=0, indices=[1, 3, 5, 7, 8, 9, 10, 11, 2, 3]),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)
Loading