Skip to content

Commit 67179f7

Browse files
committed
Enable Alias, refactor C API to reflect Op Semantics (apache#41)
* Enable Alias, refactor C API to reflect Op Semantics * add alias example
1 parent 410062d commit 67179f7

File tree

14 files changed

+184
-95
lines changed

14 files changed

+184
-95
lines changed

nnvm/example/src/operator.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ NNVM_REGISTER_OP(identity)
127127
NNVM_REGISTER_OP(add)
128128
.describe("add two data together")
129129
.set_num_inputs(2)
130+
.add_alias("__add_symbol__")
130131
.attr<FInferShape>("FInferShape", SameShape)
131132
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
132133
.attr<FGradient>(

nnvm/include/dmlc/registry.h

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,19 @@ namespace dmlc {
2626
template<typename EntryType>
2727
class Registry {
2828
public:
29-
/*! \return list of functions in the registry */
30-
inline static const std::vector<const EntryType*> &List() {
31-
return Get()->entry_list_;
29+
/*! \return list of entries in the registry(excluding alias) */
30+
inline static const std::vector<const EntryType*>& List() {
31+
return Get()->const_list_;
32+
}
33+
/*! \return list all names registered in the registry, including alias */
34+
inline static std::vector<std::string> ListAllNames() {
35+
const std::map<std::string, EntryType*> &fmap = Get()->fmap_;
36+
typename std::map<std::string, EntryType*>::const_iterator p;
37+
std::vector<std::string> names;
38+
for (p = fmap.begin(); p !=fmap.end(); ++p) {
39+
names.push_back(p->first);
40+
}
41+
return names;
3242
}
3343
/*!
3444
* \brief Find the entry with corresponding name.
@@ -44,6 +54,21 @@ class Registry {
4454
return NULL;
4555
}
4656
}
57+
/*!
58+
* \brief Add alias to the key_name
59+
* \param key_name The original entry key
60+
* \param alias The alias key.
61+
*/
62+
inline void AddAlias(const std::string& key_name,
63+
const std::string& alias) {
64+
EntryType* e = fmap_.at(key_name);
65+
if (fmap_.count(alias)) {
66+
CHECK_EQ(e, fmap_.at(alias))
67+
<< "Entry " << e->name << " already registered under different entry";
68+
} else {
69+
fmap_[alias] = e;
70+
}
71+
}
4772
/*!
4873
* \brief Internal function to register a name function under name.
4974
* \param name name of the function
@@ -55,6 +80,7 @@ class Registry {
5580
EntryType *e = new EntryType();
5681
e->name = name;
5782
fmap_[name] = e;
83+
const_list_.push_back(e);
5884
entry_list_.push_back(e);
5985
return *e;
6086
}
@@ -79,16 +105,17 @@ class Registry {
79105

80106
private:
81107
/*! \brief list of entry types */
82-
std::vector<const EntryType*> entry_list_;
108+
std::vector<EntryType*> entry_list_;
109+
/*! \brief list of entry types */
110+
std::vector<const EntryType*> const_list_;
83111
/*! \brief map of name->function */
84112
std::map<std::string, EntryType*> fmap_;
85113
/*! \brief constructor */
86114
Registry() {}
87115
/*! \brief destructor */
88116
~Registry() {
89-
for (typename std::map<std::string, EntryType*>::iterator p = fmap_.begin();
90-
p != fmap_.end(); ++p) {
91-
delete p->second;
117+
for (size_t i = 0; i < entry_list_.size(); ++i) {
118+
delete entry_list_[i];
92119
}
93120
}
94121
};

nnvm/include/nnvm/c_api.h

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
typedef unsigned int nn_uint;
3030

3131
/*! \brief handle to a function that takes param and creates symbol */
32-
typedef void *AtomicSymbolCreator;
32+
typedef void *OpHandle;
3333
/*! \brief handle to a symbol that can be bind as operator */
3434
typedef void *SymbolHandle;
3535
/*! \brief handle to Graph */
@@ -53,17 +53,39 @@ NNVM_DLL void NNAPISetLastError(const char* msg);
5353
NNVM_DLL const char *NNGetLastError(void);
5454

5555
/*!
56-
* \brief list all the available AtomicSymbolEntry
56+
* \brief list all the available operator names, include entries.
57+
* \param out_size the size of returned array
58+
* \param out_array the output operator name array.
59+
* \return 0 when success, -1 when failure happens
60+
*/
61+
NNVM_DLL int NNListAllOpNames(nn_uint *out_size,
62+
const char*** out_array);
63+
64+
/*!
65+
* \brief Get operator handle given name.
66+
* \param op_name The name of the operator.
67+
* \param op_out The returnning op handle.
68+
*/
69+
NNVM_DLL int NNGetOpHandle(const char* op_name,
70+
OpHandle* op_out);
71+
72+
/*!
73+
* \brief list all the available operators.
74+
* This won't include the alias, use ListAllNames
75+
* instead to get all alias names.
76+
*
5777
* \param out_size the size of returned array
5878
* \param out_array the output AtomicSymbolCreator array
5979
* \return 0 when success, -1 when failure happens
6080
*/
61-
NNVM_DLL int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
62-
AtomicSymbolCreator **out_array);
81+
NNVM_DLL int NNListUniqueOps(nn_uint *out_size,
82+
OpHandle **out_array);
83+
6384
/*!
6485
* \brief Get the detailed information about atomic symbol.
65-
* \param creator the AtomicSymbolCreator.
66-
* \param name The returned name of the creator.
86+
* \param op The operator handle.
87+
* \param real_name The returned name of the creator.
88+
* This name is not the alias name of the atomic symbol.
6789
* \param description The returned description of the symbol.
6890
* \param num_doc_args Number of arguments that contain documents.
6991
* \param arg_names Name of the arguments of doc args
@@ -72,24 +94,24 @@ NNVM_DLL int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
7294
* \param return_type Return type of the function, if any.
7395
* \return 0 when success, -1 when failure happens
7496
*/
75-
NNVM_DLL int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
76-
const char **name,
77-
const char **description,
78-
nn_uint *num_doc_args,
79-
const char ***arg_names,
80-
const char ***arg_type_infos,
81-
const char ***arg_descriptions,
82-
const char **return_type);
97+
NNVM_DLL int NNGetOpInfo(OpHandle op,
98+
const char **real_name,
99+
const char **description,
100+
nn_uint *num_doc_args,
101+
const char ***arg_names,
102+
const char ***arg_type_infos,
103+
const char ***arg_descriptions,
104+
const char **return_type);
83105
/*!
84106
* \brief Create an AtomicSymbol functor.
85-
* \param creator the AtomicSymbolCreator
107+
* \param op The operator handle
86108
* \param num_param the number of parameters
87109
* \param keys the keys to the params
88110
* \param vals the vals of the params
89111
* \param out pointer to the created symbol handle
90112
* \return 0 when success, -1 when failure happens
91113
*/
92-
NNVM_DLL int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
114+
NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op,
93115
nn_uint num_param,
94116
const char **keys,
95117
const char **vals,

nnvm/include/nnvm/op.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,13 @@ class Op {
199199
* \return reference to self.
200200
*/
201201
inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
202+
/*!
203+
* \brief Add another alias to this operator.
204+
* The same Op can be queried with Op::Get(alias)
205+
* \param alias The alias of the operator.
206+
* \return reference to self.
207+
*/
208+
Op& add_alias(const std::string& alias); // NOLINT(*)
202209
/*!
203210
* \brief Register additional attributes to operator.
204211
* \param attr_name The name of the attribute.

nnvm/python/nnvm/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _load_lib():
4545

4646
# type definitions
4747
nn_uint = ctypes.c_uint
48-
SymbolCreatorHandle = ctypes.c_void_p
48+
OpHandle = ctypes.c_void_p
4949
SymbolHandle = ctypes.c_void_p
5050
GraphHandle = ctypes.c_void_p
5151

File renamed without changes.
File renamed without changes.

nnvm/python/nnvm/ctypes/symbol.py renamed to nnvm/python/nnvm/_ctypes/symbol.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import sys
99
from .._base import _LIB
1010
from .._base import c_array, c_str, nn_uint, py_str, string_types
11-
from .._base import SymbolHandle
11+
from .._base import SymbolHandle, OpHandle
1212
from .._base import check_call, ctypes2docstring
1313
from ..name import NameManager
1414
from ..attribute import AttrScope
@@ -114,25 +114,25 @@ def _set_symbol_class(cls):
114114
_symbol_cls = cls
115115

116116

117-
def _make_atomic_symbol_function(handle):
117+
def _make_atomic_symbol_function(handle, name):
118118
"""Create an atomic symbol function by handle and funciton name."""
119-
name = ctypes.c_char_p()
119+
real_name = ctypes.c_char_p()
120120
desc = ctypes.c_char_p()
121121
num_args = nn_uint()
122122
arg_names = ctypes.POINTER(ctypes.c_char_p)()
123123
arg_types = ctypes.POINTER(ctypes.c_char_p)()
124124
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
125125
ret_type = ctypes.c_char_p()
126126

127-
check_call(_LIB.NNSymbolGetAtomicSymbolInfo(
128-
handle, ctypes.byref(name), ctypes.byref(desc),
127+
check_call(_LIB.NNGetOpInfo(
128+
handle, ctypes.byref(real_name), ctypes.byref(desc),
129129
ctypes.byref(num_args),
130130
ctypes.byref(arg_names),
131131
ctypes.byref(arg_types),
132132
ctypes.byref(arg_descs),
133133
ctypes.byref(ret_type)))
134134
param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)
135-
func_name = py_str(name.value)
135+
func_name = name
136136
desc = py_str(desc.value)
137137

138138
doc_str = ('%s\n\n' +
@@ -199,22 +199,25 @@ def creator(*args, **kwargs):
199199
return creator
200200

201201

202-
def _init_symbol_module():
202+
def _init_symbol_module(symbol_class, root_namespace):
203203
"""List and add all the atomic symbol functions to current module."""
204-
plist = ctypes.POINTER(ctypes.c_void_p)()
204+
_set_symbol_class(symbol_class)
205+
plist = ctypes.POINTER(ctypes.c_char_p)()
205206
size = ctypes.c_uint()
206207

207-
check_call(_LIB.NNSymbolListAtomicSymbolCreators(ctypes.byref(size),
208-
ctypes.byref(plist)))
209-
module_obj = sys.modules["nnvm.symbol"]
210-
module_internal = sys.modules["nnvm._symbol_internal"]
208+
check_call(_LIB.NNListAllOpNames(ctypes.byref(size),
209+
ctypes.byref(plist)))
210+
op_names = []
211211
for i in range(size.value):
212-
hdl = SymbolHandle(plist[i])
213-
function = _make_atomic_symbol_function(hdl)
212+
op_names.append(py_str(plist[i]))
213+
214+
module_obj = sys.modules["%s.symbol" % root_namespace]
215+
module_internal = sys.modules["%s._symbol_internal" % root_namespace]
216+
for name in op_names:
217+
hdl = OpHandle()
218+
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
219+
function = _make_atomic_symbol_function(hdl, name)
214220
if function.__name__.startswith('_'):
215221
setattr(module_internal, function.__name__, function)
216222
else:
217223
setattr(module_obj, function.__name__, function)
218-
219-
# Initialize the atomic symbol in startups
220-
_init_symbol_module()

nnvm/python/nnvm/cython/base.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
ctypedef void* SymbolHandle
2-
ctypedef void* AtomicSymbolCreator
2+
ctypedef void* OpHandle
33
ctypedef unsigned nn_uint
44

55
cdef py_str(const char* x):

nnvm/python/nnvm/cython/symbol.pyd

Lines changed: 0 additions & 6 deletions
This file was deleted.

0 commit comments

Comments
 (0)