Skip to content

Commit

Permalink
Doc for activation (#3538)
Browse files Browse the repository at this point in the history
* rnn-cell demo (push to server for testing)

* a running example with cuDNN RNN cell

* ndarray concatenate

* fix lint errors

* allow batch_axis in executor_group

* add batch_axis parameter for all modules

* fix bug in copy slice implementation

* fix module examples

* use batch_axis if data iterator provided such information

* rnn cell example in time major

* fix init state names in rnn cell bucketing example

* sanity check stochastic depth mnist

* a cifar10 example (not tested)

* add description for sd cifar

* add doc for sd module

* add a simple random number queue

* add final numbers

* fix typo

* default layout mapper

* fix other modules for layout mapper

* fix typo

* softmax output mode that preserves the shape

* comments on run-time speed of time-major

* extend layout mapper to include other information

* fix data layout API change

* fix lint errors

* fix Travis CI numpy error on unit test

* add infrastructure for symbol doctest

* add regression test demo for FullyConnected

* move utils to test_utils.py

* fix lint error

* more doc for Activation op

* doc for Flatten
  • Loading branch information
pluskid committed Oct 17, 2016
1 parent 2e9d9e6 commit 6505983
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 12 deletions.
66 changes: 64 additions & 2 deletions python/mxnet/symbol_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- all the operators (e.g. `FullyConnected`)
- the name `test_utils` for `mxnet.test_utils` (e.g. `test_utils.reldiff`)
- the name `mxnet` (e.g. `mxnet.nd.zeros`)
- the name `numpy`
The following documents are recommended:
Expand All @@ -40,6 +41,69 @@ def get_output_shape(sym, **input_shapes):
return dict(zip(sym.list_outputs(), s_outputs))


class ActivationDoc(SymbolDoc):
"""
Examples
--------
A one-hidden-layer MLP with ReLU activation:
>>> data = Variable('data')
>>> mlp = FullyConnected(data=data, num_hidden=128, name='proj')
>>> mlp = Activation(data=mlp, act_type='relu', name='activation')
>>> mlp = FullyConnected(data=mlp, num_hidden=10, name='mlp')
>>> mlp
<Symbol mlp>
Regression Test
---------------
ReLU activation
>>> test_suites = [
... ('relu', lambda x: numpy.maximum(x, 0)),
... ('sigmoid', lambda x: 1 / (1 + numpy.exp(-x))),
... ('tanh', lambda x: numpy.tanh(x)),
... ('softrelu', lambda x: numpy.log(1 + numpy.exp(x)))
... ]
>>> x = test_utils.random_arrays((2, 3, 4))
>>> for act_type, numpy_impl in test_suites:
... op = Activation(act_type=act_type, name='act')
... y = test_utils.simple_forward(op, act_data=x)
... y_np = numpy_impl(x)
... print('%s: %s' % (act_type, test_utils.almost_equal(y, y_np)))
relu: True
sigmoid: True
tanh: True
softrelu: True
"""


class FlattenDoc(SymbolDoc):
"""
Examples
--------
Flatten is usually applied before `FullyConnected`, to reshape the 4D tensor
produced by convolutional layers to 2D matrix:
>>> data = Variable('data') # say this is 4D from some conv/pool
>>> flatten = Flatten(data=data, name='flat') # now this is 2D
>>> SymbolDoc.get_output_shape(flatten, data=(2, 3, 4, 5))
{'flat_output': (2L, 60L)}
Regression Test
---------------
>>> test_dims = [(2, 3, 4, 5), (2, 3), (2,)]
>>> op = Flatten(name='flat')
>>> for dims in test_dims:
... x = test_utils.random_arrays(dims)
... y = test_utils.simple_forward(op, flat_data=x)
... y_np = x.reshape((dims[0], numpy.prod(dims[1:])))
... print('%s: %s' % (dims, test_utils.almost_equal(y, y_np)))
(2, 3, 4, 5): True
(2, 3): True
(2,): True
"""


class FullyConnectedDoc(SymbolDoc):
"""
Examples
Expand Down Expand Up @@ -77,7 +141,6 @@ class FullyConnectedDoc(SymbolDoc):
>>> test_utils.almost_equal(out, out_np)
True
"""
pass


class ConcatDoc(SymbolDoc):
Expand All @@ -97,7 +160,6 @@ class ConcatDoc(SymbolDoc):
Note the shape should be the same except on the dimension that is being
concatenated.
"""
pass


class BroadcastPlusDoc(SymbolDoc):
Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ def default_numerical_threshold():

def random_arrays(*shapes):
"""Generate some random numpy arrays."""
return [np.random.randn(*s).astype(default_dtype())
for s in shapes]
arrays = [np.random.randn(*s).astype(default_dtype())
for s in shapes]
if len(arrays) == 1:
return arrays[0]
return arrays


def np_reduce(dat, axis, keepdims, numpy_reduce_func):
Expand Down
16 changes: 12 additions & 4 deletions src/operator/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,18 @@ Operator *ActivationProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_
DMLC_REGISTER_PARAMETER(ActivationParam);

MXNET_REGISTER_OP_PROPERTY(Activation, ActivationProp)
.describe("Apply activation function to input. "
"Softmax Activation is only available with CUDNN on GPU "
"and will be computed at each location across channel if input is 4D.")
.add_argument("data", "Symbol", "Input data to activation function.")
.describe(R"(Elementwise activation function.
The following activation types are supported (operations are applied elementwisely to each
scalar of the input tensor):
- `relu`: Rectified Linear Unit, `y = max(x, 0)`
- `sigmoid`: `y = 1 / (1 + exp(-x))`
- `tanh`: Hyperbolic tangent, `y = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
- `softrelu`: Soft ReLU, or SoftPlus, `y = log(1 + exp(x))`
See `LeakyReLU` for other activations with parameters.
)")
.add_arguments(ActivationParam::__FIELDS__());

} // namespace op
Expand Down
3 changes: 2 additions & 1 deletion src/operator/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ MXNET_REGISTER_OP_PROPERTY(Reshape, ReshapeProp)
.add_arguments(ReshapeParam::__FIELDS__());

MXNET_REGISTER_OP_PROPERTY(Flatten, FlattenProp)
.describe("Flatten input")
.describe(R"(Flatten input into 2D by collapsing all the higher dimensions.
A (d1, d2, ..., dK) tensor is flatten to (d1, d2* ... *dK) matrix.)")
.add_argument("data", "Symbol", "Input data to flatten.");
} // namespace op
} // namespace mxnet
7 changes: 4 additions & 3 deletions tests/python/doctest/run.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import doctest
import logging
import mxnet
import numpy

def import_into(globs, module, names=None, error_on_overwrite=True):
"""Import names from module into the globs dict.
Parameters
----------
"""
Expand All @@ -16,7 +17,7 @@ def import_into(globs, module, names=None, error_on_overwrite=True):
mod_names = names

for name in mod_names:
if name in globs:
if name in globs and globs[name] is not getattr(module, name):
error_msg = 'Attempting to overwrite definition of %s' % name
if error_on_overwrite:
raise RuntimeError(error_msg)
Expand All @@ -27,7 +28,7 @@ def import_into(globs, module, names=None, error_on_overwrite=True):


def test_symbols():
globs = {'mxnet': mxnet, 'test_utils': mxnet.test_utils}
globs = {'numpy': numpy, 'mxnet': mxnet, 'test_utils': mxnet.test_utils}

# make sure all the operators are available
import_into(globs, mxnet.symbol)
Expand Down

0 comments on commit 6505983

Please sign in to comment.