Skip to content

Commit

Permalink
doc for symbol attributes and naming convention (#2070)
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed May 15, 2016
1 parent 7841253 commit aa512fa
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 38 deletions.
2 changes: 1 addition & 1 deletion dmlc-core
2 changes: 1 addition & 1 deletion docs/get_started/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ mlp = @mx.chain mx.Variable(:data) =>
mx.FullyConnected(num_hidden=64) =>
mx.Activation(act_type=:relu) =>
mx.FullyConnected(num_hidden=10) =>
mx.SoftmaxOutput()
mx.SoftmaxOutput(name=:softmax)
```

The model can be trained by
Expand Down
2 changes: 1 addition & 1 deletion docs/get_started/overview_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ mlp = @mx.chain mx.Variable(:data) =>
mx.FullyConnected(num_hidden=64) =>
mx.Activation(act_type=:relu) =>
mx.FullyConnected(num_hidden=10) =>
mx.Softmax()
mx.Softmax(name=:softmax)
```

在执行一个符号表达式前,我们需要对所有的自由变量进行赋值。上例中,我们需要给定数据,和各个层里隐式定义的输入,例如全连接层的权重和偏值。我们同时要申明所需要的输出,例如softmax的输出。
Expand Down
54 changes: 54 additions & 0 deletions docs/packages/python/symbol.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# MXNet Python Symbolic API
* [How to Commpose Symbols](#overloaded-operators) introduces operator overloading of symbols
* [Symbol Attributes](#symbol-attributes) introduces how to attach attributes to symbols
* [Serialization](#serialization) introduces how to save and load symbols.
* [Multiple Outputs](#multiple-outputs) introduces how to configure multiple outputs
* [Symbol Creation API Reference](#symbol-creationapi-reference) gives reference to all functions.
Expand Down Expand Up @@ -38,6 +39,59 @@ The following code gives an example of computation graph that add two inputs tog
>>> c = a + b
````

Symbol Attributes
-----------------
Attributes can be attached to symbols, by providing an attribute dictionary when creating a symbol.
```python
data = mx.sym.Variable('data', attr={'mood': 'angry'})
op = mx.sym.Convolution(data=data, name='conv', kernel=(1, 1),
num_filter=1, attr={'mood': 'so so'})
```
Both key and values of the attribute dictionary should be strings, in order to properly communicate with the C++ backend. The attributes can be retrived via `attr(key)` or `list_attr()`:
```
assert data.attr('mood') == 'angry'
assert op.list_attr() == {'mood': 'so so'}
```
In the case of a composite symbol, you can also retrieve all the attributes associated with that symbol *and its descendents* via `list_attr(recursive=True)`. Note in the returned dictionary, all the attribute names are with a prefix `'symbol_name' + '_'` in order to avoid naming conflicts.
```python
assert op.list_attr(recursive=True) == {'data_mood': 'angry', 'conv_mood': 'so so',
'conv_weight_mood': 'so so', 'conv_bias_mood': 'so so'}
```
Here you may noticed that the `mood` attribute we set for the ```Convolution``` operator is copied to `conv_weight` and `conv_bias`. Those are symbols automatically created by the ```Convolution``` operator, and the attributes are also automatically copied for them. This is intentional and is especially useful for annotation of context groups in model parallelism. However, if the weight or bias symbol are explicitly created by the user, then the attributes for the host operator will *not* be copied to them:
```python
weight = mx.sym.Variable('crazy_weight', attr={'size': '5'})
data = mx.sym.Variable('data', attr={'mood': 'angry'})
op = mx.sym.Convolution(data=data, weight=weight, name='conv', kernel=(1, 1),
num_filter=1, attr={'mood': 'so so'})
op.list_attr(recursive=True)
# =>
# {'conv_mood': 'so so',
# 'conv_bias_mood': 'so so',
# 'crazy_weight_size': '5',
# 'data_mood': 'angry'}
```
As you can see, the `mood` attribute is copied to the automatically created symbol `conv_bias`, but not to the manually created weight symbol `crazy_weight`.

Another way of attaching attributes is to use ```AttrScope```. An ```AttrScope``` will automatically add the specified attributes to all the symbols created within that scope. For example:
```python
data = mx.symbol.Variable('data')
with mx.AttrScope(group='4', data='great'):
fc1 = mx.symbol.Activation(data, act_type='relu')
with mx.AttrScope(init_bias='0.0'):
fc2 = mx.symbol.FullyConnected(fc1, num_hidden=10, name='fc2')
assert fc1.attr('data') == 'great'
assert fc2.attr('data') == 'great'
assert fc2.attr('init_bias') == '0.0'
```

**Naming convention**: it is recommended to choose the attribute names to be valid variable names. Names with double underscope (e.g. `__shape__`) are reserved for internal use. The slash `'_'` is the character used to separate a symbol name and its attributes, so should *not* be used in names of attributes.

**Components that uses attributes**: more and more components are using symbol attributes to collect useful annotations for the computational graph. Here is a (probably incomplete) list:

- ```Variable``` use attributes to store (optional) shape information for a variable.
- Optimizers will read `lr_mult` and `wd_mult` attributes for each symbol in a computational graph. This is useful to control per-layer learning rate and decay.
- The model parallelism LSTM example uses `ctx_group` attribute to divide the operators into different groups corresponding to different GPU devices.

Serialization
-------------
There are two ways to save and load the symbols. You can use pickle to serialize the ```Symbol``` objects.
Expand Down
4 changes: 2 additions & 2 deletions example/image-classification/train_cifar10_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def set_wd_mult(self, args_wd_mult):
) or 'proj' in n or 'zscore' in n:
self.wd_mult[n] = 0.0
if self.sym is not None:
attr = self.sym.list_attr()
attr = self.sym.list_attr(recursive=True)
for k, v in attr.items():
if k.endswith('_wd_mult'):
self.wd_mult[k[:-len('_wd_mult')]] = float(v)
Expand All @@ -245,7 +245,7 @@ def set_lr_mult(self, args_lr_mult):
if 'proj' in n or 'zscore' in n:
self.lr_mult[n] = 0.0
if self.sym is not None:
attr = self.sym.list_attr()
attr = self.sym.list_attr(recursive=True)
for k, v in attr.items():
if k.endswith('_lr_mult'):
self.lr_mult[k[:-len('_lr_mult')]] = float(v)
Expand Down
34 changes: 22 additions & 12 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .ndarray import NDArray, zeros, clip, sqrt
from .random import normal


class Optimizer(object):
"""Base class of all optimizers."""
opt_registry = {}
Expand All @@ -18,8 +19,8 @@ def register(klass):
assert(isinstance(klass, type))
name = klass.__name__.lower()
if name in Optimizer.opt_registry:
print('WARNING: New optimizer %s.%s is overriding ' \
'existing optimizer %s.%s'%(
print('WARNING: New optimizer %s.%s is overriding '
'existing optimizer %s.%s' % (
klass.__module__, klass.__name__,
Optimizer.opt_registry[name].__module__,
Optimizer.opt_registry[name].__name__))
Expand Down Expand Up @@ -137,7 +138,7 @@ def set_lr_mult(self, args_lr_mult):
"""
self.lr_mult = {}
if self.sym is not None:
attr = self.sym.list_attr()
attr = self.sym.list_attr(recursive=True)
for k, v in attr.items():
if k.endswith('_lr_mult'):
self.lr_mult[k[:-len('_lr_mult')]] = float(v)
Expand All @@ -160,7 +161,7 @@ def set_wd_mult(self, args_wd_mult):
if not (n.endswith('_weight') or n.endswith('_gamma')):
self.wd_mult[n] = 0.0
if self.sym is not None:
attr = self.sym.list_attr()
attr = self.sym.list_attr(recursive=True)
for k, v in attr.items():
if k.endswith('_wd_mult'):
self.wd_mult[k[:-len('_wd_mult')]] = float(v)
Expand Down Expand Up @@ -224,9 +225,10 @@ def _get_wd(self, index):
wd *= self.wd_mult.get(self.idx2name[index], 1.0)
return wd

#convenience wrapper for Optimizer.Register
# convenience wrapper for Optimizer.Register
register = Optimizer.register


@register
class SGD(Optimizer):
"""A very simple SGD optimizer with momentum and weight regularization.
Expand Down Expand Up @@ -353,6 +355,7 @@ def update(self, index, weight, grad, state):
assert self.momentum == 0.0
weight[:] += -lr * (grad + self.wd * weight)


@register
class SGLD(Optimizer):
"""Stochastic Langevin Dynamics Updater to sample from a distribution.
Expand Down Expand Up @@ -414,8 +417,8 @@ def update(self, index, weight, grad, state):
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
weight[:] += - lr/2 * (grad + wd * weight) \
+ normal(0, math.sqrt(lr), weight.shape, weight.context)
weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr),
weight.shape, weight.context)


@register
Expand Down Expand Up @@ -496,6 +499,7 @@ def update(self, index, weight, grad, state):
mx_float(lr),
mx_float(wd)))


@register
class Adam(Optimizer):
"""Adam optimizer as described in [King2014]_.
Expand Down Expand Up @@ -593,6 +597,7 @@ def update(self, index, weight, grad, state):
if wd > 0.:
weight[:] -= (lr * wd) * weight


@register
class AdaGrad(Optimizer):
"""AdaGrad optimizer of Duchi et al., 2011,
Expand Down Expand Up @@ -625,7 +630,7 @@ def __init__(self, eps=1e-7, **kwargs):
self.float_stable_eps = eps

def create_state(self, index, weight):
return zeros(weight.shape, weight.context) # history
return zeros(weight.shape, weight.context) # history

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
Expand All @@ -640,6 +645,7 @@ def update(self, index, weight, grad, state):
history[:] += (grad * grad)
weight[:] += -lr * (grad / sqrt(history + self.float_stable_eps) + self.wd * weight)


@register
class RMSProp(Optimizer):
"""RMSProp optimizer of Tieleman & Hinton, 2012,
Expand Down Expand Up @@ -713,6 +719,7 @@ def update(self, index, weight, grad, state):
delta[:] = (self.gamma2) * delta - lr * (grad/sqrt(n - g*g + 1e-4) + wd * weight)
weight[:] += delta


@register
class AdaDelta(Optimizer):
"""
Expand Down Expand Up @@ -741,8 +748,8 @@ def __init__(self, rho=0.90, epsilon=1e-5, **kwargs):
self.epsilon = epsilon

def create_state(self, index, weight):
return (zeros(weight.shape, weight.context), # accumulated g
zeros(weight.shape, weight.context)) # accumulated delta
return (zeros(weight.shape, weight.context), # accumulated g
zeros(weight.shape, weight.context)) # accumulated delta

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
Expand All @@ -760,12 +767,13 @@ def update(self, index, weight, grad, state):

# update g, delta
acc_g[:] = self.rho * acc_g + (1. - self.rho) * grad * grad
current_delta = sqrt(acc_delta + self.epsilon) / sqrt(acc_g + self.epsilon) * grad
current_delta = sqrt(acc_delta + self.epsilon) / sqrt(acc_g + self.epsilon) * grad
acc_delta[:] = self.rho * acc_delta + (1. - self.rho) * current_delta * current_delta

# update weight
weight[:] -= current_delta + wd * weight


@register
class Test(Optimizer):
"""For test use"""
Expand All @@ -782,9 +790,10 @@ def update(self, index, weight, grad, state):
weight[:] += grad * self.rescale_grad
state[:] = weight

#backward compatibility wrapper for Optimizer.CreateOptimizer
# backward compatibility wrapper for Optimizer.CreateOptimizer
create = Optimizer.create_optimizer


def get_updater(optimizer):
"""Return a clossure of the updater needed for kvstore
Expand All @@ -799,6 +808,7 @@ def get_updater(optimizer):
The clossure of the updater
"""
states = dict()

def updater(index, grad, weight):
"""updater for kvstore"""
if index not in states:
Expand Down
36 changes: 19 additions & 17 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,21 +247,21 @@ def attr(self, key):
else:
return None

def list_attr(self, shallow=False):
"""Get all attributes from the symbol and its descendents.
def list_attr(self, recursive=False):
"""Get all attributes from the symbol.
Parameters
----------
shallow : bool
Default `False`. When `shallow` is `False`, list recursively all the
recursive : bool
Default `False`. When `recursive` is `True`, list recursively all the
attributes in the descendents. The attribute names are pre-pended with
the symbol names to avoid conflicts. If `True`, then only attributes
the symbol names to avoid conflicts. If `False`, then only attributes
that belongs to this symbol is returned, and the attribute names will
**not** be pre-pended with the symbol name.
"""
size = mx_uint()
pairs = ctypes.POINTER(ctypes.c_char_p)()
f_handle = _LIB.MXSymbolListAttrShallow if shallow else _LIB.MXSymbolListAttr
f_handle = _LIB.MXSymbolListAttr if recursive else _LIB.MXSymbolListAttrShallow
check_call(f_handle(self.handle, ctypes.byref(size), ctypes.byref(pairs)))
return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size.value)}

Expand Down Expand Up @@ -536,7 +536,6 @@ def debug_str(self):
self.handle, ctypes.byref(debug_str)))
return py_str(debug_str.value)


def save(self, fname):
"""Save symbol into file.
Expand Down Expand Up @@ -990,12 +989,13 @@ def _make_atomic_symbol_function(handle):
' Name of the resulting symbol.\n\n' +
'Returns\n' +
'-------\n' +
'symbol: Symbol\n'+
'symbol: Symbol\n' +
' The result symbol.')
doc_str = doc_str % (desc, param_str)
extra_doc = "\n" + '\n'.join([x.__doc__ for x in type.__subclasses__(SymbolDoc)
if x.__name__ == '%sDoc' % func_name])
doc_str += re.sub(re.compile(" "), "", extra_doc)

def creator(*args, **kwargs):
"""Activation Operator of Neural Net.
The parameters listed below can be passed in as keyword arguments.
Expand Down Expand Up @@ -1077,6 +1077,7 @@ def _init_symbol_module():
# Initialize the atomic symbo in startups
_init_symbol_module()


# pylint: disable=no-member
# pylint: disable=redefined-builtin
def pow(base, exp):
Expand All @@ -1093,11 +1094,11 @@ def pow(base, exp):
"""
if isinstance(base, Symbol) and isinstance(exp, Symbol):
return Symbol._Power(base, exp)
if isinstance(base, Symbol) and isinstance(exp, Number):
if isinstance(base, Symbol) and isinstance(exp, Number):
return Symbol._PowerScalar(base, scalar=exp)
if isinstance(base, Number) and isinstance(exp, Symbol):
if isinstance(base, Number) and isinstance(exp, Symbol):
return Symbol._RPowerScalar(exp, scalar=base)
if isinstance(base, Number) and isinstance(exp, Number):
if isinstance(base, Number) and isinstance(exp, Number):
return base**exp
else:
raise TypeError('types (%s, %s) not supported' % (str(type(base)), str(type(exp))))
Expand All @@ -1119,15 +1120,16 @@ def maximum(left, right):
"""
if isinstance(left, Symbol) and isinstance(right, Symbol):
return Symbol._Maximum(left, right)
if isinstance(left, Symbol) and isinstance(right, Number):
if isinstance(left, Symbol) and isinstance(right, Number):
return Symbol._MaximumScalar(left, scalar=right)
if isinstance(left, Number) and isinstance(right, Symbol):
if isinstance(left, Number) and isinstance(right, Symbol):
return Symbol._MaximumScalar(right, scalar=left)
if isinstance(left, Number) and isinstance(right, Number):
if isinstance(left, Number) and isinstance(right, Number):
return left if left > right else right
else:
raise TypeError('types (%s, %s) not supported' % (str(type(left)), str(type(right))))


# pylint: disable=no-member
# pylint: disable=redefined-builtin
def minimum(left, right):
Expand All @@ -1144,11 +1146,11 @@ def minimum(left, right):
"""
if isinstance(left, Symbol) and isinstance(right, Symbol):
return Symbol._Minimum(left, right)
if isinstance(left, Symbol) and isinstance(right, Number):
if isinstance(left, Symbol) and isinstance(right, Number):
return Symbol._MinimumScalar(left, scalar=right)
if isinstance(left, Number) and isinstance(right, Symbol):
if isinstance(left, Number) and isinstance(right, Symbol):
return Symbol._MinimumScalar(right, scalar=left)
if isinstance(left, Number) and isinstance(right, Number):
if isinstance(left, Number) and isinstance(right, Number):
return left if left > right else right
else:
raise TypeError('types (%s, %s) not supported' % (str(type(left)), str(type(right))))
3 changes: 2 additions & 1 deletion src/symbol/symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace mxnet {

namespace symbol_constants {
const char *kShapeKey = "__shape__";
const char *kNamespaceSeparator = "_";
} // namespace symbol_constants

/*!
Expand Down Expand Up @@ -494,7 +495,7 @@ std::map<std::string, std::string> Symbol::ListAttr() {
this->DFSVisit([&ret](const std::shared_ptr<Node> &n) {
if (n->attr.get() == nullptr) return;
for (const auto &it : *(n->attr.get())) {
ret[n->name+"_"+it.first] = it.second;
ret[n->name+symbol_constants::kNamespaceSeparator+it.first] = it.second;
}
});
return ret;
Expand Down
Loading

0 comments on commit aa512fa

Please sign in to comment.