diff --git a/docs/api/python/foo.loss.md b/docs/api/python/foo.loss.md deleted file mode 100644 index b35a6942c6a1..000000000000 --- a/docs/api/python/foo.loss.md +++ /dev/null @@ -1,23 +0,0 @@ -# Foo Loss API - -```eval_rst -.. currentmodule:: mxnet.foo.loss -``` - -```eval_rst -.. warning:: This package is currently experimental and may change in the near future. -``` - -## API Reference - - - -```eval_rst -.. automethod:: mxnet.foo.loss.custom_loss -.. automethod:: mxnet.foo.loss.multitask_loss -.. automethod:: mxnet.foo.loss.l1_loss -.. automethod:: mxnet.foo.loss.l2_loss -.. automethod:: mxnet.foo.loss.softmax_cross_entropy_loss -``` - - diff --git a/docs/api/python/foo.md b/docs/api/python/foo.md index 72897fc1676b..0bff54ca9c36 100644 --- a/docs/api/python/foo.md +++ b/docs/api/python/foo.md @@ -1,4 +1,4 @@ -# Foo API +# Foo Package ```eval_rst .. currentmodule:: mxnet.foo @@ -8,7 +8,18 @@ .. warning:: This package is currently experimental and may change in the near future. ``` -## API Reference +## Overview + +Foo package is a high-level interface for MXNet designed to be easy to use while +keeping most of the flexibility of low level API. Foo supports both imperative +and symbolic programming, making it easy to train complex models imperatively +in Python and then deploy with symbolic graph in C++ and Scala. + +## Parameter + +```eval_rst +.. currentmodule:: mxnet.foo +``` @@ -17,8 +28,192 @@ :members: .. autoclass:: mxnet.foo.ParameterDict :members: +``` + + + + +## Neural Network Layers + +```eval_rst +.. currentmodule:: mxnet.foo.nn +``` + +### Containers + + + +```eval_rst +.. currentmodule:: mxnet.foo.nn +.. autoclass:: mxnet.foo.nn.Layer + :members: + + .. automethod:: __call__ +.. autoclass:: mxnet.foo.nn.Sequential + :members: +``` + + + +### Basic Layers + + + +```eval_rst +.. currentmodule:: mxnet.foo.nn +.. autoclass:: mxnet.foo.nn.Dense + :members: +.. autoclass:: mxnet.foo.nn.Activation + :members: +.. autoclass:: mxnet.foo.nn.Dropout + :members: +.. autoclass:: mxnet.foo.nn.BatchNorm + :members: +.. autoclass:: mxnet.foo.nn.LeakyReLU + :members: +.. autoclass:: mxnet.foo.nn.Embedding + :members: +``` + + + +### Convolutional Layers + + + +```eval_rst +.. currentmodule:: mxnet.foo.nn +.. autoclass:: mxnet.foo.nn.Conv1D + :members: +.. autoclass:: mxnet.foo.nn.Conv2D + :members: +.. autoclass:: mxnet.foo.nn.Conv3D + :members: +.. autoclass:: mxnet.foo.nn.Conv1DTranspose + :members: +.. autoclass:: mxnet.foo.nn.Conv2DTranspose + :members: +.. autoclass:: mxnet.foo.nn.Conv3DTranspose + :members: +``` + + + + +### Pooling Layers + + + +```eval_rst +.. currentmodule:: mxnet.foo.nn +.. autoclass:: mxnet.foo.nn.MaxPool1D + :members: +.. autoclass:: mxnet.foo.nn.MaxPool2D + :members: +.. autoclass:: mxnet.foo.nn.MaxPool3D + :members: +.. autoclass:: mxnet.foo.nn.AvgPool1D + :members: +.. autoclass:: mxnet.foo.nn.AvgPool2D + :members: +.. autoclass:: mxnet.foo.nn.AvgPool3D + :members: +.. autoclass:: mxnet.foo.nn.GlobalMaxPool1D + :members: +.. autoclass:: mxnet.foo.nn.GlobalMaxPool2D + :members: +.. autoclass:: mxnet.foo.nn.GlobalMaxPool3D + :members: +.. autoclass:: mxnet.foo.nn.GlobalAvgPool1D + :members: +.. autoclass:: mxnet.foo.nn.GlobalAvgPool2D + :members: +.. autoclass:: mxnet.foo.nn.GlobalAvgPool3D + :members: +``` + + + + +## Recurrent Layers + +```eval_rst +.. currentmodule:: mxnet.foo.rnn +``` + + + +```eval_rst +.. autoclass:: mxnet.foo.rnn.RecurrentCell + :members: + + .. automethod:: __call__ +.. autoclass:: mxnet.foo.rnn.LSTMCell + :members: +.. autoclass:: mxnet.foo.rnn.GRUCell + :members: +.. autoclass:: mxnet.foo.rnn.RNNCell + :members: +.. autoclass:: mxnet.foo.rnn.FusedRNNCell + :members: +.. autoclass:: mxnet.foo.rnn.SequentialRNNCell + :members: +.. autoclass:: mxnet.foo.rnn.BidirectionalCell + :members: +.. autoclass:: mxnet.foo.rnn.DropoutCell + :members: +.. autoclass:: mxnet.foo.rnn.ZoneoutCell + :members: +.. autoclass:: mxnet.foo.rnn.ResidualCell + :members: +``` + + + +## Trainer + +```eval_rst +.. currentmodule:: mxnet.foo +``` + + + +```eval_rst .. autoclass:: mxnet.foo.Trainer :members: ``` + +## Loss functions + +```eval_rst +.. currentmodule:: mxnet.foo.loss +``` + + + +```eval_rst +.. automethod:: mxnet.foo.loss.custom_loss +.. automethod:: mxnet.foo.loss.multitask_loss +.. automethod:: mxnet.foo.loss.l1_loss +.. automethod:: mxnet.foo.loss.l2_loss +.. automethod:: mxnet.foo.loss.softmax_cross_entropy_loss +``` + + + +## Utilities + +```eval_rst +.. currentmodule:: mxnet.foo.utils +``` + + + +```eval_rst +.. automethod:: mxnet.foo.utils.split_data +.. automethod:: mxnet.foo.utils.load_data +``` + + diff --git a/docs/api/python/foo.nn.md b/docs/api/python/foo.nn.md deleted file mode 100644 index 184f0ecc5dbb..000000000000 --- a/docs/api/python/foo.nn.md +++ /dev/null @@ -1,72 +0,0 @@ -# Foo NN API - -```eval_rst -.. currentmodule:: mxnet.foo.nn -``` - -```eval_rst -.. warning:: This package is currently experimental and may change in the near future. -``` - -## API Reference - - - -```eval_rst -.. currentmodule:: mxnet.foo.nn -.. autoclass:: mxnet.foo.nn.Layer - :members: - - .. automethod:: __call__ -.. autoclass:: mxnet.foo.nn.Sequential - :members: -.. autoclass:: mxnet.foo.nn.Dense - :members: -.. autoclass:: mxnet.foo.nn.Activation - :members: -.. autoclass:: mxnet.foo.nn.Dropout - :members: -.. autoclass:: mxnet.foo.nn.BatchNorm - :members: -.. autoclass:: mxnet.foo.nn.LeakyReLU - :members: - -.. autoclass:: mxnet.foo.nn.Conv1D - :members: -.. autoclass:: mxnet.foo.nn.Conv2D - :members: -.. autoclass:: mxnet.foo.nn.Conv3D - :members: -.. autoclass:: mxnet.foo.nn.Conv1DTranspose - :members: -.. autoclass:: mxnet.foo.nn.Conv2DTranspose - :members: -.. autoclass:: mxnet.foo.nn.Conv3DTranspose - :members: -.. autoclass:: mxnet.foo.nn.MaxPool1D - :members: -.. autoclass:: mxnet.foo.nn.MaxPool2D - :members: -.. autoclass:: mxnet.foo.nn.MaxPool3D - :members: -.. autoclass:: mxnet.foo.nn.AvgPool1D - :members: -.. autoclass:: mxnet.foo.nn.AvgPool2D - :members: -.. autoclass:: mxnet.foo.nn.AvgPool3D - :members: -.. autoclass:: mxnet.foo.nn.GlobalMaxPool1D - :members: -.. autoclass:: mxnet.foo.nn.GlobalMaxPool2D - :members: -.. autoclass:: mxnet.foo.nn.GlobalMaxPool3D - :members: -.. autoclass:: mxnet.foo.nn.GlobalAvgPool1D - :members: -.. autoclass:: mxnet.foo.nn.GlobalAvgPool2D - :members: -.. autoclass:: mxnet.foo.nn.GlobalAvgPool3D - :members: -``` - - diff --git a/docs/api/python/foo.rnn.md b/docs/api/python/foo.rnn.md deleted file mode 100644 index e2c2b37a1f6f..000000000000 --- a/docs/api/python/foo.rnn.md +++ /dev/null @@ -1,40 +0,0 @@ -# Foo RNN API - -```eval_rst -.. currentmodule:: mxnet.foo.rnn -``` - -```eval_rst -.. warning:: This package is currently experimental and may change in the near future. -``` - -## API Reference - - - -```eval_rst -.. autoclass:: mxnet.foo.rnn.RecurrentCell - :members: - - .. automethod:: __call__ -.. autoclass:: mxnet.foo.rnn.LSTMCell - :members: -.. autoclass:: mxnet.foo.rnn.GRUCell - :members: -.. autoclass:: mxnet.foo.rnn.RNNCell - :members: -.. autoclass:: mxnet.foo.rnn.FusedRNNCell - :members: -.. autoclass:: mxnet.foo.rnn.SequentialRNNCell - :members: -.. autoclass:: mxnet.foo.rnn.BidirectionalCell - :members: -.. autoclass:: mxnet.foo.rnn.DropoutCell - :members: -.. autoclass:: mxnet.foo.rnn.ZoneoutCell - :members: -.. autoclass:: mxnet.foo.rnn.ResidualCell - :members: -``` - - diff --git a/docs/api/python/foo.utils.md b/docs/api/python/foo.utils.md deleted file mode 100644 index 21dea1a0c2b9..000000000000 --- a/docs/api/python/foo.utils.md +++ /dev/null @@ -1,20 +0,0 @@ -# Foo Utility API - -```eval_rst -.. currentmodule:: mxnet.foo.utils -``` - -```eval_rst -.. warning:: This package is currently experimental and may change in the near future. -``` - -## API Reference - - - -```eval_rst -.. automethod:: mxnet.foo.utils.split_data -.. automethod:: mxnet.foo.utils.load_data -``` - - diff --git a/docs/api/python/index.md b/docs/api/python/index.md index fe102eb6a601..43f02677126a 100644 --- a/docs/api/python/index.md +++ b/docs/api/python/index.md @@ -29,10 +29,6 @@ imported by running: symbol module foo - foo.nn - foo.rnn - foo.loss - foo.utils rnn kvstore io diff --git a/docs/tutorials/basic/foo.md b/docs/tutorials/basic/foo.md new file mode 100644 index 000000000000..84b14278158c --- /dev/null +++ b/docs/tutorials/basic/foo.md @@ -0,0 +1,291 @@ +# Foo - High-level Interface + +Foo package is a high-level interface for MXNet designed to be easy to use while +keeping most of the flexibility of low level API. Foo supports both imperative +and symbolic programming, making it easy to train complex models imperatively +in Python and then deploy with symbolic graph in C++ and Scala. + +This tutorial covers four topics: +- MXNet NDArray as a replacement of numpy for asynchronous scientific computing +across CPU and GPU. +- Automatic differentiation with NDArray. +- Define and train neural network models with Foo's imperative API. +- [TODO] Save trained models as symbolic graph for easy production deployment. + +## Setup +First, let's import MXNet and Foo: + +```python +from __future__ import print_function +import numpy as np +import mxnet as mx +``` + +## NDArray + +### Creating NDArray + +NDArray is similar to numpy's ndarray, but supports asynchronous operations +and GPU. There are many ways to create NDArray. + +Construct from (nested) list: +```python +x = mx.nd.array([[1, 2, 3], [4, 5, 6]]) +print(x) +``` + +Construct from numpy array: +```python +x_numpy = np.ones((2, 3)) +x = mx.nd.array(x_numpy) +print(x) +``` + +Array construction routines: +```python +# create an 2x3 array of ones +x = mx.nd.ones((2, 3)) +print(x) +# create an 2x3 array of zeros +x = mx.nd.zeros((2, 3)) +print(x) +# create an 1d-array of 0 to 5 and reshape to 2x3 +x = mx.nd.arange(6).reshape((2, 3)) +print(x) +``` + +You can convert any NDArray to numpy array with `.asnumpy()`: +```python +z = x.asnumpy() +print(z) +``` + +### NDArray Operations + +NDArray supports a wide range of operations. Simple operations can be called +with python syntax: + +```python +x = mx.nd.array([[1, 2], [3, 4]]) +y = mx.nd.array([[4, 3], [2, 1]]) +print(x + y) +``` + +You can also call operators from the `mxnet.ndarray` (or `mx.nd` for short) name space: + +```python +z = mx.nd.add(x, y) +print(z) +``` + +You can also pass additional flags to operators: + +```python +z = mx.nd.sum(x, axis=0) +print('axis=0:', z) +z = mx.nd.sum(x, axis=1) +print('axis=1:', z) +``` + +By default operators create new NDArrays for return value. You can specify `out` +to use a pre-allocated buffer: + +```python +z = mx.nd.empty((2, 2)) +mx.nd.add(x, y, out=z) +print(x) +``` + +### Using GPU + +Each NDArray lives on a `Context`. MXNet supports `mx.cpu()` for CPU and `mx.gpu(0)`, +`mx.gpu(1)`, etc for GPU. You can specify context when creating NDArray: + +```python +# creates on CPU (the default). +# Replace mx.cpu() with mx.gpu(0) if you have a GPU. +x = mx.nd.zeros((2, 2), ctx=mx.cpu()) +print(x) +x = mx.nd.array([[1, 2], [3, 4]], ctx=mx.cpu()) +print(x) +``` + +You can copy arrays between devices with `.copyto()`: + +```python +# Copy x to cpu. Replace with mx.gpu(0) if you have GPU. +y = x.copyto(mx.cpu()) +# Copy x to another NDArray, possibly on another Context. +x.copyto(y) +print(y) +``` + +See the [NDArray tutorial](ndarray.md) for a more detailed introduction to +NDArray API. + +## Automatic Differentiation + +MXNet supports automatic differentiation with the `autograd` package. +`autograd` allows you to differentiate a network of NDArray operations. +This is call define-by-run, i.e., the network is defined on-the-fly by +running forward computation. You can define exotic network structures +and differentiate them, and each iteration can have a totally different +network structure. + +```python +form mxnet import autograd +from mxnet.autograd import train_section +``` + +To use `autograd`, we must first mark variables that require gradient and +attach gradient buffers to them: + +```python +x = mx.nd.array([[1, 2], [3, 4]]) +dx = mx.nd.zeros_like(x) +x.attach_grad(dx) +``` + +Now we can define the network while running forward computation by wrapping +it inside a `train_section` (operations out of `train_section` does not define +a graph and cannot be differentiated): + +```python +with train_section(): + y = x * 2 + z = y * x +``` + +Let's backprop with `z.backward()`, which is equivalent to +`z.backward(mx.nd.ones_like(z))`. When z has more than one entry, `z.backward()` +is equivalent to `mx.nd.sum(z).backward()`: + +```python +z.backward() +print(x.grad) +``` + +## Neural Network and Layers + +Neural networks (and other machine learning models) can be defined and trained +with `foo.nn` and `foo.rnn` package. A typical training script has the following +steps: + +- Define network +- Initialize parameters +- Loop over inputs +- Forward input through network to get output +- Compute loss with output and label +- Backprop gradient +- Update parameters with gradient descent. + + +### Define Network + +`foo.nn.Layer` is the basic building block of models. You can define networks by +composing and inheriting `Layer`: + +```python +import mxnet.foo as foo +from mxnet.foo import nn + +class Net(nn.Layer): + def __init__(self, **kwargs): + super(Net, self).__init__(**kwargs) + with self.name_scope: + # layers created in name_scope will inherit name space + # from parent layer. + self.conv1 = nn.Conv2D(6, kernel_size=5) + self.pool1 = nn.Pool2D(kernel_size=2) + self.conv2 = nn.Conv2D(16, kernel_size=5) + self.pool2 = nn.Pool2D(kernel_size=2) + self.fc1 = nn.Dense(120) + self.fc2 = nn.Dense(84) + self.fc3 = nn.Dense(10) + + def forward(self, F, x): + x = self.pool1(F.relu(self.conv1(x))) + x = self.pool2(F.relu(self.conv2(x))) + # 0 means copy over size from corresponding dimension. + # -1 means infer size from the rest of dimensions. + x = x.reshape((0, -1)) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x +``` + +### Initialize Parameters + +A network must be created and initialized before it can be used: + +```python +net = Net() +# Initialize on CPU. Replace with `mx.gpu(0)`, or `[mx.gpu(0), mx.gpu(1)]`, +# etc to use one or more GPUs. +net.all_params().initialize(mx.init.Xavier(), ctx=mx.cpu()) +``` + +Note that because we didn't specify input size to layers in Net's constructor, +the shape of parameters cannot be determined at this point. Actual initialization +is deferred to the first forward pass, i.e. if you access `net.fc1.weight.data()` +now an exception will be raised. + +You can actually initialize the weights by running a forward pass: + +```python +data = mx.nd.random_normal(shape=(10, 1, 32, 32)) # dummy data +output = net(data) +``` + +Or you can specify input size when creating layers, i.e. `nn.Dense(84, in_units=120)` +instead of `nn.Dense(84)`. + +### Loss Functions + +Loss functions take (output, label) pairs and compute a scalar loss for each sample +in the mini-batch. The scalars measure how far each output is from the label. + +There are many predefined loss functions in `foo.loss`. Here we use +`softmax_cross_entropy_loss` for digit classification. + +To compute loss and backprop for one iteration, we do: + +```python +label = mx.nd.arange(10) # dummy label +with train_section(): + output = net(data) + loss = foo.loss.softmax_cross_entropy_loss(output, label) + loss.backward() +print('loss:', loss) +print('grad:', net.fc1.weight.grad()) +``` + +### Updating the weights + +Now that gradient is computed, we just need to update the weights. This is usually +done with formulas like `weight = weight - learning_rate * grad / batch_size`. +Note we divide gradient by batch_size because gradient is aggregated over the +entire batch. For example, + +```python +lr = 0.01 +for p in net.all_params().values(): + p.data()[:] -= lr / data.shape[0] * p.grad() +``` + +But sometimes you want more fancy updating rules like momentum and Adam, and since +this is a commonly used functionality, foo provide a `Trainer` class for it: + +```python +trainer = foo.Trainer(net.all_params(), 'sgd', {'learning_rate': 0.01}) + +with train_section(): + output = net(data) + loss = foo.loss.softmax_cross_entropy_loss(output, label) + loss.backward() + +# do the update. Trainer needs to know the batch size of data to normalize +# the gradient by 1/batch_size. +trainer.step(data.shape[0]) +``` diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index aed11a4bebf1..dc56cb145fce 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md @@ -10,6 +10,7 @@ These tutorials introduce a few fundamental concepts in deep learning and how to .. toctree:: :maxdepth: 1 + basic/foo basic/ndarray basic/symbol basic/module diff --git a/example/autograd/actor_critic.py b/example/autograd/actor_critic.py index 1e87178f3679..7a716b23fc4d 100644 --- a/example/autograd/actor_critic.py +++ b/example/autograd/actor_critic.py @@ -30,12 +30,12 @@ class Policy(nn.Layer): def __init__(self, **kwargs): super(Policy, self).__init__(**kwargs) - with self.scope: + with self.name_scope(): self.dense = nn.Dense(16, in_units=4, activation='relu') self.action_pred = nn.Dense(2, in_units=16) self.value_pred = nn.Dense(1, in_units=16) - def generic_forward(self, F, x): + def forward(self, F, x): x = self.dense(x) probs = self.action_pred(x) values = self.value_pred(x) diff --git a/example/autograd/resnet.py b/example/autograd/resnet.py index c87193338dde..5715eeaf9403 100644 --- a/example/autograd/resnet.py +++ b/example/autograd/resnet.py @@ -14,7 +14,7 @@ def conv3x3(filters, stride, in_filters): class BasicBlockV1(nn.Layer): def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs): super(BasicBlockV1, self).__init__(**kwargs) - with self.scope: + with self.name_scope(): self.conv1 = conv3x3(filters, stride, in_filters) self.bn1 = nn.BatchNorm(num_features=in_filters) self.conv2 = conv3x3(filters, 1, filters) @@ -24,7 +24,7 @@ def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs): self.bn_ds = nn.BatchNorm(num_features=filters) self.downsample = downsample - def generic_forward(self, domain, x): + def forward(self, domain, x): residual = x out = self.conv1(x) @@ -47,7 +47,7 @@ def generic_forward(self, domain, x): class BottleneckV1(nn.Layer): def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs): super(BottleneckV1, self).__init__(**kwargs) - with self.scope: + with self.name_scope(): self.conv1 = nn.Conv2D(filters=filters//4, kernel_size=1, strides=1, in_filters=in_filters) self.bn1 = nn.BatchNorm(num_features=filters//4) self.conv2 = conv3x3(filters//4, stride, filters//4) @@ -59,7 +59,7 @@ def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs): self.bn_ds = nn.BatchNorm(num_features=filters) self.downsample = downsample - def generic_forward(self, domain, x): + def forward(self, domain, x): residual = x out = self.conv1(x) @@ -86,7 +86,7 @@ def generic_forward(self, domain, x): class ResnetV1(nn.Layer): def __init__(self, block, classes, layers, filters, thumbnail=False, **kwargs): super(ResnetV1, self).__init__(**kwargs) - with self.scope: + with self.name_scope(): assert len(layers) == len(filters) - 1 self._thumbnail = thumbnail if thumbnail: @@ -115,7 +115,7 @@ def _make_layer(self, block, layers, filters, stride, in_filters=0): layer.add(block(filters, 1, False, in_filters=filters)) return layer - def generic_forward(self, domain, x): + def forward(self, domain, x): x = self.conv0(x) if not self._thumbnail: x = self.bn0(x) @@ -134,7 +134,7 @@ def generic_forward(self, domain, x): class BasicBlockV2(nn.Layer): def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs): super(BasicBlockV2, self).__init__(**kwargs) - with self.scope: + with self.name_scope(): self.bn1 = nn.BatchNorm(num_features=in_filters) self.conv1 = conv3x3(filters, stride, in_filters) self.bn2 = nn.BatchNorm(num_features=filters) @@ -145,7 +145,7 @@ def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs): else: self.downsample = None - def generic_forward(self, domain, x): + def forward(self, domain, x): if not self.downsample: residual = x x = self.bn1(x) @@ -164,7 +164,7 @@ def generic_forward(self, domain, x): class BottleneckV2(nn.Layer): def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs): super(BottleneckV2, self).__init__(**kwargs) - with self.scope: + with self.name_scope(): self.bn1 = nn.BatchNorm(num_features=in_filters) self.conv1 = conv3x3(filters//4, 1, in_filters) self.bn2 = nn.BatchNorm(num_features=filters//4) @@ -177,7 +177,7 @@ def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs): else: self.downsample = None - def generic_forward(self, domain, x): + def forward(self, domain, x): if not self.downsample: residual = x x = self.bn1(x) @@ -199,7 +199,7 @@ def generic_forward(self, domain, x): class ResnetV2(nn.Layer): def __init__(self, block, classes, layers, filters, thumbnail=False, **kwargs): super(ResnetV2, self).__init__(**kwargs) - with self.scope: + with self.name_scope(): assert len(layers) == len(filters) - 1 self._thumbnail = thumbnail self.bn_data = nn.BatchNorm(num_features=3, scale=False, center=False) @@ -230,7 +230,7 @@ def _make_layer(self, block, layers, filters, stride, in_filters=0): layer.add(block(filters, 1, False, in_filters=filters)) return layer - def generic_forward(self, domain, x): + def forward(self, domain, x): x = self.bn_data(x) x = self.conv0(x) if not self._thumbnail: diff --git a/example/autograd/super_resolution.py b/example/autograd/super_resolution.py index 89cc58b8ad6e..3c66d7b09dcd 100644 --- a/example/autograd/super_resolution.py +++ b/example/autograd/super_resolution.py @@ -90,14 +90,14 @@ def _rearrange(raw, F, upscale_factor): class SuperResolutionNet(nn.Layer): def __init__(self, upscale_factor): super(SuperResolutionNet, self).__init__() - with self.scope: + with self.name_scope(): self.conv1 = nn.Conv2D(64, (5, 5), strides=(1, 1), padding=(2, 2), in_filters=1) self.conv2 = nn.Conv2D(64, (3, 3), strides=(1, 1), padding=(1, 1), in_filters=64) self.conv3 = nn.Conv2D(32, (3, 3), strides=(1, 1), padding=(1, 1), in_filters=64) self.conv4 = nn.Conv2D(upscale_factor ** 2, (3, 3), strides=(1, 1), padding=(1, 1), in_filters=32) self.upscale_factor = upscale_factor - def generic_forward(self, F, x): + def forward(self, F, x): x = F.Activation(self.conv1(x), act_type='relu') x = F.Activation(self.conv2(x), act_type='relu') x = F.Activation(self.conv3(x), act_type='relu') diff --git a/example/autograd/word_language_model/model.py b/example/autograd/word_language_model/model.py index 8a2a7d92a054..97622566c0d3 100644 --- a/example/autograd/word_language_model/model.py +++ b/example/autograd/word_language_model/model.py @@ -6,7 +6,7 @@ class RNNModel(nn.Layer): def __init__(self, mode, vocab_size, num_embed, num_hidden, num_layers, dropout=0.5, tie_weights=False, **kwargs): super(RNNModel, self).__init__(**kwargs) - with self.scope: + with self.name_scope(): self.drop = nn.Dropout(dropout) self.encoder = nn.Embedding(vocab_size, num_embed) self.rnn = rnn.FusedRNNCell(num_hidden, num_layers, mode=mode, @@ -20,7 +20,7 @@ def __init__(self, mode, vocab_size, num_embed, num_hidden, self.num_hidden = num_hidden - def generic_forward(self, F, inputs, hidden): + def forward(self, F, inputs, hidden): emb = self.drop(self.encoder(inputs)) output, hidden = self.rnn.unroll(None, emb, layout='TNC', merge_outputs=True) output = self.drop(output) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 90270f776456..f6105393d2f5 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -416,6 +416,12 @@ MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle, MXNET_DLL int MXNDArrayGetContext(NDArrayHandle handle, int *out_dev_type, int *out_dev_id); +/*! + * \brief return gradient buffer attached to this NDArray + * \param handle NDArray handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out); /*! * \brief detach and ndarray from computation graph by clearing entry_ * \param handle NDArray handle diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 504fd5e7676e..e349b3091c56 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -107,6 +107,10 @@ class NDArray { SetTBlob(); return tblob_; } + /*! + * \return the gradient ndarray. + */ + NDArray grad() const; /*! * \return the context of NDArray, this function is only valid when the NDArray is not empty */ diff --git a/python/mxnet/foo/loss.py b/python/mxnet/foo/loss.py index 9bfa3795c2e3..8f7193383ec4 100644 --- a/python/mxnet/foo/loss.py +++ b/python/mxnet/foo/loss.py @@ -82,14 +82,14 @@ def custom_loss(loss, output, label, weight=None, sample_weight=None, batch_axis loss : BaseLoss created loss - Example - ------- - The following code defines a least square loss (same as `nn.l2_loss`):: - data = mx.sym.var('data') - output = mx.sym.FullyConnected(data, num_hidden=1) - label = mx.sym.var('label') - loss = mx.sym.square(output - label.reshape((-1, 1)))/2 - loss = nn.custom_loss(loss, output, label, name='l2') + Examples + -------- + >>> # To define a least square loss (same as `l2_loss`) + >>> data = mx.sym.var('data') + >>> output = mx.sym.FullyConnected(data, num_hidden=1) + >>> label = mx.sym.var('label') + >>> loss = mx.sym.square(output - label.reshape((-1, 1)))/2 + >>> loss = nn.custom_loss(loss, output, label, name='l2') """ F = _get_F(loss) loss = _apply_weighting(F, loss, weight, sample_weight) diff --git a/python/mxnet/foo/nn/conv_layers.py b/python/mxnet/foo/nn/conv_layers.py index f70aa11a29db..d26bebe97c57 100644 --- a/python/mxnet/foo/nn/conv_layers.py +++ b/python/mxnet/foo/nn/conv_layers.py @@ -62,7 +62,7 @@ def __init__(self, filters, kernel_size, strides, padding, dilation, kernel_initializer=None, bias_initializer=None, op_name='Convolution', prefix=None, params=None, **kwargs): super(_Conv, self).__init__(prefix=prefix, params=params) - with self.scope: + with self.name_scope(): self._filters = filters self._in_filters = in_filters if isinstance(strides, numeric_types): @@ -93,7 +93,7 @@ def __init__(self, filters, kernel_size, strides, padding, dilation, else: self.act = None - def generic_forward(self, F, x, weight, bias=None): + def forward(self, F, x, weight, bias=None): if bias is None: act = F.invoke(self._op, [x, weight]) else: @@ -520,7 +520,7 @@ def __init__(self, pool_size, strides, padding, global_pool, pool_type, **kwargs 'pool_type': pool_type} self._op = symbol.CachedOp('Pooling', 1, **attrs) - def generic_forward(self, F, x): + def forward(self, F, x): return F.invoke(self._op, [x]) diff --git a/python/mxnet/foo/nn/layer.py b/python/mxnet/foo/nn/layer.py index 5be7f1ff2085..172b3cb0c5cc 100644 --- a/python/mxnet/foo/nn/layer.py +++ b/python/mxnet/foo/nn/layer.py @@ -59,11 +59,11 @@ class Layer(object): class Net(nn.Layer): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.scope: + with self.name_scope(): self.dense0 = nn.Dense(20, in_units=10) self.dense1 = nn.Dense(20, in_units=20) - def forward(self, x): + def forward(self, F, x): x = self.dense0(x) return self.dense1(x) @@ -129,8 +129,10 @@ def name(self): return self.prefix[:-1] return self.prefix - @property - def scope(self): + def name_scope(self): + """Returns a name space object managing sublayer and parameter + names. Should be used by `with` statement + """ return self._scope def register_child(self, layer): @@ -147,7 +149,7 @@ def infer_shape(self, *args): inputs = [symbol.var('__input%d__'%i, shape=shape) for i, shape in enumerate(args)] params = {k: v.var() for k, v in self._reg_params.items()} - sym = self.symbol_forward(*inputs, **params) + sym = self.forward(symbol, *inputs, **params) arg_shapes, _, aux_shapes = sym.infer_shape() sdict = {name: shape for name, shape in zip(sym.list_arguments(), arg_shapes)} sdict.update( @@ -157,33 +159,30 @@ def infer_shape(self, *args): def __call__(self, *args): """Call forward.""" - try: - return self.forward(*args) # pylint: disable= no-value-for-parameter - except DeferredInitializationError: - self.infer_shape(*[i.shape for i in args]) - for i in self.params.values(): - i._finish_deferred_init() - return self.forward(*args) # pylint: disable= no-value-for-parameter - - def forward(self, x, *args): + return self.call(*args) # pylint: disable=no-value-for-parameter + + def call(self, x, *args): """Defines the forward computation. Arguments can be either NDArray or Symbol.""" if isinstance(x, NDArray): with x.context as ctx: - params = {k: v.data(ctx) for k, v in self._reg_params.items()} - return self.ndarray_forward(x, *args, **params) + try: + params = {k: v.data(ctx) for k, v in self._reg_params.items()} + except DeferredInitializationError: + arg_shapes = [x.shape] + arg_shapes += [i.shape if isinstance(i, NDArray) else i for i in args] + self.infer_shape(*arg_shapes) + for i in self.params.values(): + i._finish_deferred_init() + params = {k: v.data(ctx) for k, v in self._reg_params.items()} + return self.forward(ndarray, x, *args, **params) else: assert isinstance(x, Symbol), \ - "Layer requires the first argument to forward be either Symbol or NDArray" + "Layer requires the first argument to forward be either " \ + "Symbol or NDArray, but got %s"%type(x) params = {k: v.var() for k, v in self._reg_params.items()} - return self.symbol_forward(x, *args, **params) - - def ndarray_forward(self, x, *args, **kwargs): - return self.generic_forward(ndarray, x, *args, **kwargs) - - def symbol_forward(self, x, *args, **kwargs): - return self.generic_forward(symbol, x, *args, **kwargs) + return self.forward(symbol, x, *args, **params) - def generic_forward(self, F, x, *args, **kwargs): + def forward(self, F, x, *args, **kwargs): """Simple forward supports both `Symbol` and `NDArray` API. Parameters @@ -217,13 +216,13 @@ def add(self, layer): """Add layer on top of the stack.""" self.register_child(layer) - def forward(self, x): + def call(self, x): #pylint: disable=arguments-differ for layer in self._children: x = layer(x) return x - def generic_forward(self, F, x, *args, **kwargs): + def forward(self, F, x, *args, **kwargs): raise NotImplementedError @@ -284,7 +283,7 @@ def __init__(self, units, activation=None, use_bias=True, kernel_initializer=None, bias_initializer=None, in_units=0, **kwargs): super(Dense, self).__init__(**kwargs) - with self.scope: + with self.name_scope(): self._op = symbol.CachedOp('FullyConnected', 3 if use_bias else 2, num_hidden=units, no_bias=not use_bias) self.weight = self.params.get('weight', shape=(units, in_units), @@ -297,7 +296,7 @@ def __init__(self, units, activation=None, use_bias=True, else: self.act = None - def generic_forward(self, F, x, weight, bias=None): + def forward(self, F, x, weight, bias=None): if bias is None: act = F.invoke(self._op, [x, weight]) else: @@ -331,7 +330,7 @@ def __init__(self, activation, **kwargs): def _alias(self): return self._act_type - def generic_forward(self, F, x): + def forward(self, F, x): return F.invoke(self._op, [x]) @@ -348,14 +347,14 @@ class Dropout(Layer): References ---------- - - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting]( + [Dropout: A Simple Way to Prevent Neural Networks from Overfitting]( http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) """ def __init__(self, rate, **kwargs): super(Dropout, self).__init__(**kwargs) self._op = symbol.CachedOp('Dropout', 1, p=rate) - def generic_forward(self, F, x): + def forward(self, F, x): return F.invoke(self._op, [x]) @@ -407,7 +406,7 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-3, center=True, scale=True, shape=(num_features,), init=running_variance_initializer) - def generic_forward(self, F, x, gamma, beta, running_mean, running_var): + def forward(self, F, x, gamma, beta, running_mean, running_var): return F.invoke(self._op, [x, gamma, beta, running_mean, running_var]) @@ -427,7 +426,7 @@ def __init__(self, alpha, **kwargs): super(LeakyReLU, self).__init__(**kwargs) self._op = symbol.CachedOp('LeakyReLU', 1, act_type='leaky', slope=alpha) - def generic_forward(self, F, x): + def forward(self, F, x): return F.invoke(self._op, [x]) @@ -463,5 +462,5 @@ def __init__(self, input_dim, output_dim, dtype='float32', self.weight = self.params.get('weight', shape=(input_dim, output_dim), init=embeddings_initializer) - def generic_forward(self, F, x, weight): + def forward(self, F, x, weight): return F.invoke(self._op, [x, weight]) diff --git a/python/mxnet/foo/parameter.py b/python/mxnet/foo/parameter.py index 5901f1873d85..50c9c614b853 100644 --- a/python/mxnet/foo/parameter.py +++ b/python/mxnet/foo/parameter.py @@ -85,6 +85,8 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Xavier(), allow_deferring=True): """Intialize parameter and gradient arrays. Only used for `NDArray` API. + Parameters + ---------- init : Initializer The initializer to use. Overrides `Parameter.init` and default_init. ctx : Context or list of Context, defaults to `context.current_context()`. @@ -295,8 +297,8 @@ def get(self, name, **kwargs): found, `get` will create a new Parameter with key-word arguments and insert it to self. - Parameter - --------- + Parameters + ---------- name : str name of the desired Parameter. It will be prepended with this dictionary's prefix. diff --git a/python/mxnet/foo/rnn/rnn_cell.py b/python/mxnet/foo/rnn/rnn_cell.py index d0f6ebbcd118..2733cebe46bd 100644 --- a/python/mxnet/foo/rnn/rnn_cell.py +++ b/python/mxnet/foo/rnn/rnn_cell.py @@ -301,7 +301,7 @@ def _get_activation(self, F, inputs, activation, **kwargs): else: return activation(inputs, **kwargs) - def forward(self, inputs, states): + def call(self, inputs, states): """Unroll the recurrent cell for one time step. Parameters @@ -329,7 +329,7 @@ def forward(self, inputs, states): """ # pylint: disable= arguments-differ self._counter += 1 - return super(RecurrentCell, self).forward(inputs, states) + return super(RecurrentCell, self).call(inputs, states) @@ -370,8 +370,8 @@ def _gate_names(self): def _alias(self): return 'rnn' - def generic_forward(self, F, inputs, states, i2h_weight, i2h_bias, - h2h_weight, h2h_bias): + def forward(self, F, inputs, states, i2h_weight, i2h_bias, + h2h_weight, h2h_bias): name = self._curr_prefix i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias, num_hidden=self._num_hidden, @@ -425,8 +425,8 @@ def _gate_names(self): def _alias(self): return 'lstm' - def generic_forward(self, F, inputs, states, i2h_weight, i2h_bias, - h2h_weight, h2h_bias): + def forward(self, F, inputs, states, i2h_weight, i2h_bias, + h2h_weight, h2h_bias): name = self._curr_prefix i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias, num_hidden=self._num_hidden*4, @@ -487,8 +487,8 @@ def _gate_names(self): def _alias(self): return 'gru' - def generic_forward(self, F, inputs, states, i2h_weight, i2h_bias, - h2h_weight, h2h_bias): + def forward(self, F, inputs, states, i2h_weight, i2h_bias, + h2h_weight, h2h_bias): # pylint: disable=too-many-locals name = self._curr_prefix prev_state_h = states[0] @@ -780,7 +780,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N return inputs, next_states - def generic_forward(self, *args, **kwargs): + def forward(self, *args, **kwargs): raise NotImplementedError @@ -804,7 +804,7 @@ def state_info(self, batch_size=0): def _alias(self): return 'dropout' - def generic_forward(self, F, inputs, states): + def forward(self, F, inputs, states): if self.dropout > 0: inputs = F.Dropout(data=inputs, p=self.dropout) return inputs, states @@ -814,7 +814,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs) if isinstance(inputs, tensor_types): - return self.generic_forward(F, inputs, begin_state if begin_state else []) + return self.forward(F, inputs, begin_state if begin_state else []) else: return super(DropoutCell, self).unroll( length, inputs, begin_state=begin_state, layout=layout, @@ -858,7 +858,7 @@ def unpack_weights(self, args): def pack_weights(self, args): return self.base_cell.pack_weights(args) - def generic_forward(self, F, inputs, states): + def forward(self, F, inputs, states): raise NotImplementedError @@ -886,7 +886,7 @@ def reset(self): super(ZoneoutCell, self).reset() self.prev_output = None - def generic_forward(self, F, inputs, states): + def forward(self, F, inputs, states): cell, p_outputs, p_states = self.base_cell, self.zoneout_outputs, self.zoneout_states next_output, next_states = cell(inputs, states) mask = (lambda p, like: F.Dropout(F.ones_like(like), p=p)) @@ -915,7 +915,7 @@ class ResidualCell(ModifierCell): def __init__(self, base_cell): super(ResidualCell, self).__init__(base_cell) - def generic_forward(self, F, inputs, states): + def forward(self, F, inputs, states): output, states = self.base_cell(inputs, states) output = F.elemwise_add(output, inputs, name="%s_plus_residual" % output.name) return output, states diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 8900843f5937..a70b81b8f99b 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -62,6 +62,12 @@ 3 : np.uint8, 4 : np.int32 } + +_GRAD_REQ_MAP = { + 'null': 0, + 'write': 1, + 'add': 3 +} # pylint: enable= no-member def _new_empty_handle(): @@ -116,8 +122,9 @@ class NDArray(NDArrayBase): def __repr__(self): """Returns a string representation of the array.""" shape_info = 'x'.join(['%d' % x for x in self.shape]) - return '<%s %s @%s>' % (self.__class__.__name__, - shape_info, self.context) + return '%s\n<%s %s @%s>' % (str(self.asnumpy()), + self.__class__.__name__, + shape_info, self.context) def __add__(self, other): """x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """ @@ -926,6 +933,34 @@ def as_in_context(self, context): return self return self.copyto(context) + def set_grad(self, grad_req='write'): + """Attach a gradient buffer to this NDArray, so that `backward` + can compute gradient with respect to it. + + Parameters + ---------- + grad_req : {'write', 'add', 'null'} + How gradient will be accumulated. + - 'write': gradient will be overwritten on every backward. + - 'add': gradient will be added to existing value on every backward. + - 'null': do not compute gradient for this NDArray. + """ + grad = zeros_like(self) # pylint: disable=undefined-variable + grad_req = _GRAD_REQ_MAP[grad_req] + check_call(_LIB.MXAutogradMarkVariables( + 1, ctypes.pointer(self.handle), + ctypes.pointer(mx_uint(grad_req)), + ctypes.pointer(grad.handle))) + + @property + def grad(self): + """Returns gradient buffer attached to this NDArray.""" + hdl = NDArrayHandle() + check_call(_LIB.MXNDArrayGetGrad(self.handle, ctypes.byref(hdl))) + if hdl.value is None: + return None + return NDArray(hdl) + def detach(self): """Returns a new NDArray, detached from the current graph.""" hdl = NDArrayHandle() diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 35c36841191b..84199b952a54 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -17,7 +17,7 @@ from .base import NDArrayHandle, ExecutorHandle, SymbolHandle, OpHandle from .base import check_call, MXNetError, NotImplementedForSymbol, _Null # pylint: disable=unused-import from .context import Context -from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP +from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP from .name import NameManager # pylint: disable=unused-import from .executor import Executor from . import _symbol_internal as _internal @@ -42,7 +42,6 @@ from ._ctypes.symbol import SymbolBase, _set_symbol_class from ._ctypes.symbol import CachedOp, invoke, _symbol_creator # pylint: disable=unused-import -_GRAD_REQ_MAP = {'null': 0, 'write': 1, 'add': 3} class Symbol(SymbolBase): """Symbol is symbolic graph of the mxnet.""" @@ -1572,7 +1571,7 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', executor.aux_arrays = aux_states return executor - def grad(self, wrt): + def gradient(self, wrt): """Gets the autodiff of current symbol. This function can only be used if current symbol is a loss function. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 9d60c8615027..cfd26301388a 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -398,6 +398,19 @@ int MXNDArrayGetContext(NDArrayHandle handle, API_END(); } + +int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + NDArray ret = arr->grad(); + if (ret.is_none()) { + *out = NULL; + } else { + *out = new NDArray(ret); + } + API_END(); +} + int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out) { API_BEGIN(); NDArray *arr = static_cast(handle); diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 6f1795d6f368..279ae9617fed 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -24,6 +24,14 @@ DMLC_REGISTRY_ENABLE(::mxnet::NDArrayFunctionReg); namespace mxnet { +NDArray NDArray::grad() const { + if (this->entry_.ag_node && this->entry_.ag_node->out_grads.size()) { + CHECK_EQ(this->entry_.ag_node->out_grads.size(), 1); + return this->entry_.ag_node->out_grads[0]; + } + return NDArray(); +} + NDArray NDArray::Reshape(const TShape &shape) const { using namespace autograd; if (AutogradRuntime::Get()->IsTraining()) { diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index 9b2ea4b867f3..eb73a125e819 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py @@ -234,6 +234,17 @@ def test_retain_grad(): "differentiating the same graph twice without retain_graph should fail") +def test_set_grad(): + x = mx.nd.zeros((10,)) + assert x.grad is None + x.set_grad() + with train_section(): + y = x * 2 + assert y.grad is None + y.backward() + assert (x.grad.asnumpy() == 2).all() + + if __name__ == "__main__": import nose nose.runmodule() diff --git a/tests/python/unittest/test_nn.py b/tests/python/unittest/test_nn.py index 42917855df34..bd9eca662fa4 100644 --- a/tests/python/unittest/test_nn.py +++ b/tests/python/unittest/test_nn.py @@ -27,11 +27,11 @@ def test_parameter_sharing(): class Net(nn.Layer): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.scope: + with self.name_scope(): self.dense0 = nn.Dense(5, in_units=5) self.dense1 = nn.Dense(5, in_units=5) - def generic_forward(self, F, x): + def forward(self, F, x): return self.dense1(self.dense0(x)) net1 = Net(prefix='net1_')