Skip to content

Commit

Permalink
Implement aten.squeeze_copy.dims (#4223)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4223

Implement aten.squeeze_copy.dims operator
This op is compiled from `torch.squeeze(x, dims=0)`.

bypass-github-export-checks
bypass-github-pytorch-ci-checks
bypass-github-executorch-ci-checks

Reviewed By: jorgep31415

Differential Revision: D59605342

fbshipit-source-id: 2acabe080360875937e4e48d427d6cc7fae802ff
  • Loading branch information
Yujie Hui authored and facebook-github-bot committed Jul 12, 2024
1 parent 9221ab6 commit 7047162
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 9 deletions.
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __contains__(self, op):
]

SHAPE_MANIPULATION_OPS = [
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.permute_copy.default,
Expand Down
19 changes: 19 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Clone.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/vulkan/runtime/api/api.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

namespace vkcompute {

void add_clone_node(ComputeGraph& graph, const ValueRef in, const ValueRef out);

} // namespace vkcompute
16 changes: 7 additions & 9 deletions backends/vulkan/runtime/graph/ops/impl/Permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ void check_args(
// dim size as the argument. The code will work as long as the input tensor's
// dim size is shorter than the permute dim array. In this case, the code
// assume size of 1 at the higher dimensions.

int64_t out_dim = out.dim();
VK_CHECK_COND(
out_dim == permute_dims.size(),
"Output tensor dim size must match argument");
}

} // namespace
Expand All @@ -56,15 +51,18 @@ void add_permute_node(

ivec4 out_dims{0, 1, 2, 3};

int64_t out_dim = t_out->dim();
std::vector<bool> seen(out_dim);
for (int i = 0; i < t_out->dim(); i++) {
// Special cases of squeeze/unsqueeze. Because the input dim size can be
// different with output dim size. So pick t_in->dim() if squeeze, and
// t_out->dim() if unsqueeze to create parameter for permute.
int64_t out_ndim = std::max(t_in->dim(), t_out->dim());
std::vector<bool> seen(out_ndim);
for (int i = 0; i < out_ndim; i++) {
int64_t permute_dim = permute_dims[i];
VK_CHECK_COND(
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
seen[permute_dim] = true;

out_dims.data[(4u - out_dim) + i] = permute_dim + (4 - out_dim);
out_dims.data[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim);
}

std::string kernel_name = "permute";
Expand Down
63 changes: 63 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Clone.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void add_squeeze_copy_dims_node(
ComputeGraph& graph,
ValueRef in,
ValueRef dims_ref,
ValueRef out) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_out = graph.get_tensor(out);

IntListPtr dims = graph.get_int_list(dims_ref);
std::vector<int64_t> squeeze_dims;
// Filter out edge cases that we don't need squeeze:
// 1. The size of squeeze dim is larger than 1.
// 2. Squeeze outter most dim
// For these cases, just pass input to output via clone.
for (int i = 0; i < dims->size(); ++i) {
if (dims->at(i) != 0 && t_in->sizes().at(dims->at(i)) == 1) {
squeeze_dims.push_back(dims->at(i));
}
}
if (squeeze_dims.size() == 0) {
add_clone_node(graph, in, out);
} else {
std::vector<int64_t> permute_dims(t_in->dim());
for (int i = 0; i < t_in->dim(); ++i) {
permute_dims.at(i) = i;
}
for (auto& elem : squeeze_dims) {
auto it = std::find(permute_dims.begin(), permute_dims.end(), elem);
VK_CHECK_COND(
it != permute_dims.end(), "Squeeze dim not found in permute_dims");
std::rotate(permute_dims.begin(), it, it + 1);
}

add_permute_node(graph, in, permute_dims, out);
}
}

void squeeze_copy_dims(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_squeeze_copy_dims_node(graph, args[0], args[1], args[2]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.squeeze_copy.dims, squeeze_copy_dims);
}

} // namespace vkcompute
19 changes: 19 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,3 +1038,22 @@ def get_minimum_inputs():
]
)
return test_suite


@register_test_suite("aten.squeeze_copy.dims")
def get_squeeze_copy_dim_inputs():
test_suite = VkTestSuite(
[
([S, S, S, 1], 3),
([S, 1, S, S], 1),
([S, 1, 1, S], [1, 2]),
([1, S, S, S], 0),
([S, S, S, S], 3),
([S, S, S, S], 2),
([S, S, S, S], 1),
([M, M1, 1], 2),
([M, 1, M1], 1),
([1, M1, M1], 0),
]
)
return test_suite
16 changes: 16 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,22 @@ def forward(self, x):
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_squeeze(self):
class SqueezeModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.squeeze(x, 0)

sample_inputs = (torch.randn(size=(1, 2, 2, 1), dtype=torch.float32),)

self.lower_module_and_test_output(
SqueezeModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_select(self):
class SelectModule(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit 7047162

Please sign in to comment.