Skip to content

[ET-VK] Enable dynamic operator registration #2305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,17 @@ namespace native {
namespace vulkan {

bool OperatorRegistry::has_op(const std::string& name) {
return OperatorRegistry::kTable.count(name) > 0;
return table_.count(name) > 0;
}

OperatorRegistry::OpFunction& OperatorRegistry::get_op_fn(
const std::string& name) {
return OperatorRegistry::kTable.find(name)->second;
return table_.find(name)->second;
}

// @lint-ignore-every CLANGTIDY modernize-avoid-bind
// clang-format off
#define OPERATOR_ENTRY(name, function) \
{ #name, std::bind(&function, std::placeholders::_1, std::placeholders::_2) }
// clang-format on

const OperatorRegistry::OpTable OperatorRegistry::kTable = {
OPERATOR_ENTRY(aten.add.Tensor, add),
OPERATOR_ENTRY(aten.sub.Tensor, sub),
OPERATOR_ENTRY(aten.mul.Tensor, mul),
OPERATOR_ENTRY(aten.div.Tensor, div),
OPERATOR_ENTRY(aten.div.Tensor_mode, floor_div),
OPERATOR_ENTRY(aten.pow.Tensor_Tensor, pow),
};
void OperatorRegistry::register_op(const std::string& name, OpFunction& fn) {
table_.insert(std::make_pair(name, fn));
}

OperatorRegistry& operator_registry() {
static OperatorRegistry registry;
Expand Down
26 changes: 25 additions & 1 deletion backends/vulkan/runtime/graph/ops/OperatorRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
#define VK_GET_OP_FN(name) \
::at::native::vulkan::operator_registry().get_op_fn(name)

#define VK_REGISTER_OP(name, function) \
::at::native::vulkan::operator_registry().register_op( \
#name, \
std::bind(&function, std::placeholders::_1, std::placeholders::_2))

#define REGISTER_OPERATORS \
static void register_ops(); \
static const OperatorRegisterInit reg(&register_ops); \
static void register_ops()

namespace at {
namespace native {
namespace vulkan {
Expand All @@ -35,7 +45,7 @@ class OperatorRegistry final {
const std::function<void(ComputeGraph&, const std::vector<ValueRef>&)>;
using OpTable = std::unordered_map<std::string, OpFunction>;

static const OpTable kTable;
OpTable table_;

public:
/*
Expand All @@ -47,6 +57,20 @@ class OperatorRegistry final {
* Given an operator name, return the Vulkan delegate function
*/
OpFunction& get_op_fn(const std::string& name);

/*
* Register a function to a given operator name
*/
void register_op(const std::string& name, OpFunction& fn);
};

class OperatorRegisterInit final {
using InitFn = void();

public:
explicit OperatorRegisterInit(InitFn* init_fn) {
init_fn();
}
};

// The Vulkan operator registry is global. It is retrieved using this function,
Expand Down
3 changes: 0 additions & 3 deletions backends/vulkan/runtime/graph/ops/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ namespace at {
namespace native {
namespace vulkan {

#define DECLARE_OP_FN(function) \
void function(ComputeGraph& graph, const std::vector<ValueRef>& args);

api::utils::ivec4 get_size_as_ivec4(const vTensor& t);

void bind_tensor_to_descriptor_set(
Expand Down
10 changes: 10 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h>

#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

Expand Down Expand Up @@ -81,6 +82,15 @@ void add_arithmetic_node(
std::move(params)));
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.add.Tensor, add);
VK_REGISTER_OP(aten.sub.Tensor, sub);
VK_REGISTER_OP(aten.mul.Tensor, mul);
VK_REGISTER_OP(aten.div.Tensor, div);
VK_REGISTER_OP(aten.div.Tensor_mode, floor_div);
VK_REGISTER_OP(aten.pow.Tensor_Tensor, pow);
}

} // namespace vulkan
} // namespace native
} // namespace at
7 changes: 0 additions & 7 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,6 @@ namespace at {
namespace native {
namespace vulkan {

DECLARE_OP_FN(add);
DECLARE_OP_FN(sub);
DECLARE_OP_FN(mul);
DECLARE_OP_FN(div);
DECLARE_OP_FN(floor_div);
DECLARE_OP_FN(pow);

void add_arithmetic_node(
ComputeGraph& graph,
const ValueRef in1,
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def define_common_targets():
"//caffe2:torch_vulkan_spv",
],
define_static_target = False,
# Static initialization is used to register operators to the global operator registry,
# therefore link_whole must be True to make sure unused symbols are not discarded.
# @lint-ignore BUCKLINT: Avoid `link_whole=True`
link_whole = True,
)

runtime.cxx_library(
Expand All @@ -81,4 +85,6 @@ def define_common_targets():
# VulkanBackend.cpp needs to compile with executor as whole
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
link_whole = True,
# Define an soname that can be used for dynamic loading in Java, Python, etc.
soname = "libvulkan_graph_runtime.$(ext)",
)
20 changes: 13 additions & 7 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <ATen/native/vulkan/api/api.h>

#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h>
Expand Down Expand Up @@ -585,8 +587,8 @@ TEST(VulkanComputeGraphTest, test_simple_graph) {

out.value = graph.add_tensor(size_big, api::kFloat);

add_arithmetic_node(
graph, a.value, b.value, kDummyValueRef, out.value, VK_KERNEL(add));
auto addFn = VK_GET_OP_FN("aten.add.Tensor");
addFn(graph, {a.value, b.value, kDummyValueRef, out.value});

out.staging = graph.set_output_tensor(out.value);

Expand Down Expand Up @@ -636,8 +638,11 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) {
ValueRef c = graph.add_tensor(size_big, api::kFloat);
ValueRef e = graph.add_tensor(size_big, api::kFloat);

add_arithmetic_node(graph, a.value, w1, kDummyValueRef, c, VK_KERNEL(add));
add_arithmetic_node(graph, c, w2, kDummyValueRef, e, VK_KERNEL(mul));
auto addFn = VK_GET_OP_FN("aten.add.Tensor");
addFn(graph, {a.value, w1, kDummyValueRef, c});

auto mulFn = VK_GET_OP_FN("aten.mul.Tensor");
mulFn(graph, {c, w2, e});

IOValueRef out = {};
out.value = e;
Expand Down Expand Up @@ -697,8 +702,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) {
api::kFloat,
/*shared_object_idx = */ 6);

add_arithmetic_node(
graph, a.value, b.value, kDummyValueRef, c, VK_KERNEL(add));
auto addFn = VK_GET_OP_FN("aten.add.Tensor");
addFn(graph, {a.value, b.value, kDummyValueRef, c});

IOValueRef d = graph.add_input_tensor(
size_small,
Expand All @@ -716,7 +721,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) {
api::kFloat,
/*shared_object_idx = */ 4);

add_arithmetic_node(graph, c, d.value, kDummyValueRef, e, VK_KERNEL(mul));
auto mulFn = VK_GET_OP_FN("aten.mul.Tensor");
mulFn(graph, {c, d.value, e});

IOValueRef out = {};
out.value = e;
Expand Down