Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix hidden keys (#4414)
Browse files Browse the repository at this point in the history
* fix hidden keys

* fix
  • Loading branch information
piiswrong committed Dec 29, 2016
1 parent 528dee0 commit 04dc056
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 27 deletions.
4 changes: 2 additions & 2 deletions docs/api/python/symbol.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ Another way to attach attributes is to use ```AttrScope```. ```AttrScope``` auto
**Components that use attributes**: More and more components are using symbol attributes to collect useful annotations for the computational graph. Here is a (probably incomplete) list:

- ```Variable``` uses attributes to store (optional) shape information for a variable.
- Optimizers read `lr_mult` and `wd_mult` attributes for each symbol in a computational graph. This is useful to control per-layer learning rate and decay.
- The model parallelism LSTM example uses the `ctx_group` attribute to divide the operators into groups that correspond to GPU devices.
- Optimizers read `__lr_mult__` and `__wd_mult__` attributes for each symbol in a computational graph. This is useful to control per-layer learning rate and decay.
- The model parallelism LSTM example uses the `__ctx_group__` attribute to divide the operators into groups that correspond to GPU devices.

## Serialization

Expand Down
2 changes: 1 addition & 1 deletion example/model-parallel-lstm/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def setup_rnn_model(default_ctx,
arg_arrays = []
args_grad = {}
for shape, name in zip(arg_shape, arg_names):
group = internals[name].attr("ctx_group")
group = internals[name].attr("__ctx_group__")
ctx = group2ctx[group] if group is not None else default_ctx
arg_arrays.append(mx.nd.zeros(shape, ctx))
if is_param_name(name):
Expand Down
2 changes: 1 addition & 1 deletion nnvm
10 changes: 9 additions & 1 deletion python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ def grad(self, wrt):
# pylint: enable= no-member


def Variable(name, attr=None, shape=None, lr_mult=None, wd_mult=None):
def Variable(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None):
"""Create a symbolic variable with specified name.
Parameters
Expand All @@ -968,6 +968,12 @@ def Variable(name, attr=None, shape=None, lr_mult=None, wd_mult=None):
Optionally, one can specify the shape of a variable. This will be used during
shape inference. If user specified a different shape for this variable using
keyword argument when calling shape inference, this shape information will be ignored.
lr_mult : float
Specify learning rate muliplier for this variable.
wd_mult : float
Specify weight decay muliplier for this variable.
dtype : str or numpy.dtype
Similar to shape, we can specify dtype for this variable.
Returns
-------
Expand All @@ -987,6 +993,8 @@ def Variable(name, attr=None, shape=None, lr_mult=None, wd_mult=None):
attr['__lr_mult__'] = str(lr_mult)
if wd_mult is not None:
attr['__wd_mult__'] = str(wd_mult)
if dtype is not None:
attr['__dtype__'] = str(_DTYPE_NP_TO_MX[_numpy.dtype(dtype).type])
ret._set_attr(**attr)
return ret

Expand Down
100 changes: 85 additions & 15 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
* \file c_api_symbolic.cc
* \brief C API of mxnet
*/

#include <mxnet/base.h>
#include <mxnet/c_api.h>
#include <nnvm/c_api.h>
Expand All @@ -19,8 +18,13 @@ void RegisterLegacyOpProp();
void RegisterLegacyNDFunc();
}
const std::vector<std::string> kHiddenKeys = {
"ctx_group", "lr_mult", "wd_mult", "__force_mirroring__"
"ctx_group", "lr_mult", "wd_mult", "force_mirroring"
};
const std::vector<std::string> kReplacedHiddenKeys = {
"__ctx_group__", "__lr_mult__", "__wd_mult__", "__force_mirroring__"
};
const char *kNamespaceSeparator = "$";


DMLC_JSON_ENABLE_ANY(int, int);

Expand Down Expand Up @@ -163,42 +167,108 @@ int MXSymbolGetAttr(SymbolHandle symbol,
const char* key,
const char** out,
int* success) {
return NNSymbolGetAttr(symbol, key, out, success);
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
if (s->GetAttr(key, &(ret->ret_str))) {
*out = (ret->ret_str).c_str();
*success = 1;
} else {
*out = nullptr;
*success = 0;
if (std::find(kHiddenKeys.begin(), kHiddenKeys.end(), key) != kHiddenKeys.end()) {
std::string skey = "__" + std::string(key) + "__";
if (s->GetAttr(skey, &(ret->ret_str))) {
*out = (ret->ret_str).c_str();
*success = 1;
}
}
}
API_END();
}

int MXSymbolSetAttr(SymbolHandle symbol,
const char* key,
const char* value) {
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
API_BEGIN();
std::vector<std::pair<std::string, std::string> > kwargs;
std::string skey(key), sval(value);
for (const auto &k : kHiddenKeys) {
std::string tmp(key);
size_t pos = tmp.rfind(k);
if (pos == 0) {
tmp = "__" + tmp + "__";
const char *tkey = tmp.c_str();
return NNSymbolSetAttrs(symbol, 1, &tkey, &value);
} else if (pos != std::string::npos && pos == tmp.length() - k.length()) {
size_t pos = skey.rfind(k);
if (pos == 0 && k.length() == skey.length()) {
skey = "__" + skey + "__";
break;
} else if (pos != std::string::npos && pos + k.length() == skey.length()) {
std::ostringstream os;
os << "setting variable attributes with " << key << " is deprecated. "
<< "please instead use\nw = Variable(" << k << "=" << value << ")\n"
<< "sym = YourSymbolName(" << tmp.substr(0, pos-1) << "=w)";
<< "sym = YourSymbolName(" << skey.substr(0, pos-1) << "=w)";
throw dmlc::Error(os.str());
}
}
return NNSymbolSetAttrs(symbol, 1, &key, &value);
kwargs.emplace_back(std::make_pair(std::move(skey), std::move(sval)));
s->SetAttrs(kwargs);
API_END();
}

int MXSymbolListAttr(SymbolHandle symbol,
mx_uint *out_size,
const char*** out) {
return NNSymbolListAttrs(symbol, 0, out_size, out);
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
std::vector<std::tuple<std::string, std::string, std::string> > attr =
s->ListAttrsRecursive();

std::vector<std::string>& attr_list = ret->ret_vec_str;
attr_list.clear();
for (const auto& tp : attr) {
attr_list.emplace_back(std::get<0>(tp) + kNamespaceSeparator + std::get<1>(tp));
attr_list.emplace_back(std::get<2>(tp));
if (find(kReplacedHiddenKeys.begin(), kReplacedHiddenKeys.end(), std::get<1>(tp))
!= kReplacedHiddenKeys.end()) {
attr_list.push_back(std::get<0>(tp) + kNamespaceSeparator +
std::get<1>(tp).substr(2, std::get<1>(tp).length() - 4));
attr_list.push_back(std::get<2>(tp));
}
}
*out_size = attr_list.size()/2;
ret->ret_vec_charp.clear();
for (size_t i = 0; i < attr_list.size(); ++i) {
ret->ret_vec_charp.push_back(attr_list[i].c_str());
}
*out = dmlc::BeginPtr(ret->ret_vec_charp);
API_END();
}

int MXSymbolListAttrShallow(SymbolHandle symbol,
mx_uint *out_size,
const char*** out) {
return NNSymbolListAttrs(symbol, 1, out_size, out);
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
std::unordered_map<std::string, std::string> attr =
s->ListAttrs(static_cast<nnvm::Symbol::ListAttrOption>(1)); // NOLINT(*)

std::vector<std::string>& attr_list = ret->ret_vec_str;
attr_list.clear();
for (const auto& kv : attr) {
attr_list.push_back(kv.first);
attr_list.push_back(kv.second);
if (find(kReplacedHiddenKeys.begin(), kReplacedHiddenKeys.end(), kv.first)
!= kReplacedHiddenKeys.end()) {
attr_list.push_back(kv.first.substr(2, kv.first.length() - 4));
attr_list.push_back(kv.second);
}
}
*out_size = attr_list.size()/2;
ret->ret_vec_charp.clear();
for (size_t i = 0; i < attr_list.size(); ++i) {
ret->ret_vec_charp.push_back(attr_list[i].c_str());
}
*out = dmlc::BeginPtr(ret->ret_vec_charp);
API_END();
}

int MXSymbolListOutputs(SymbolHandle symbol,
Expand Down Expand Up @@ -444,7 +514,7 @@ int MXSymbolInferType(SymbolHandle sym,
mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType");
}

g = nnvm::pass::InferType(std::move(g), arg_types);
g = nnvm::pass::InferType(std::move(g), arg_types, "__dtype__");
// copy back
CopyAttr(g.indexed_graph(), g.GetAttr<nnvm::DTypeVector>("dtype"),
&(ret->arg_types), &(ret->out_types), &(ret->aux_types));
Expand Down
2 changes: 1 addition & 1 deletion src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
arg_types.resize(idx.input_nodes().size(), -1);
// other initializations
g = nnvm::pass::InferShape(g, arg_shapes, "__shape__");
g = nnvm::pass::InferType(g, arg_types);
g = nnvm::pass::InferType(g, arg_types, "__dtype__");
g = nnvm::ApplyPass(g, "PlanMemory");
g = DetectInplaceAddTo(g);
return g;
Expand Down
18 changes: 12 additions & 6 deletions tests/python/unittest/test_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,16 @@ def test_attr_basic():
with mx.AttrScope(group='4', data='great'):
data = mx.symbol.Variable('data',
attr={'dtype':'data',
'group': '1'})
'group': '1',
'force_mirroring': 'True'},
lr_mult=1)
gdata = mx.symbol.Variable('data2')
assert gdata.attr('group') == '4'
assert data.attr('group') == '1'
assert data.attr('lr_mult') == '1'
assert data.attr('__lr_mult__') == '1'
assert data.attr('force_mirroring') == 'True'
assert data.attr('__force_mirroring__') == 'True'
data2 = pkl.loads(pkl.dumps(data))
assert data.attr('dtype') == data2.attr('dtype')

Expand Down Expand Up @@ -43,21 +49,21 @@ def contain(x, y):
def test_list_attr():
data = mx.sym.Variable('data', attr={'mood': 'angry'})
op = mx.sym.Convolution(data=data, name='conv', kernel=(1, 1),
num_filter=1, attr={'__mood__': 'so so'})
assert contain({'__mood__': 'so so'}, op.list_attr())
num_filter=1, attr={'__mood__': 'so so', 'wd_mult': 'x'})
assert contain({'__mood__': 'so so', 'wd_mult': 'x', '__wd_mult__': 'x'}, op.list_attr())

def test_attr_dict():
data = mx.sym.Variable('data', attr={'mood': 'angry'})
op = mx.sym.Convolution(data=data, name='conv', kernel=(1, 1),
num_filter=1, attr={'__mood__': 'so so'})
num_filter=1, attr={'__mood__': 'so so'}, lr_mult=1)
assert contain({
'data': {'mood': 'angry'},
'conv_weight': {'__mood__': 'so so'},
'conv': {'kernel': '(1, 1)', '__mood__': 'so so', 'num_filter': '1'},
'conv': {'kernel': '(1, 1)', '__mood__': 'so so', 'num_filter': '1', 'lr_mult': '1', '__lr_mult__': '1'},
'conv_bias': {'__mood__': 'so so'}}, op.attr_dict())

if __name__ == '__main__':
test_attr_basic()
test_operator()
test_list_attr()
test_attr_dict()
test_attr_dict()

0 comments on commit 04dc056

Please sign in to comment.