Skip to content

Commit a3e0b77

Browse files
author
morelos
committed
[ET-VK][Ops] aten.tan.default from scratch implementation
Pull Request resolved: #11046 Goal is to create the tan operator and its test case ghstack-source-id: 285689588 @exported-using-ghexport Differential Revision: [D75100188](https://our.internmc.facebook.com/intern/diff/D75100188/)
1 parent 71275e5 commit a3e0b77

File tree

4 files changed

+153
-0
lines changed

4 files changed

+153
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
14+
#define T ${buffer_scalar_type(DTYPE)}
15+
16+
${define_active_storage_type(STORAGE)}
17+
18+
#include "indexing_utils.h"
19+
20+
${define_required_extensions(DTYPE)}
21+
22+
layout(std430) buffer;
23+
24+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
25+
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
26+
$if STORAGE == "buffer":
27+
${layout_declare_ubo(2, "int", "numel")}
28+
$else:
29+
${layout_declare_ubo(2, "ivec3", "out_limits")}
30+
31+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
32+
33+
#include "activations.h"
34+
35+
#ifdef USING_BUFFER
36+
37+
void main() {
38+
const int i = int(gl_GlobalInvocationID.x);
39+
if (i >= numel) {
40+
return;
41+
}
42+
43+
float in_val = float(t_in[i]);
44+
t_out[i] = T(tan(in_val));
45+
}
46+
47+
#else
48+
49+
void main() {
50+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
51+
52+
if (any(greaterThanEqual(pos, out_limits))) {
53+
return;
54+
}
55+
56+
VEC4_T in_texel = texelFetch(t_in, pos, 0);
57+
imageStore(t_out, pos, VEC4_T(tan(in_texel)));
58+
}
59+
60+
#endif
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
tan:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
STORAGE: texture3d
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
STORAGE:
10+
- VALUE: texture3d
11+
- VALUE: buffer
12+
shader_variants:
13+
- NAME: tan
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
13+
14+
namespace vkcompute {
15+
16+
using namespace utils;
17+
18+
void resize_tan_node(
19+
ComputeGraph* graph,
20+
const std::vector<ArgGroup>& args,
21+
const std::vector<ValueRef>& extra_args) {
22+
(void)extra_args;
23+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
24+
vTensorPtr self = graph->get_tensor(args[1].refs[0]);
25+
26+
out->virtual_resize(self->sizes());
27+
}
28+
29+
void add_tan_node(ComputeGraph& graph, const ValueRef in, const ValueRef out) {
30+
std::string kernel_name = "tan";
31+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
32+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
33+
34+
vkapi::ParamsBindList ubos({});
35+
ubos.append({graph.logical_limits_ubo(out)});
36+
37+
graph.execute_nodes().emplace_back(new DispatchNode(
38+
graph,
39+
VK_KERNEL_FROM_STR(kernel_name),
40+
graph.create_global_wg_size(out),
41+
graph.create_local_wg_size(out),
42+
// Inputs and Outputs
43+
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
44+
// Shader params buffers
45+
ubos,
46+
// Push Constants
47+
{},
48+
// Specialization Constants
49+
{},
50+
// Resize Args
51+
{},
52+
// Resizing Logic
53+
resize_tan_node));
54+
}
55+
56+
void tan(ComputeGraph& graph, const std::vector<ValueRef>& args) {
57+
return add_tan_node(graph, args[0], args[1]);
58+
}
59+
60+
REGISTER_OPERATORS {
61+
VK_REGISTER_OP(aten.tan.default, tan);
62+
}
63+
64+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,22 @@ def get_unary_ops_inputs():
11641164
return test_suite
11651165

11661166

1167+
# separate test suite from unary_ops for learning purposes
1168+
@register_test_suite("aten.tan.default")
1169+
def get_tan_inputs():
1170+
test_suite = VkTestSuite(
1171+
[
1172+
(M1,),
1173+
(M1, M2),
1174+
(S1, M1, M2),
1175+
(S1, S2, S2, M2),
1176+
]
1177+
)
1178+
test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"]
1179+
test_suite.dtypes = ["at::kFloat"]
1180+
return test_suite
1181+
1182+
11671183
@register_test_suite("aten._native_batch_norm_legit_no_training.default")
11681184
def get_native_batch_norm_inputs():
11691185
Test = namedtuple(

0 commit comments

Comments
 (0)