Skip to content

Commit 9c0a23b

Browse files
committed
[ET-VK][EZ] Clean up OperatorRegistry
Align `OperatorRegistry` with the style of `ShaderRegistry` in #2222 This means - Improve comments and comment formatting. - Use snake case, even if it deviates from the original registry I was following. Snake case is more consistent with the Vulkan backend code. https://www.internalfb.com/code/fbsource/[a97f9ed1a715231bb61b05942273f1e8f8631503]/fbcode/executorch/runtime/kernel/operator_registry.h?lines=208%2C213 - Move `using` declarations and member variables to top of class definition. - Place static `OperatorRegistry` instance declaration in a global function `operator_registry()` instead of in member function `getInstance()`. - Use macros to wrap `OperatorRegistry` functions instead of global functions. - For simplicity, remove unneeded ctor and assignment operator deletion/hiding. Note users can now create their own non-static `OperatorRegistry` instance and we can consider hiding this again later. Differential Revision: [D54640160](https://our.internmc.facebook.com/intern/diff/D54640160/) [ghstack-poisoned]
1 parent 47b837b commit 9c0a23b

File tree

3 files changed

+40
-45
lines changed

3 files changed

+40
-45
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ class GraphBuilder {
173173
// Parse the operators
174174
for (OpCallPtr op_call : *(flatbuffer_->chain())) {
175175
std::string op_name = op_call->name()->str();
176-
ET_CHECK_MSG(hasOpsFn(op_name), "Missing operator: %s", op_name.c_str());
176+
ET_CHECK_MSG(VK_HAS_OP(op_name), "Missing operator: %s", op_name.c_str());
177177

178178
const std::vector<int> arg_fb_ids(
179179
op_call->args()->cbegin(), op_call->args()->cend());
@@ -183,7 +183,7 @@ class GraphBuilder {
183183
args.push_back(get_fb_id_valueref(arg_fb_id));
184184
}
185185

186-
auto vkFn = getOpsFn(op_name);
186+
auto vkFn = VK_GET_OP_FN(op_name);
187187
vkFn(*compute_graph_, args);
188188
}
189189

backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,19 @@ namespace at {
1414
namespace native {
1515
namespace vulkan {
1616

17-
bool hasOpsFn(const std::string& name) {
18-
return OperatorRegistry::getInstance().hasOpsFn(name);
19-
}
20-
21-
OpFunction& getOpsFn(const std::string& name) {
22-
return OperatorRegistry::getInstance().getOpsFn(name);
23-
}
24-
25-
OperatorRegistry& OperatorRegistry::getInstance() {
26-
static OperatorRegistry instance;
27-
return instance;
28-
}
29-
30-
bool OperatorRegistry::hasOpsFn(const std::string& name) {
17+
bool OperatorRegistry::has_op(const std::string& name) {
3118
return OperatorRegistry::kTable.count(name) > 0;
3219
}
3320

34-
OpFunction& OperatorRegistry::getOpsFn(const std::string& name) {
21+
OperatorRegistry::OpFunction& OperatorRegistry::get_op_fn(
22+
const std::string& name) {
3523
return OperatorRegistry::kTable.find(name)->second;
3624
}
3725

3826
// @lint-ignore-every CLANGTIDY modernize-avoid-bind
3927
// clang-format off
4028
#define OPERATOR_ENTRY(name, function) \
41-
{ #name, std::bind(&at::native::vulkan::function, std::placeholders::_1, std::placeholders::_2) }
29+
{ #name, std::bind(&function, std::placeholders::_1, std::placeholders::_2) }
4230
// clang-format on
4331

4432
const OperatorRegistry::OpTable OperatorRegistry::kTable = {
@@ -50,6 +38,11 @@ const OperatorRegistry::OpTable OperatorRegistry::kTable = {
5038
OPERATOR_ENTRY(aten.pow.Tensor_Tensor, pow),
5139
};
5240

41+
OperatorRegistry& operator_registry() {
42+
static OperatorRegistry registry;
43+
return registry;
44+
}
45+
5346
} // namespace vulkan
5447
} // namespace native
5548
} // namespace at

backends/vulkan/runtime/graph/ops/OperatorRegistry.h

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,42 +15,44 @@
1515
#include <functional>
1616
#include <unordered_map>
1717

18+
#define VK_HAS_OP(name) ::at::native::vulkan::operator_registry().has_op(name)
19+
20+
#define VK_GET_OP_FN(name) \
21+
::at::native::vulkan::operator_registry().get_op_fn(name)
22+
1823
namespace at {
1924
namespace native {
2025
namespace vulkan {
2126

22-
using OpFunction =
23-
const std::function<void(ComputeGraph&, const std::vector<ValueRef>&)>;
24-
25-
bool hasOpsFn(const std::string& name);
26-
27-
OpFunction& getOpsFn(const std::string& name);
28-
29-
// The Vulkan operator registry is a simplified version of
30-
// fbcode/executorch/runtime/kernel/operator_registry.h
31-
// that uses the C++ Standard Library.
32-
class OperatorRegistry {
33-
public:
34-
static OperatorRegistry& getInstance();
35-
36-
bool hasOpsFn(const std::string& name);
37-
OpFunction& getOpsFn(const std::string& name);
38-
39-
OperatorRegistry(const OperatorRegistry&) = delete;
40-
OperatorRegistry(OperatorRegistry&&) = delete;
41-
OperatorRegistry& operator=(const OperatorRegistry&) = delete;
42-
OperatorRegistry& operator=(OperatorRegistry&&) = delete;
43-
44-
private:
45-
// TODO: Input string corresponds to target_name. We may need to pass kwargs.
27+
/*
28+
* The Vulkan operator registry maps ATen operator names to their Vulkan
29+
* delegate function implementation. It is a simplified version of
30+
* executorch/runtime/kernel/operator_registry.h that uses the C++ Standard
31+
* Library.
32+
*/
33+
class OperatorRegistry final {
34+
using OpFunction =
35+
const std::function<void(ComputeGraph&, const std::vector<ValueRef>&)>;
4636
using OpTable = std::unordered_map<std::string, OpFunction>;
47-
// @lint-ignore CLANGTIDY facebook-hte-NonPodStaticDeclaration
37+
4838
static const OpTable kTable;
4939

50-
OperatorRegistry() = default;
51-
~OperatorRegistry() = default;
40+
public:
41+
/*
42+
* Check if the registry has an operator registered under the given name
43+
*/
44+
bool has_op(const std::string& name);
45+
46+
/*
47+
* Given an operator name, return the Vulkan delegate function
48+
*/
49+
OpFunction& get_op_fn(const std::string& name);
5250
};
5351

52+
// The Vulkan operator registry is global. It is retrieved using this function,
53+
// where it is declared as a static local variable.
54+
OperatorRegistry& operator_registry();
55+
5456
} // namespace vulkan
5557
} // namespace native
5658
} // namespace at

0 commit comments

Comments
 (0)