forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRegistry.h
309 lines (267 loc) · 12 KB
/
Registry.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
#ifndef C10_UTIL_REGISTRY_H_
#define C10_UTIL_REGISTRY_H_
/**
* Simple registry implementation that uses static variables to
* register object creators during program initialization time.
*/
// NB: This Registry works poorly when you have other namespaces.
// Make all macro invocations from inside the at namespace.
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <functional>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include <c10/macros/Macros.h>
#include <c10/util/Type.h>
namespace c10 {
template <typename KeyType>
inline std::string KeyStrRepr(const KeyType& /*key*/) {
return "[key type printing not supported]";
}
template <>
inline std::string KeyStrRepr(const std::string& key) {
return key;
}
enum RegistryPriority {
REGISTRY_FALLBACK = 1,
REGISTRY_DEFAULT = 2,
REGISTRY_PREFERRED = 3,
};
/**
* @brief A template class that allows one to register classes by keys.
*
* The keys are usually a std::string specifying the name, but can be anything
* that can be used in a std::map.
*
* You should most likely not use the Registry class explicitly, but use the
* helper macros below to declare specific registries as well as registering
* objects.
*/
template <class SrcType, class ObjectPtrType, class... Args>
class Registry {
public:
typedef std::function<ObjectPtrType(Args...)> Creator;
Registry(bool warning = true)
: registry_(), priority_(), terminate_(true), warning_(warning) {}
void Register(
const SrcType& key,
Creator creator,
const RegistryPriority priority = REGISTRY_DEFAULT) {
std::lock_guard<std::mutex> lock(register_mutex_);
// The if statement below is essentially the same as the following line:
// CHECK_EQ(registry_.count(key), 0) << "Key " << key
// << " registered twice.";
// However, CHECK_EQ depends on google logging, and since registration is
// carried out at static initialization time, we do not want to have an
// explicit dependency on glog's initialization function.
if (registry_.count(key) != 0) {
auto cur_priority = priority_[key];
if (priority > cur_priority) {
#ifdef DEBUG
std::string warn_msg =
"Overwriting already registered item for key " + KeyStrRepr(key);
fprintf(stderr, "%s\n", warn_msg.c_str());
#endif
registry_[key] = creator;
priority_[key] = priority;
} else if (priority == cur_priority) {
std::string err_msg =
"Key already registered with the same priority: " + KeyStrRepr(key);
fprintf(stderr, "%s\n", err_msg.c_str());
if (terminate_) {
std::exit(1);
} else {
throw std::runtime_error(err_msg);
}
} else if (warning_) {
std::string warn_msg =
"Higher priority item already registered, skipping registration of " +
KeyStrRepr(key);
fprintf(stderr, "%s\n", warn_msg.c_str());
}
} else {
registry_[key] = creator;
priority_[key] = priority;
}
}
void Register(
const SrcType& key,
Creator creator,
const std::string& help_msg,
const RegistryPriority priority = REGISTRY_DEFAULT) {
Register(key, creator, priority);
help_message_[key] = help_msg;
}
inline bool Has(const SrcType& key) {
return (registry_.count(key) != 0);
}
ObjectPtrType Create(const SrcType& key, Args... args) {
if (registry_.count(key) == 0) {
// Returns nullptr if the key is not registered.
return nullptr;
}
return registry_[key](args...);
}
/**
* Returns the keys currently registered as a std::vector.
*/
std::vector<SrcType> Keys() const {
std::vector<SrcType> keys;
for (const auto& it : registry_) {
keys.push_back(it.first);
}
return keys;
}
inline const std::unordered_map<SrcType, std::string>& HelpMessage() const {
return help_message_;
}
const char* HelpMessage(const SrcType& key) const {
auto it = help_message_.find(key);
if (it == help_message_.end()) {
return nullptr;
}
return it->second.c_str();
}
// Used for testing, if terminate is unset, Registry throws instead of
// calling std::exit
void SetTerminate(bool terminate) {
terminate_ = terminate;
}
private:
std::unordered_map<SrcType, Creator> registry_;
std::unordered_map<SrcType, RegistryPriority> priority_;
bool terminate_;
const bool warning_;
std::unordered_map<SrcType, std::string> help_message_;
std::mutex register_mutex_;
C10_DISABLE_COPY_AND_ASSIGN(Registry);
};
template <class SrcType, class ObjectPtrType, class... Args>
class Registerer {
public:
explicit Registerer(
const SrcType& key,
Registry<SrcType, ObjectPtrType, Args...>* registry,
typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
const std::string& help_msg = "") {
registry->Register(key, creator, help_msg);
}
explicit Registerer(
const SrcType& key,
const RegistryPriority priority,
Registry<SrcType, ObjectPtrType, Args...>* registry,
typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
const std::string& help_msg = "") {
registry->Register(key, creator, help_msg, priority);
}
template <class DerivedType>
static ObjectPtrType DefaultCreator(Args... args) {
return ObjectPtrType(new DerivedType(args...));
}
};
/**
* C10_DECLARE_TYPED_REGISTRY is a macro that expands to a function
* declaration, as well as creating a convenient typename for its corresponding
* registerer.
*/
// Note on C10_IMPORT and C10_EXPORT below: we need to explicitly mark DECLARE
// as import and DEFINE as export, because these registry macros will be used
// in downstream shared libraries as well, and one cannot use *_API - the API
// macro will be defined on a per-shared-library basis. Semantically, when one
// declares a typed registry it is always going to be IMPORT, and when one
// defines a registry (which should happen ONLY ONCE and ONLY IN SOURCE FILE),
// the instantiation unit is always going to be exported.
//
// The only unique condition is when in the same file one does DECLARE and
// DEFINE - in Windows compilers, this generates a warning that dllimport and
// dllexport are mixed, but the warning is fine and linker will be properly
// exporting the symbol. Same thing happens in the gflags flag declaration and
// definition caes.
#define C10_DECLARE_TYPED_REGISTRY( \
RegistryName, SrcType, ObjectType, PtrType, ...) \
C10_IMPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
RegistryName(); \
typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__> \
Registerer##RegistryName
#define C10_DEFINE_TYPED_REGISTRY( \
RegistryName, SrcType, ObjectType, PtrType, ...) \
C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
RegistryName() { \
static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
registry = new ::c10:: \
Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>(); \
return registry; \
}
#define C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
RegistryName, SrcType, ObjectType, PtrType, ...) \
C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
RegistryName() { \
static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
registry = \
new ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>( \
false); \
return registry; \
}
// Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated
// creator with comma in its templated arguments.
#define C10_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \
static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, RegistryName(), ##__VA_ARGS__);
#define C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \
RegistryName, key, priority, ...) \
static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, priority, RegistryName(), ##__VA_ARGS__);
#define C10_REGISTER_TYPED_CLASS(RegistryName, key, ...) \
static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, \
RegistryName(), \
Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
::c10::demangle_type<__VA_ARGS__>());
#define C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \
RegistryName, key, priority, ...) \
static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, \
priority, \
RegistryName(), \
Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
::c10::demangle_type<__VA_ARGS__>());
// C10_DECLARE_REGISTRY and C10_DEFINE_REGISTRY are hard-wired to use
// std::string as the key type, because that is the most commonly used cases.
#define C10_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
C10_DECLARE_TYPED_REGISTRY( \
RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
#define C10_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
C10_DEFINE_TYPED_REGISTRY( \
RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
#define C10_DEFINE_REGISTRY_WITHOUT_WARNING(RegistryName, ObjectType, ...) \
C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
#define C10_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
C10_DECLARE_TYPED_REGISTRY( \
RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
#define C10_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
C10_DEFINE_TYPED_REGISTRY( \
RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
#define C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( \
RegistryName, ObjectType, ...) \
C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
// C10_REGISTER_CREATOR and C10_REGISTER_CLASS are hard-wired to use std::string
// as the key
// type, because that is the most commonly used cases.
#define C10_REGISTER_CREATOR(RegistryName, key, ...) \
C10_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__)
#define C10_REGISTER_CREATOR_WITH_PRIORITY(RegistryName, key, priority, ...) \
C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \
RegistryName, #key, priority, __VA_ARGS__)
#define C10_REGISTER_CLASS(RegistryName, key, ...) \
C10_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
#define C10_REGISTER_CLASS_WITH_PRIORITY(RegistryName, key, priority, ...) \
C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \
RegistryName, #key, priority, __VA_ARGS__)
} // namespace c10
#endif // C10_UTIL_REGISTRY_H_