From aa512fafd36e737006e555a953b9a287a8ed400e Mon Sep 17 00:00:00 2001 From: Chiyuan Zhang Date: Wed, 11 May 2016 14:16:37 -0400 Subject: [PATCH] doc for symbol attributes and naming convention (#2070) --- dmlc-core | 2 +- docs/get_started/index.md | 2 +- docs/get_started/overview_zh.md | 2 +- docs/packages/python/symbol.md | 54 +++++++++++++++++++ .../train_cifar10_resnet.py | 4 +- python/mxnet/optimizer.py | 34 +++++++----- python/mxnet/symbol.py | 36 +++++++------ src/symbol/symbol.cc | 3 +- tests/python/unittest/test_attr.py | 9 ++-- 9 files changed, 108 insertions(+), 38 deletions(-) diff --git a/dmlc-core b/dmlc-core index 9fd3b484..98bd72f3 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 9fd3b48462a7a651e12a197679f71e043dcb25a2 +Subproject commit 98bd72f3c4eed7680d3dd16992bf2cf8b8f5eb9c diff --git a/docs/get_started/index.md b/docs/get_started/index.md index 75486518..0133d3fa 100644 --- a/docs/get_started/index.md +++ b/docs/get_started/index.md @@ -259,7 +259,7 @@ mlp = @mx.chain mx.Variable(:data) => mx.FullyConnected(num_hidden=64) => mx.Activation(act_type=:relu) => mx.FullyConnected(num_hidden=10) => - mx.SoftmaxOutput() + mx.SoftmaxOutput(name=:softmax) ``` The model can be trained by diff --git a/docs/get_started/overview_zh.md b/docs/get_started/overview_zh.md index 52586ff8..f1a81242 100644 --- a/docs/get_started/overview_zh.md +++ b/docs/get_started/overview_zh.md @@ -46,7 +46,7 @@ mlp = @mx.chain mx.Variable(:data) => mx.FullyConnected(num_hidden=64) => mx.Activation(act_type=:relu) => mx.FullyConnected(num_hidden=10) => - mx.Softmax() + mx.Softmax(name=:softmax) ``` 在执行一个符号表达式前,我们需要对所有的自由变量进行赋值。上例中,我们需要给定数据,和各个层里隐式定义的输入,例如全连接层的权重和偏值。我们同时要申明所需要的输出,例如softmax的输出。 diff --git a/docs/packages/python/symbol.md b/docs/packages/python/symbol.md index 60ee800a..da05cd30 100644 --- a/docs/packages/python/symbol.md +++ b/docs/packages/python/symbol.md @@ -1,5 +1,6 @@ # MXNet Python Symbolic API * [How to Commpose Symbols](#overloaded-operators) introduces operator overloading of symbols +* [Symbol Attributes](#symbol-attributes) introduces how to attach attributes to symbols * [Serialization](#serialization) introduces how to save and load symbols. * [Multiple Outputs](#multiple-outputs) introduces how to configure multiple outputs * [Symbol Creation API Reference](#symbol-creationapi-reference) gives reference to all functions. @@ -38,6 +39,59 @@ The following code gives an example of computation graph that add two inputs tog >>> c = a + b ```` +Symbol Attributes +----------------- +Attributes can be attached to symbols, by providing an attribute dictionary when creating a symbol. +```python +data = mx.sym.Variable('data', attr={'mood': 'angry'}) +op = mx.sym.Convolution(data=data, name='conv', kernel=(1, 1), + num_filter=1, attr={'mood': 'so so'}) +``` +Both key and values of the attribute dictionary should be strings, in order to properly communicate with the C++ backend. The attributes can be retrived via `attr(key)` or `list_attr()`: +``` +assert data.attr('mood') == 'angry' +assert op.list_attr() == {'mood': 'so so'} +``` +In the case of a composite symbol, you can also retrieve all the attributes associated with that symbol *and its descendents* via `list_attr(recursive=True)`. Note in the returned dictionary, all the attribute names are with a prefix `'symbol_name' + '_'` in order to avoid naming conflicts. +```python +assert op.list_attr(recursive=True) == {'data_mood': 'angry', 'conv_mood': 'so so', + 'conv_weight_mood': 'so so', 'conv_bias_mood': 'so so'} +``` +Here you may noticed that the `mood` attribute we set for the ```Convolution``` operator is copied to `conv_weight` and `conv_bias`. Those are symbols automatically created by the ```Convolution``` operator, and the attributes are also automatically copied for them. This is intentional and is especially useful for annotation of context groups in model parallelism. However, if the weight or bias symbol are explicitly created by the user, then the attributes for the host operator will *not* be copied to them: +```python +weight = mx.sym.Variable('crazy_weight', attr={'size': '5'}) +data = mx.sym.Variable('data', attr={'mood': 'angry'}) +op = mx.sym.Convolution(data=data, weight=weight, name='conv', kernel=(1, 1), + num_filter=1, attr={'mood': 'so so'}) +op.list_attr(recursive=True) +# => +# {'conv_mood': 'so so', +# 'conv_bias_mood': 'so so', +# 'crazy_weight_size': '5', +# 'data_mood': 'angry'} +``` +As you can see, the `mood` attribute is copied to the automatically created symbol `conv_bias`, but not to the manually created weight symbol `crazy_weight`. + +Another way of attaching attributes is to use ```AttrScope```. An ```AttrScope``` will automatically add the specified attributes to all the symbols created within that scope. For example: +```python +data = mx.symbol.Variable('data') +with mx.AttrScope(group='4', data='great'): + fc1 = mx.symbol.Activation(data, act_type='relu') + with mx.AttrScope(init_bias='0.0'): + fc2 = mx.symbol.FullyConnected(fc1, num_hidden=10, name='fc2') +assert fc1.attr('data') == 'great' +assert fc2.attr('data') == 'great' +assert fc2.attr('init_bias') == '0.0' +``` + +**Naming convention**: it is recommended to choose the attribute names to be valid variable names. Names with double underscope (e.g. `__shape__`) are reserved for internal use. The slash `'_'` is the character used to separate a symbol name and its attributes, so should *not* be used in names of attributes. + +**Components that uses attributes**: more and more components are using symbol attributes to collect useful annotations for the computational graph. Here is a (probably incomplete) list: + +- ```Variable``` use attributes to store (optional) shape information for a variable. +- Optimizers will read `lr_mult` and `wd_mult` attributes for each symbol in a computational graph. This is useful to control per-layer learning rate and decay. +- The model parallelism LSTM example uses `ctx_group` attribute to divide the operators into different groups corresponding to different GPU devices. + Serialization ------------- There are two ways to save and load the symbols. You can use pickle to serialize the ```Symbol``` objects. diff --git a/example/image-classification/train_cifar10_resnet.py b/example/image-classification/train_cifar10_resnet.py index 58fe105c..3404d19a 100644 --- a/example/image-classification/train_cifar10_resnet.py +++ b/example/image-classification/train_cifar10_resnet.py @@ -224,7 +224,7 @@ def set_wd_mult(self, args_wd_mult): ) or 'proj' in n or 'zscore' in n: self.wd_mult[n] = 0.0 if self.sym is not None: - attr = self.sym.list_attr() + attr = self.sym.list_attr(recursive=True) for k, v in attr.items(): if k.endswith('_wd_mult'): self.wd_mult[k[:-len('_wd_mult')]] = float(v) @@ -245,7 +245,7 @@ def set_lr_mult(self, args_lr_mult): if 'proj' in n or 'zscore' in n: self.lr_mult[n] = 0.0 if self.sym is not None: - attr = self.sym.list_attr() + attr = self.sym.list_attr(recursive=True) for k, v in attr.items(): if k.endswith('_lr_mult'): self.lr_mult[k[:-len('_lr_mult')]] = float(v) diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 01263f64..94f96eff 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -8,6 +8,7 @@ from .ndarray import NDArray, zeros, clip, sqrt from .random import normal + class Optimizer(object): """Base class of all optimizers.""" opt_registry = {} @@ -18,8 +19,8 @@ def register(klass): assert(isinstance(klass, type)) name = klass.__name__.lower() if name in Optimizer.opt_registry: - print('WARNING: New optimizer %s.%s is overriding ' \ - 'existing optimizer %s.%s'%( + print('WARNING: New optimizer %s.%s is overriding ' + 'existing optimizer %s.%s' % ( klass.__module__, klass.__name__, Optimizer.opt_registry[name].__module__, Optimizer.opt_registry[name].__name__)) @@ -137,7 +138,7 @@ def set_lr_mult(self, args_lr_mult): """ self.lr_mult = {} if self.sym is not None: - attr = self.sym.list_attr() + attr = self.sym.list_attr(recursive=True) for k, v in attr.items(): if k.endswith('_lr_mult'): self.lr_mult[k[:-len('_lr_mult')]] = float(v) @@ -160,7 +161,7 @@ def set_wd_mult(self, args_wd_mult): if not (n.endswith('_weight') or n.endswith('_gamma')): self.wd_mult[n] = 0.0 if self.sym is not None: - attr = self.sym.list_attr() + attr = self.sym.list_attr(recursive=True) for k, v in attr.items(): if k.endswith('_wd_mult'): self.wd_mult[k[:-len('_wd_mult')]] = float(v) @@ -224,9 +225,10 @@ def _get_wd(self, index): wd *= self.wd_mult.get(self.idx2name[index], 1.0) return wd -#convenience wrapper for Optimizer.Register +# convenience wrapper for Optimizer.Register register = Optimizer.register + @register class SGD(Optimizer): """A very simple SGD optimizer with momentum and weight regularization. @@ -353,6 +355,7 @@ def update(self, index, weight, grad, state): assert self.momentum == 0.0 weight[:] += -lr * (grad + self.wd * weight) + @register class SGLD(Optimizer): """Stochastic Langevin Dynamics Updater to sample from a distribution. @@ -414,8 +417,8 @@ def update(self, index, weight, grad, state): grad = grad * self.rescale_grad if self.clip_gradient is not None: grad = clip(grad, -self.clip_gradient, self.clip_gradient) - weight[:] += - lr/2 * (grad + wd * weight) \ - + normal(0, math.sqrt(lr), weight.shape, weight.context) + weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr), + weight.shape, weight.context) @register @@ -496,6 +499,7 @@ def update(self, index, weight, grad, state): mx_float(lr), mx_float(wd))) + @register class Adam(Optimizer): """Adam optimizer as described in [King2014]_. @@ -593,6 +597,7 @@ def update(self, index, weight, grad, state): if wd > 0.: weight[:] -= (lr * wd) * weight + @register class AdaGrad(Optimizer): """AdaGrad optimizer of Duchi et al., 2011, @@ -625,7 +630,7 @@ def __init__(self, eps=1e-7, **kwargs): self.float_stable_eps = eps def create_state(self, index, weight): - return zeros(weight.shape, weight.context) # history + return zeros(weight.shape, weight.context) # history def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) @@ -640,6 +645,7 @@ def update(self, index, weight, grad, state): history[:] += (grad * grad) weight[:] += -lr * (grad / sqrt(history + self.float_stable_eps) + self.wd * weight) + @register class RMSProp(Optimizer): """RMSProp optimizer of Tieleman & Hinton, 2012, @@ -713,6 +719,7 @@ def update(self, index, weight, grad, state): delta[:] = (self.gamma2) * delta - lr * (grad/sqrt(n - g*g + 1e-4) + wd * weight) weight[:] += delta + @register class AdaDelta(Optimizer): """ @@ -741,8 +748,8 @@ def __init__(self, rho=0.90, epsilon=1e-5, **kwargs): self.epsilon = epsilon def create_state(self, index, weight): - return (zeros(weight.shape, weight.context), # accumulated g - zeros(weight.shape, weight.context)) # accumulated delta + return (zeros(weight.shape, weight.context), # accumulated g + zeros(weight.shape, weight.context)) # accumulated delta def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) @@ -760,12 +767,13 @@ def update(self, index, weight, grad, state): # update g, delta acc_g[:] = self.rho * acc_g + (1. - self.rho) * grad * grad - current_delta = sqrt(acc_delta + self.epsilon) / sqrt(acc_g + self.epsilon) * grad + current_delta = sqrt(acc_delta + self.epsilon) / sqrt(acc_g + self.epsilon) * grad acc_delta[:] = self.rho * acc_delta + (1. - self.rho) * current_delta * current_delta # update weight weight[:] -= current_delta + wd * weight + @register class Test(Optimizer): """For test use""" @@ -782,9 +790,10 @@ def update(self, index, weight, grad, state): weight[:] += grad * self.rescale_grad state[:] = weight -#backward compatibility wrapper for Optimizer.CreateOptimizer +# backward compatibility wrapper for Optimizer.CreateOptimizer create = Optimizer.create_optimizer + def get_updater(optimizer): """Return a clossure of the updater needed for kvstore @@ -799,6 +808,7 @@ def get_updater(optimizer): The clossure of the updater """ states = dict() + def updater(index, grad, weight): """updater for kvstore""" if index not in states: diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 333fdc20..5b48bafa 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -247,21 +247,21 @@ def attr(self, key): else: return None - def list_attr(self, shallow=False): - """Get all attributes from the symbol and its descendents. + def list_attr(self, recursive=False): + """Get all attributes from the symbol. Parameters ---------- - shallow : bool - Default `False`. When `shallow` is `False`, list recursively all the + recursive : bool + Default `False`. When `recursive` is `True`, list recursively all the attributes in the descendents. The attribute names are pre-pended with - the symbol names to avoid conflicts. If `True`, then only attributes + the symbol names to avoid conflicts. If `False`, then only attributes that belongs to this symbol is returned, and the attribute names will **not** be pre-pended with the symbol name. """ size = mx_uint() pairs = ctypes.POINTER(ctypes.c_char_p)() - f_handle = _LIB.MXSymbolListAttrShallow if shallow else _LIB.MXSymbolListAttr + f_handle = _LIB.MXSymbolListAttr if recursive else _LIB.MXSymbolListAttrShallow check_call(f_handle(self.handle, ctypes.byref(size), ctypes.byref(pairs))) return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size.value)} @@ -536,7 +536,6 @@ def debug_str(self): self.handle, ctypes.byref(debug_str))) return py_str(debug_str.value) - def save(self, fname): """Save symbol into file. @@ -990,12 +989,13 @@ def _make_atomic_symbol_function(handle): ' Name of the resulting symbol.\n\n' + 'Returns\n' + '-------\n' + - 'symbol: Symbol\n'+ + 'symbol: Symbol\n' + ' The result symbol.') doc_str = doc_str % (desc, param_str) extra_doc = "\n" + '\n'.join([x.__doc__ for x in type.__subclasses__(SymbolDoc) if x.__name__ == '%sDoc' % func_name]) doc_str += re.sub(re.compile(" "), "", extra_doc) + def creator(*args, **kwargs): """Activation Operator of Neural Net. The parameters listed below can be passed in as keyword arguments. @@ -1077,6 +1077,7 @@ def _init_symbol_module(): # Initialize the atomic symbo in startups _init_symbol_module() + # pylint: disable=no-member # pylint: disable=redefined-builtin def pow(base, exp): @@ -1093,11 +1094,11 @@ def pow(base, exp): """ if isinstance(base, Symbol) and isinstance(exp, Symbol): return Symbol._Power(base, exp) - if isinstance(base, Symbol) and isinstance(exp, Number): + if isinstance(base, Symbol) and isinstance(exp, Number): return Symbol._PowerScalar(base, scalar=exp) - if isinstance(base, Number) and isinstance(exp, Symbol): + if isinstance(base, Number) and isinstance(exp, Symbol): return Symbol._RPowerScalar(exp, scalar=base) - if isinstance(base, Number) and isinstance(exp, Number): + if isinstance(base, Number) and isinstance(exp, Number): return base**exp else: raise TypeError('types (%s, %s) not supported' % (str(type(base)), str(type(exp)))) @@ -1119,15 +1120,16 @@ def maximum(left, right): """ if isinstance(left, Symbol) and isinstance(right, Symbol): return Symbol._Maximum(left, right) - if isinstance(left, Symbol) and isinstance(right, Number): + if isinstance(left, Symbol) and isinstance(right, Number): return Symbol._MaximumScalar(left, scalar=right) - if isinstance(left, Number) and isinstance(right, Symbol): + if isinstance(left, Number) and isinstance(right, Symbol): return Symbol._MaximumScalar(right, scalar=left) - if isinstance(left, Number) and isinstance(right, Number): + if isinstance(left, Number) and isinstance(right, Number): return left if left > right else right else: raise TypeError('types (%s, %s) not supported' % (str(type(left)), str(type(right)))) + # pylint: disable=no-member # pylint: disable=redefined-builtin def minimum(left, right): @@ -1144,11 +1146,11 @@ def minimum(left, right): """ if isinstance(left, Symbol) and isinstance(right, Symbol): return Symbol._Minimum(left, right) - if isinstance(left, Symbol) and isinstance(right, Number): + if isinstance(left, Symbol) and isinstance(right, Number): return Symbol._MinimumScalar(left, scalar=right) - if isinstance(left, Number) and isinstance(right, Symbol): + if isinstance(left, Number) and isinstance(right, Symbol): return Symbol._MinimumScalar(right, scalar=left) - if isinstance(left, Number) and isinstance(right, Number): + if isinstance(left, Number) and isinstance(right, Number): return left if left > right else right else: raise TypeError('types (%s, %s) not supported' % (str(type(left)), str(type(right)))) diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index e9bb5b30..dc896287 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -15,6 +15,7 @@ namespace mxnet { namespace symbol_constants { const char *kShapeKey = "__shape__"; +const char *kNamespaceSeparator = "_"; } // namespace symbol_constants /*! @@ -494,7 +495,7 @@ std::map Symbol::ListAttr() { this->DFSVisit([&ret](const std::shared_ptr &n) { if (n->attr.get() == nullptr) return; for (const auto &it : *(n->attr.get())) { - ret[n->name+"_"+it.first] = it.second; + ret[n->name+symbol_constants::kNamespaceSeparator+it.first] = it.second; } }); return ret; diff --git a/tests/python/unittest/test_attr.py b/tests/python/unittest/test_attr.py index 7af78576..f13a8321 100644 --- a/tests/python/unittest/test_attr.py +++ b/tests/python/unittest/test_attr.py @@ -21,17 +21,20 @@ def test_operator(): with mx.AttrScope(init_bias='0.0'): fc2 = mx.symbol.FullyConnected(fc1, num_hidden=10, name='fc2') assert fc1.attr('data') == 'great' + assert fc2.attr('data') == 'great' + assert fc2.attr('init_bias') == '0.0' fc2copy = pkl.loads(pkl.dumps(fc2)) assert fc2copy.tojson() == fc2.tojson() fc2weight = fc2.get_internals()['fc2_weight'] + def test_list_attr(): data = mx.sym.Variable('data', attr={'mood': 'angry'}) op = mx.sym.Convolution(data=data, name='conv', kernel=(1, 1), num_filter=1, attr={'mood': 'so so'}) - assert op.list_attr() == {'data_mood': 'angry', 'conv_mood': 'so so', - 'conv_weight_mood': 'so so', 'conv_bias_mood': 'so so'} - assert op.list_attr(shallow=True) == {'mood': 'so so'} + assert op.list_attr(recursive=True) == {'data_mood': 'angry', 'conv_mood': 'so so', + 'conv_weight_mood': 'so so', 'conv_bias_mood': 'so so'} + assert op.list_attr() == {'mood': 'so so'} if __name__ == '__main__':