Skip to content

Commit c5b5ad6

Browse files
committed
[ET-VK] Enable dynamic operator registration
Pull Request resolved: #2305 This change follows 1. in the footsteps of #2222 for static initialization and 2. the popular `TorchLibraryImpl` for wrapping with macros. https://www.internalfb.com/code/fbsource/[b6860acf0fd7a95224f2ed3f6fe48f699a9a45c0]/fbcode/caffe2/torch/library.h?lines=1004%2C1012-1026 Contributors can now write their operator and register them within the same file using `REGISTER_OPERATORS` + `VK_REGISTER_OP()`, as shown in `Arithmetic.h/cpp`. Typically in Linux/Android C++ environments, the symbols corresponding to `OperatorRegisterInit` static instances are discarded since they aren't used for anything other than static initialization. Hence, we need to `link_whole = True` for the `vulkan_graph_runtime` library. We update our Compute API tests to verify we can go through `OperatorRegistry` with proper static initialization. ghstack-source-id: 217884083 @exported-using-ghexport Differential Revision: [D54641117](https://our.internmc.facebook.com/intern/diff/D54641117/)
1 parent 4b3594b commit c5b5ad6

File tree

7 files changed

+59
-34
lines changed

7 files changed

+59
-34
lines changed

backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,17 @@ namespace native {
1515
namespace vulkan {
1616

1717
bool OperatorRegistry::has_op(const std::string& name) {
18-
return OperatorRegistry::kTable.count(name) > 0;
18+
return table_.count(name) > 0;
1919
}
2020

2121
OperatorRegistry::OpFunction& OperatorRegistry::get_op_fn(
2222
const std::string& name) {
23-
return OperatorRegistry::kTable.find(name)->second;
23+
return table_.find(name)->second;
2424
}
2525

26-
// @lint-ignore-every CLANGTIDY modernize-avoid-bind
27-
// clang-format off
28-
#define OPERATOR_ENTRY(name, function) \
29-
{ #name, std::bind(&function, std::placeholders::_1, std::placeholders::_2) }
30-
// clang-format on
31-
32-
const OperatorRegistry::OpTable OperatorRegistry::kTable = {
33-
OPERATOR_ENTRY(aten.add.Tensor, add),
34-
OPERATOR_ENTRY(aten.sub.Tensor, sub),
35-
OPERATOR_ENTRY(aten.mul.Tensor, mul),
36-
OPERATOR_ENTRY(aten.div.Tensor, div),
37-
OPERATOR_ENTRY(aten.div.Tensor_mode, floor_div),
38-
OPERATOR_ENTRY(aten.pow.Tensor_Tensor, pow),
39-
};
26+
void OperatorRegistry::register_op(const std::string& name, OpFunction& fn) {
27+
table_.insert(std::make_pair(name, fn));
28+
}
4029

4130
OperatorRegistry& operator_registry() {
4231
static OperatorRegistry registry;

backends/vulkan/runtime/graph/ops/OperatorRegistry.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@
2020
#define VK_GET_OP_FN(name) \
2121
::at::native::vulkan::operator_registry().get_op_fn(name)
2222

23+
#define VK_REGISTER_OP(name, function) \
24+
::at::native::vulkan::operator_registry().register_op( \
25+
#name, \
26+
std::bind(&function, std::placeholders::_1, std::placeholders::_2))
27+
28+
#define REGISTER_OPERATORS \
29+
static void register_ops(); \
30+
static const OperatorRegisterInit reg(&register_ops); \
31+
static void register_ops()
32+
2333
namespace at {
2434
namespace native {
2535
namespace vulkan {
@@ -35,7 +45,7 @@ class OperatorRegistry final {
3545
const std::function<void(ComputeGraph&, const std::vector<ValueRef>&)>;
3646
using OpTable = std::unordered_map<std::string, OpFunction>;
3747

38-
static const OpTable kTable;
48+
OpTable table_;
3949

4050
public:
4151
/*
@@ -47,6 +57,20 @@ class OperatorRegistry final {
4757
* Given an operator name, return the Vulkan delegate function
4858
*/
4959
OpFunction& get_op_fn(const std::string& name);
60+
61+
/*
62+
* Register a function to a given operator name
63+
*/
64+
void register_op(const std::string& name, OpFunction& fn);
65+
};
66+
67+
class OperatorRegisterInit final {
68+
using InitFn = void();
69+
70+
public:
71+
explicit OperatorRegisterInit(InitFn* init_fn) {
72+
init_fn();
73+
}
5074
};
5175

5276
// The Vulkan operator registry is global. It is retrieved using this function,

backends/vulkan/runtime/graph/ops/Utils.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ namespace at {
1616
namespace native {
1717
namespace vulkan {
1818

19-
#define DECLARE_OP_FN(function) \
20-
void function(ComputeGraph& graph, const std::vector<ValueRef>& args);
21-
2219
api::utils::ivec4 get_size_as_ivec4(const vTensor& t);
2320

2421
void bind_tensor_to_descriptor_set(

backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h>
1010

1111
#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1213

1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1415

@@ -81,6 +82,15 @@ void add_arithmetic_node(
8182
std::move(params)));
8283
}
8384

85+
REGISTER_OPERATORS {
86+
VK_REGISTER_OP(aten.add.Tensor, add);
87+
VK_REGISTER_OP(aten.sub.Tensor, sub);
88+
VK_REGISTER_OP(aten.mul.Tensor, mul);
89+
VK_REGISTER_OP(aten.div.Tensor, div);
90+
VK_REGISTER_OP(aten.div.Tensor_mode, floor_div);
91+
VK_REGISTER_OP(aten.pow.Tensor_Tensor, pow);
92+
}
93+
8494
} // namespace vulkan
8595
} // namespace native
8696
} // namespace at

backends/vulkan/runtime/graph/ops/impl/Arithmetic.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,6 @@ namespace at {
1818
namespace native {
1919
namespace vulkan {
2020

21-
DECLARE_OP_FN(add);
22-
DECLARE_OP_FN(sub);
23-
DECLARE_OP_FN(mul);
24-
DECLARE_OP_FN(div);
25-
DECLARE_OP_FN(floor_div);
26-
DECLARE_OP_FN(pow);
27-
2821
void add_arithmetic_node(
2922
ComputeGraph& graph,
3023
const ValueRef in1,

backends/vulkan/targets.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def define_common_targets():
5656
"//caffe2:torch_vulkan_spv",
5757
],
5858
define_static_target = False,
59+
# Static initialization is used to register operators to the global operator registry,
60+
# therefore link_whole must be True to make sure unused symbols are not discarded.
61+
# @lint-ignore BUCKLINT: Avoid `link_whole=True`
62+
link_whole = True,
5963
)
6064

6165
runtime.cxx_library(
@@ -81,4 +85,6 @@ def define_common_targets():
8185
# VulkanBackend.cpp needs to compile with executor as whole
8286
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
8387
link_whole = True,
88+
# Define an soname that can be used for dynamic loading in Java, Python, etc.
89+
soname = "libvulkan_graph_runtime.$(ext)",
8490
)

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include <ATen/native/vulkan/api/api.h>
1212

1313
#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
14+
15+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1416
#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
1517

1618
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h>
@@ -585,8 +587,8 @@ TEST(VulkanComputeGraphTest, test_simple_graph) {
585587

586588
out.value = graph.add_tensor(size_big, api::kFloat);
587589

588-
add_arithmetic_node(
589-
graph, a.value, b.value, kDummyValueRef, out.value, VK_KERNEL(add));
590+
auto addFn = VK_GET_OP_FN("aten.add.Tensor");
591+
addFn(graph, {a.value, b.value, kDummyValueRef, out.value});
590592

591593
out.staging = graph.set_output_tensor(out.value);
592594

@@ -636,8 +638,11 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) {
636638
ValueRef c = graph.add_tensor(size_big, api::kFloat);
637639
ValueRef e = graph.add_tensor(size_big, api::kFloat);
638640

639-
add_arithmetic_node(graph, a.value, w1, kDummyValueRef, c, VK_KERNEL(add));
640-
add_arithmetic_node(graph, c, w2, kDummyValueRef, e, VK_KERNEL(mul));
641+
auto addFn = VK_GET_OP_FN("aten.add.Tensor");
642+
addFn(graph, {a.value, w1, kDummyValueRef, c});
643+
644+
auto mulFn = VK_GET_OP_FN("aten.mul.Tensor");
645+
mulFn(graph, {c, w2, e});
641646

642647
IOValueRef out = {};
643648
out.value = e;
@@ -697,8 +702,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) {
697702
api::kFloat,
698703
/*shared_object_idx = */ 6);
699704

700-
add_arithmetic_node(
701-
graph, a.value, b.value, kDummyValueRef, c, VK_KERNEL(add));
705+
auto addFn = VK_GET_OP_FN("aten.add.Tensor");
706+
addFn(graph, {a.value, b.value, kDummyValueRef, c});
702707

703708
IOValueRef d = graph.add_input_tensor(
704709
size_small,
@@ -716,7 +721,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) {
716721
api::kFloat,
717722
/*shared_object_idx = */ 4);
718723

719-
add_arithmetic_node(graph, c, d.value, kDummyValueRef, e, VK_KERNEL(mul));
724+
auto mulFn = VK_GET_OP_FN("aten.mul.Tensor");
725+
mulFn(graph, {c, d.value, e});
720726

721727
IOValueRef out = {};
722728
out.value = e;

0 commit comments

Comments
 (0)