diff --git a/docs/api/python/foo.loss.md b/docs/api/python/foo.loss.md
deleted file mode 100644
index b35a6942c6a1..000000000000
--- a/docs/api/python/foo.loss.md
+++ /dev/null
@@ -1,23 +0,0 @@
-# Foo Loss API
-
-```eval_rst
-.. currentmodule:: mxnet.foo.loss
-```
-
-```eval_rst
-.. warning:: This package is currently experimental and may change in the near future.
-```
-
-## API Reference
-
-
-
-```eval_rst
-.. automethod:: mxnet.foo.loss.custom_loss
-.. automethod:: mxnet.foo.loss.multitask_loss
-.. automethod:: mxnet.foo.loss.l1_loss
-.. automethod:: mxnet.foo.loss.l2_loss
-.. automethod:: mxnet.foo.loss.softmax_cross_entropy_loss
-```
-
-
diff --git a/docs/api/python/foo.md b/docs/api/python/foo.md
index 72897fc1676b..0bff54ca9c36 100644
--- a/docs/api/python/foo.md
+++ b/docs/api/python/foo.md
@@ -1,4 +1,4 @@
-# Foo API
+# Foo Package
```eval_rst
.. currentmodule:: mxnet.foo
@@ -8,7 +8,18 @@
.. warning:: This package is currently experimental and may change in the near future.
```
-## API Reference
+## Overview
+
+Foo package is a high-level interface for MXNet designed to be easy to use while
+keeping most of the flexibility of low level API. Foo supports both imperative
+and symbolic programming, making it easy to train complex models imperatively
+in Python and then deploy with symbolic graph in C++ and Scala.
+
+## Parameter
+
+```eval_rst
+.. currentmodule:: mxnet.foo
+```
@@ -17,8 +28,192 @@
:members:
.. autoclass:: mxnet.foo.ParameterDict
:members:
+```
+
+
+
+
+## Neural Network Layers
+
+```eval_rst
+.. currentmodule:: mxnet.foo.nn
+```
+
+### Containers
+
+
+
+```eval_rst
+.. currentmodule:: mxnet.foo.nn
+.. autoclass:: mxnet.foo.nn.Layer
+ :members:
+
+ .. automethod:: __call__
+.. autoclass:: mxnet.foo.nn.Sequential
+ :members:
+```
+
+
+
+### Basic Layers
+
+
+
+```eval_rst
+.. currentmodule:: mxnet.foo.nn
+.. autoclass:: mxnet.foo.nn.Dense
+ :members:
+.. autoclass:: mxnet.foo.nn.Activation
+ :members:
+.. autoclass:: mxnet.foo.nn.Dropout
+ :members:
+.. autoclass:: mxnet.foo.nn.BatchNorm
+ :members:
+.. autoclass:: mxnet.foo.nn.LeakyReLU
+ :members:
+.. autoclass:: mxnet.foo.nn.Embedding
+ :members:
+```
+
+
+
+### Convolutional Layers
+
+
+
+```eval_rst
+.. currentmodule:: mxnet.foo.nn
+.. autoclass:: mxnet.foo.nn.Conv1D
+ :members:
+.. autoclass:: mxnet.foo.nn.Conv2D
+ :members:
+.. autoclass:: mxnet.foo.nn.Conv3D
+ :members:
+.. autoclass:: mxnet.foo.nn.Conv1DTranspose
+ :members:
+.. autoclass:: mxnet.foo.nn.Conv2DTranspose
+ :members:
+.. autoclass:: mxnet.foo.nn.Conv3DTranspose
+ :members:
+```
+
+
+
+
+### Pooling Layers
+
+
+
+```eval_rst
+.. currentmodule:: mxnet.foo.nn
+.. autoclass:: mxnet.foo.nn.MaxPool1D
+ :members:
+.. autoclass:: mxnet.foo.nn.MaxPool2D
+ :members:
+.. autoclass:: mxnet.foo.nn.MaxPool3D
+ :members:
+.. autoclass:: mxnet.foo.nn.AvgPool1D
+ :members:
+.. autoclass:: mxnet.foo.nn.AvgPool2D
+ :members:
+.. autoclass:: mxnet.foo.nn.AvgPool3D
+ :members:
+.. autoclass:: mxnet.foo.nn.GlobalMaxPool1D
+ :members:
+.. autoclass:: mxnet.foo.nn.GlobalMaxPool2D
+ :members:
+.. autoclass:: mxnet.foo.nn.GlobalMaxPool3D
+ :members:
+.. autoclass:: mxnet.foo.nn.GlobalAvgPool1D
+ :members:
+.. autoclass:: mxnet.foo.nn.GlobalAvgPool2D
+ :members:
+.. autoclass:: mxnet.foo.nn.GlobalAvgPool3D
+ :members:
+```
+
+
+
+
+## Recurrent Layers
+
+```eval_rst
+.. currentmodule:: mxnet.foo.rnn
+```
+
+
+
+```eval_rst
+.. autoclass:: mxnet.foo.rnn.RecurrentCell
+ :members:
+
+ .. automethod:: __call__
+.. autoclass:: mxnet.foo.rnn.LSTMCell
+ :members:
+.. autoclass:: mxnet.foo.rnn.GRUCell
+ :members:
+.. autoclass:: mxnet.foo.rnn.RNNCell
+ :members:
+.. autoclass:: mxnet.foo.rnn.FusedRNNCell
+ :members:
+.. autoclass:: mxnet.foo.rnn.SequentialRNNCell
+ :members:
+.. autoclass:: mxnet.foo.rnn.BidirectionalCell
+ :members:
+.. autoclass:: mxnet.foo.rnn.DropoutCell
+ :members:
+.. autoclass:: mxnet.foo.rnn.ZoneoutCell
+ :members:
+.. autoclass:: mxnet.foo.rnn.ResidualCell
+ :members:
+```
+
+
+
+## Trainer
+
+```eval_rst
+.. currentmodule:: mxnet.foo
+```
+
+
+
+```eval_rst
.. autoclass:: mxnet.foo.Trainer
:members:
```
+
+## Loss functions
+
+```eval_rst
+.. currentmodule:: mxnet.foo.loss
+```
+
+
+
+```eval_rst
+.. automethod:: mxnet.foo.loss.custom_loss
+.. automethod:: mxnet.foo.loss.multitask_loss
+.. automethod:: mxnet.foo.loss.l1_loss
+.. automethod:: mxnet.foo.loss.l2_loss
+.. automethod:: mxnet.foo.loss.softmax_cross_entropy_loss
+```
+
+
+
+## Utilities
+
+```eval_rst
+.. currentmodule:: mxnet.foo.utils
+```
+
+
+
+```eval_rst
+.. automethod:: mxnet.foo.utils.split_data
+.. automethod:: mxnet.foo.utils.load_data
+```
+
+
diff --git a/docs/api/python/foo.nn.md b/docs/api/python/foo.nn.md
deleted file mode 100644
index 184f0ecc5dbb..000000000000
--- a/docs/api/python/foo.nn.md
+++ /dev/null
@@ -1,72 +0,0 @@
-# Foo NN API
-
-```eval_rst
-.. currentmodule:: mxnet.foo.nn
-```
-
-```eval_rst
-.. warning:: This package is currently experimental and may change in the near future.
-```
-
-## API Reference
-
-
-
-```eval_rst
-.. currentmodule:: mxnet.foo.nn
-.. autoclass:: mxnet.foo.nn.Layer
- :members:
-
- .. automethod:: __call__
-.. autoclass:: mxnet.foo.nn.Sequential
- :members:
-.. autoclass:: mxnet.foo.nn.Dense
- :members:
-.. autoclass:: mxnet.foo.nn.Activation
- :members:
-.. autoclass:: mxnet.foo.nn.Dropout
- :members:
-.. autoclass:: mxnet.foo.nn.BatchNorm
- :members:
-.. autoclass:: mxnet.foo.nn.LeakyReLU
- :members:
-
-.. autoclass:: mxnet.foo.nn.Conv1D
- :members:
-.. autoclass:: mxnet.foo.nn.Conv2D
- :members:
-.. autoclass:: mxnet.foo.nn.Conv3D
- :members:
-.. autoclass:: mxnet.foo.nn.Conv1DTranspose
- :members:
-.. autoclass:: mxnet.foo.nn.Conv2DTranspose
- :members:
-.. autoclass:: mxnet.foo.nn.Conv3DTranspose
- :members:
-.. autoclass:: mxnet.foo.nn.MaxPool1D
- :members:
-.. autoclass:: mxnet.foo.nn.MaxPool2D
- :members:
-.. autoclass:: mxnet.foo.nn.MaxPool3D
- :members:
-.. autoclass:: mxnet.foo.nn.AvgPool1D
- :members:
-.. autoclass:: mxnet.foo.nn.AvgPool2D
- :members:
-.. autoclass:: mxnet.foo.nn.AvgPool3D
- :members:
-.. autoclass:: mxnet.foo.nn.GlobalMaxPool1D
- :members:
-.. autoclass:: mxnet.foo.nn.GlobalMaxPool2D
- :members:
-.. autoclass:: mxnet.foo.nn.GlobalMaxPool3D
- :members:
-.. autoclass:: mxnet.foo.nn.GlobalAvgPool1D
- :members:
-.. autoclass:: mxnet.foo.nn.GlobalAvgPool2D
- :members:
-.. autoclass:: mxnet.foo.nn.GlobalAvgPool3D
- :members:
-```
-
-
diff --git a/docs/api/python/foo.rnn.md b/docs/api/python/foo.rnn.md
deleted file mode 100644
index e2c2b37a1f6f..000000000000
--- a/docs/api/python/foo.rnn.md
+++ /dev/null
@@ -1,40 +0,0 @@
-# Foo RNN API
-
-```eval_rst
-.. currentmodule:: mxnet.foo.rnn
-```
-
-```eval_rst
-.. warning:: This package is currently experimental and may change in the near future.
-```
-
-## API Reference
-
-
-
-```eval_rst
-.. autoclass:: mxnet.foo.rnn.RecurrentCell
- :members:
-
- .. automethod:: __call__
-.. autoclass:: mxnet.foo.rnn.LSTMCell
- :members:
-.. autoclass:: mxnet.foo.rnn.GRUCell
- :members:
-.. autoclass:: mxnet.foo.rnn.RNNCell
- :members:
-.. autoclass:: mxnet.foo.rnn.FusedRNNCell
- :members:
-.. autoclass:: mxnet.foo.rnn.SequentialRNNCell
- :members:
-.. autoclass:: mxnet.foo.rnn.BidirectionalCell
- :members:
-.. autoclass:: mxnet.foo.rnn.DropoutCell
- :members:
-.. autoclass:: mxnet.foo.rnn.ZoneoutCell
- :members:
-.. autoclass:: mxnet.foo.rnn.ResidualCell
- :members:
-```
-
-
diff --git a/docs/api/python/foo.utils.md b/docs/api/python/foo.utils.md
deleted file mode 100644
index 21dea1a0c2b9..000000000000
--- a/docs/api/python/foo.utils.md
+++ /dev/null
@@ -1,20 +0,0 @@
-# Foo Utility API
-
-```eval_rst
-.. currentmodule:: mxnet.foo.utils
-```
-
-```eval_rst
-.. warning:: This package is currently experimental and may change in the near future.
-```
-
-## API Reference
-
-
-
-```eval_rst
-.. automethod:: mxnet.foo.utils.split_data
-.. automethod:: mxnet.foo.utils.load_data
-```
-
-
diff --git a/docs/api/python/index.md b/docs/api/python/index.md
index fe102eb6a601..43f02677126a 100644
--- a/docs/api/python/index.md
+++ b/docs/api/python/index.md
@@ -29,10 +29,6 @@ imported by running:
symbol
module
foo
- foo.nn
- foo.rnn
- foo.loss
- foo.utils
rnn
kvstore
io
diff --git a/docs/tutorials/basic/foo.md b/docs/tutorials/basic/foo.md
new file mode 100644
index 000000000000..84b14278158c
--- /dev/null
+++ b/docs/tutorials/basic/foo.md
@@ -0,0 +1,291 @@
+# Foo - High-level Interface
+
+Foo package is a high-level interface for MXNet designed to be easy to use while
+keeping most of the flexibility of low level API. Foo supports both imperative
+and symbolic programming, making it easy to train complex models imperatively
+in Python and then deploy with symbolic graph in C++ and Scala.
+
+This tutorial covers four topics:
+- MXNet NDArray as a replacement of numpy for asynchronous scientific computing
+across CPU and GPU.
+- Automatic differentiation with NDArray.
+- Define and train neural network models with Foo's imperative API.
+- [TODO] Save trained models as symbolic graph for easy production deployment.
+
+## Setup
+First, let's import MXNet and Foo:
+
+```python
+from __future__ import print_function
+import numpy as np
+import mxnet as mx
+```
+
+## NDArray
+
+### Creating NDArray
+
+NDArray is similar to numpy's ndarray, but supports asynchronous operations
+and GPU. There are many ways to create NDArray.
+
+Construct from (nested) list:
+```python
+x = mx.nd.array([[1, 2, 3], [4, 5, 6]])
+print(x)
+```
+
+Construct from numpy array:
+```python
+x_numpy = np.ones((2, 3))
+x = mx.nd.array(x_numpy)
+print(x)
+```
+
+Array construction routines:
+```python
+# create an 2x3 array of ones
+x = mx.nd.ones((2, 3))
+print(x)
+# create an 2x3 array of zeros
+x = mx.nd.zeros((2, 3))
+print(x)
+# create an 1d-array of 0 to 5 and reshape to 2x3
+x = mx.nd.arange(6).reshape((2, 3))
+print(x)
+```
+
+You can convert any NDArray to numpy array with `.asnumpy()`:
+```python
+z = x.asnumpy()
+print(z)
+```
+
+### NDArray Operations
+
+NDArray supports a wide range of operations. Simple operations can be called
+with python syntax:
+
+```python
+x = mx.nd.array([[1, 2], [3, 4]])
+y = mx.nd.array([[4, 3], [2, 1]])
+print(x + y)
+```
+
+You can also call operators from the `mxnet.ndarray` (or `mx.nd` for short) name space:
+
+```python
+z = mx.nd.add(x, y)
+print(z)
+```
+
+You can also pass additional flags to operators:
+
+```python
+z = mx.nd.sum(x, axis=0)
+print('axis=0:', z)
+z = mx.nd.sum(x, axis=1)
+print('axis=1:', z)
+```
+
+By default operators create new NDArrays for return value. You can specify `out`
+to use a pre-allocated buffer:
+
+```python
+z = mx.nd.empty((2, 2))
+mx.nd.add(x, y, out=z)
+print(x)
+```
+
+### Using GPU
+
+Each NDArray lives on a `Context`. MXNet supports `mx.cpu()` for CPU and `mx.gpu(0)`,
+`mx.gpu(1)`, etc for GPU. You can specify context when creating NDArray:
+
+```python
+# creates on CPU (the default).
+# Replace mx.cpu() with mx.gpu(0) if you have a GPU.
+x = mx.nd.zeros((2, 2), ctx=mx.cpu())
+print(x)
+x = mx.nd.array([[1, 2], [3, 4]], ctx=mx.cpu())
+print(x)
+```
+
+You can copy arrays between devices with `.copyto()`:
+
+```python
+# Copy x to cpu. Replace with mx.gpu(0) if you have GPU.
+y = x.copyto(mx.cpu())
+# Copy x to another NDArray, possibly on another Context.
+x.copyto(y)
+print(y)
+```
+
+See the [NDArray tutorial](ndarray.md) for a more detailed introduction to
+NDArray API.
+
+## Automatic Differentiation
+
+MXNet supports automatic differentiation with the `autograd` package.
+`autograd` allows you to differentiate a network of NDArray operations.
+This is call define-by-run, i.e., the network is defined on-the-fly by
+running forward computation. You can define exotic network structures
+and differentiate them, and each iteration can have a totally different
+network structure.
+
+```python
+form mxnet import autograd
+from mxnet.autograd import train_section
+```
+
+To use `autograd`, we must first mark variables that require gradient and
+attach gradient buffers to them:
+
+```python
+x = mx.nd.array([[1, 2], [3, 4]])
+dx = mx.nd.zeros_like(x)
+x.attach_grad(dx)
+```
+
+Now we can define the network while running forward computation by wrapping
+it inside a `train_section` (operations out of `train_section` does not define
+a graph and cannot be differentiated):
+
+```python
+with train_section():
+ y = x * 2
+ z = y * x
+```
+
+Let's backprop with `z.backward()`, which is equivalent to
+`z.backward(mx.nd.ones_like(z))`. When z has more than one entry, `z.backward()`
+is equivalent to `mx.nd.sum(z).backward()`:
+
+```python
+z.backward()
+print(x.grad)
+```
+
+## Neural Network and Layers
+
+Neural networks (and other machine learning models) can be defined and trained
+with `foo.nn` and `foo.rnn` package. A typical training script has the following
+steps:
+
+- Define network
+- Initialize parameters
+- Loop over inputs
+- Forward input through network to get output
+- Compute loss with output and label
+- Backprop gradient
+- Update parameters with gradient descent.
+
+
+### Define Network
+
+`foo.nn.Layer` is the basic building block of models. You can define networks by
+composing and inheriting `Layer`:
+
+```python
+import mxnet.foo as foo
+from mxnet.foo import nn
+
+class Net(nn.Layer):
+ def __init__(self, **kwargs):
+ super(Net, self).__init__(**kwargs)
+ with self.name_scope:
+ # layers created in name_scope will inherit name space
+ # from parent layer.
+ self.conv1 = nn.Conv2D(6, kernel_size=5)
+ self.pool1 = nn.Pool2D(kernel_size=2)
+ self.conv2 = nn.Conv2D(16, kernel_size=5)
+ self.pool2 = nn.Pool2D(kernel_size=2)
+ self.fc1 = nn.Dense(120)
+ self.fc2 = nn.Dense(84)
+ self.fc3 = nn.Dense(10)
+
+ def forward(self, F, x):
+ x = self.pool1(F.relu(self.conv1(x)))
+ x = self.pool2(F.relu(self.conv2(x)))
+ # 0 means copy over size from corresponding dimension.
+ # -1 means infer size from the rest of dimensions.
+ x = x.reshape((0, -1))
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = self.fc3(x)
+ return x
+```
+
+### Initialize Parameters
+
+A network must be created and initialized before it can be used:
+
+```python
+net = Net()
+# Initialize on CPU. Replace with `mx.gpu(0)`, or `[mx.gpu(0), mx.gpu(1)]`,
+# etc to use one or more GPUs.
+net.all_params().initialize(mx.init.Xavier(), ctx=mx.cpu())
+```
+
+Note that because we didn't specify input size to layers in Net's constructor,
+the shape of parameters cannot be determined at this point. Actual initialization
+is deferred to the first forward pass, i.e. if you access `net.fc1.weight.data()`
+now an exception will be raised.
+
+You can actually initialize the weights by running a forward pass:
+
+```python
+data = mx.nd.random_normal(shape=(10, 1, 32, 32)) # dummy data
+output = net(data)
+```
+
+Or you can specify input size when creating layers, i.e. `nn.Dense(84, in_units=120)`
+instead of `nn.Dense(84)`.
+
+### Loss Functions
+
+Loss functions take (output, label) pairs and compute a scalar loss for each sample
+in the mini-batch. The scalars measure how far each output is from the label.
+
+There are many predefined loss functions in `foo.loss`. Here we use
+`softmax_cross_entropy_loss` for digit classification.
+
+To compute loss and backprop for one iteration, we do:
+
+```python
+label = mx.nd.arange(10) # dummy label
+with train_section():
+ output = net(data)
+ loss = foo.loss.softmax_cross_entropy_loss(output, label)
+ loss.backward()
+print('loss:', loss)
+print('grad:', net.fc1.weight.grad())
+```
+
+### Updating the weights
+
+Now that gradient is computed, we just need to update the weights. This is usually
+done with formulas like `weight = weight - learning_rate * grad / batch_size`.
+Note we divide gradient by batch_size because gradient is aggregated over the
+entire batch. For example,
+
+```python
+lr = 0.01
+for p in net.all_params().values():
+ p.data()[:] -= lr / data.shape[0] * p.grad()
+```
+
+But sometimes you want more fancy updating rules like momentum and Adam, and since
+this is a commonly used functionality, foo provide a `Trainer` class for it:
+
+```python
+trainer = foo.Trainer(net.all_params(), 'sgd', {'learning_rate': 0.01})
+
+with train_section():
+ output = net(data)
+ loss = foo.loss.softmax_cross_entropy_loss(output, label)
+ loss.backward()
+
+# do the update. Trainer needs to know the batch size of data to normalize
+# the gradient by 1/batch_size.
+trainer.step(data.shape[0])
+```
diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md
index aed11a4bebf1..dc56cb145fce 100644
--- a/docs/tutorials/index.md
+++ b/docs/tutorials/index.md
@@ -10,6 +10,7 @@ These tutorials introduce a few fundamental concepts in deep learning and how to
.. toctree::
:maxdepth: 1
+ basic/foo
basic/ndarray
basic/symbol
basic/module
diff --git a/example/autograd/actor_critic.py b/example/autograd/actor_critic.py
index 1e87178f3679..7a716b23fc4d 100644
--- a/example/autograd/actor_critic.py
+++ b/example/autograd/actor_critic.py
@@ -30,12 +30,12 @@
class Policy(nn.Layer):
def __init__(self, **kwargs):
super(Policy, self).__init__(**kwargs)
- with self.scope:
+ with self.name_scope():
self.dense = nn.Dense(16, in_units=4, activation='relu')
self.action_pred = nn.Dense(2, in_units=16)
self.value_pred = nn.Dense(1, in_units=16)
- def generic_forward(self, F, x):
+ def forward(self, F, x):
x = self.dense(x)
probs = self.action_pred(x)
values = self.value_pred(x)
diff --git a/example/autograd/resnet.py b/example/autograd/resnet.py
index c87193338dde..5715eeaf9403 100644
--- a/example/autograd/resnet.py
+++ b/example/autograd/resnet.py
@@ -14,7 +14,7 @@ def conv3x3(filters, stride, in_filters):
class BasicBlockV1(nn.Layer):
def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs):
super(BasicBlockV1, self).__init__(**kwargs)
- with self.scope:
+ with self.name_scope():
self.conv1 = conv3x3(filters, stride, in_filters)
self.bn1 = nn.BatchNorm(num_features=in_filters)
self.conv2 = conv3x3(filters, 1, filters)
@@ -24,7 +24,7 @@ def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs):
self.bn_ds = nn.BatchNorm(num_features=filters)
self.downsample = downsample
- def generic_forward(self, domain, x):
+ def forward(self, domain, x):
residual = x
out = self.conv1(x)
@@ -47,7 +47,7 @@ def generic_forward(self, domain, x):
class BottleneckV1(nn.Layer):
def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs):
super(BottleneckV1, self).__init__(**kwargs)
- with self.scope:
+ with self.name_scope():
self.conv1 = nn.Conv2D(filters=filters//4, kernel_size=1, strides=1, in_filters=in_filters)
self.bn1 = nn.BatchNorm(num_features=filters//4)
self.conv2 = conv3x3(filters//4, stride, filters//4)
@@ -59,7 +59,7 @@ def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs):
self.bn_ds = nn.BatchNorm(num_features=filters)
self.downsample = downsample
- def generic_forward(self, domain, x):
+ def forward(self, domain, x):
residual = x
out = self.conv1(x)
@@ -86,7 +86,7 @@ def generic_forward(self, domain, x):
class ResnetV1(nn.Layer):
def __init__(self, block, classes, layers, filters, thumbnail=False, **kwargs):
super(ResnetV1, self).__init__(**kwargs)
- with self.scope:
+ with self.name_scope():
assert len(layers) == len(filters) - 1
self._thumbnail = thumbnail
if thumbnail:
@@ -115,7 +115,7 @@ def _make_layer(self, block, layers, filters, stride, in_filters=0):
layer.add(block(filters, 1, False, in_filters=filters))
return layer
- def generic_forward(self, domain, x):
+ def forward(self, domain, x):
x = self.conv0(x)
if not self._thumbnail:
x = self.bn0(x)
@@ -134,7 +134,7 @@ def generic_forward(self, domain, x):
class BasicBlockV2(nn.Layer):
def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs):
super(BasicBlockV2, self).__init__(**kwargs)
- with self.scope:
+ with self.name_scope():
self.bn1 = nn.BatchNorm(num_features=in_filters)
self.conv1 = conv3x3(filters, stride, in_filters)
self.bn2 = nn.BatchNorm(num_features=filters)
@@ -145,7 +145,7 @@ def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs):
else:
self.downsample = None
- def generic_forward(self, domain, x):
+ def forward(self, domain, x):
if not self.downsample:
residual = x
x = self.bn1(x)
@@ -164,7 +164,7 @@ def generic_forward(self, domain, x):
class BottleneckV2(nn.Layer):
def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs):
super(BottleneckV2, self).__init__(**kwargs)
- with self.scope:
+ with self.name_scope():
self.bn1 = nn.BatchNorm(num_features=in_filters)
self.conv1 = conv3x3(filters//4, 1, in_filters)
self.bn2 = nn.BatchNorm(num_features=filters//4)
@@ -177,7 +177,7 @@ def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs):
else:
self.downsample = None
- def generic_forward(self, domain, x):
+ def forward(self, domain, x):
if not self.downsample:
residual = x
x = self.bn1(x)
@@ -199,7 +199,7 @@ def generic_forward(self, domain, x):
class ResnetV2(nn.Layer):
def __init__(self, block, classes, layers, filters, thumbnail=False, **kwargs):
super(ResnetV2, self).__init__(**kwargs)
- with self.scope:
+ with self.name_scope():
assert len(layers) == len(filters) - 1
self._thumbnail = thumbnail
self.bn_data = nn.BatchNorm(num_features=3, scale=False, center=False)
@@ -230,7 +230,7 @@ def _make_layer(self, block, layers, filters, stride, in_filters=0):
layer.add(block(filters, 1, False, in_filters=filters))
return layer
- def generic_forward(self, domain, x):
+ def forward(self, domain, x):
x = self.bn_data(x)
x = self.conv0(x)
if not self._thumbnail:
diff --git a/example/autograd/super_resolution.py b/example/autograd/super_resolution.py
index 89cc58b8ad6e..3c66d7b09dcd 100644
--- a/example/autograd/super_resolution.py
+++ b/example/autograd/super_resolution.py
@@ -90,14 +90,14 @@ def _rearrange(raw, F, upscale_factor):
class SuperResolutionNet(nn.Layer):
def __init__(self, upscale_factor):
super(SuperResolutionNet, self).__init__()
- with self.scope:
+ with self.name_scope():
self.conv1 = nn.Conv2D(64, (5, 5), strides=(1, 1), padding=(2, 2), in_filters=1)
self.conv2 = nn.Conv2D(64, (3, 3), strides=(1, 1), padding=(1, 1), in_filters=64)
self.conv3 = nn.Conv2D(32, (3, 3), strides=(1, 1), padding=(1, 1), in_filters=64)
self.conv4 = nn.Conv2D(upscale_factor ** 2, (3, 3), strides=(1, 1), padding=(1, 1), in_filters=32)
self.upscale_factor = upscale_factor
- def generic_forward(self, F, x):
+ def forward(self, F, x):
x = F.Activation(self.conv1(x), act_type='relu')
x = F.Activation(self.conv2(x), act_type='relu')
x = F.Activation(self.conv3(x), act_type='relu')
diff --git a/example/autograd/word_language_model/model.py b/example/autograd/word_language_model/model.py
index 8a2a7d92a054..97622566c0d3 100644
--- a/example/autograd/word_language_model/model.py
+++ b/example/autograd/word_language_model/model.py
@@ -6,7 +6,7 @@ class RNNModel(nn.Layer):
def __init__(self, mode, vocab_size, num_embed, num_hidden,
num_layers, dropout=0.5, tie_weights=False, **kwargs):
super(RNNModel, self).__init__(**kwargs)
- with self.scope:
+ with self.name_scope():
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(vocab_size, num_embed)
self.rnn = rnn.FusedRNNCell(num_hidden, num_layers, mode=mode,
@@ -20,7 +20,7 @@ def __init__(self, mode, vocab_size, num_embed, num_hidden,
self.num_hidden = num_hidden
- def generic_forward(self, F, inputs, hidden):
+ def forward(self, F, inputs, hidden):
emb = self.drop(self.encoder(inputs))
output, hidden = self.rnn.unroll(None, emb, layout='TNC', merge_outputs=True)
output = self.drop(output)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 90270f776456..f6105393d2f5 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -416,6 +416,12 @@ MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle,
MXNET_DLL int MXNDArrayGetContext(NDArrayHandle handle,
int *out_dev_type,
int *out_dev_id);
+/*!
+ * \brief return gradient buffer attached to this NDArray
+ * \param handle NDArray handle
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out);
/*!
* \brief detach and ndarray from computation graph by clearing entry_
* \param handle NDArray handle
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 504fd5e7676e..e349b3091c56 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -107,6 +107,10 @@ class NDArray {
SetTBlob();
return tblob_;
}
+ /*!
+ * \return the gradient ndarray.
+ */
+ NDArray grad() const;
/*!
* \return the context of NDArray, this function is only valid when the NDArray is not empty
*/
diff --git a/python/mxnet/foo/loss.py b/python/mxnet/foo/loss.py
index 9bfa3795c2e3..8f7193383ec4 100644
--- a/python/mxnet/foo/loss.py
+++ b/python/mxnet/foo/loss.py
@@ -82,14 +82,14 @@ def custom_loss(loss, output, label, weight=None, sample_weight=None, batch_axis
loss : BaseLoss
created loss
- Example
- -------
- The following code defines a least square loss (same as `nn.l2_loss`)::
- data = mx.sym.var('data')
- output = mx.sym.FullyConnected(data, num_hidden=1)
- label = mx.sym.var('label')
- loss = mx.sym.square(output - label.reshape((-1, 1)))/2
- loss = nn.custom_loss(loss, output, label, name='l2')
+ Examples
+ --------
+ >>> # To define a least square loss (same as `l2_loss`)
+ >>> data = mx.sym.var('data')
+ >>> output = mx.sym.FullyConnected(data, num_hidden=1)
+ >>> label = mx.sym.var('label')
+ >>> loss = mx.sym.square(output - label.reshape((-1, 1)))/2
+ >>> loss = nn.custom_loss(loss, output, label, name='l2')
"""
F = _get_F(loss)
loss = _apply_weighting(F, loss, weight, sample_weight)
diff --git a/python/mxnet/foo/nn/conv_layers.py b/python/mxnet/foo/nn/conv_layers.py
index f70aa11a29db..d26bebe97c57 100644
--- a/python/mxnet/foo/nn/conv_layers.py
+++ b/python/mxnet/foo/nn/conv_layers.py
@@ -62,7 +62,7 @@ def __init__(self, filters, kernel_size, strides, padding, dilation,
kernel_initializer=None, bias_initializer=None,
op_name='Convolution', prefix=None, params=None, **kwargs):
super(_Conv, self).__init__(prefix=prefix, params=params)
- with self.scope:
+ with self.name_scope():
self._filters = filters
self._in_filters = in_filters
if isinstance(strides, numeric_types):
@@ -93,7 +93,7 @@ def __init__(self, filters, kernel_size, strides, padding, dilation,
else:
self.act = None
- def generic_forward(self, F, x, weight, bias=None):
+ def forward(self, F, x, weight, bias=None):
if bias is None:
act = F.invoke(self._op, [x, weight])
else:
@@ -520,7 +520,7 @@ def __init__(self, pool_size, strides, padding, global_pool, pool_type, **kwargs
'pool_type': pool_type}
self._op = symbol.CachedOp('Pooling', 1, **attrs)
- def generic_forward(self, F, x):
+ def forward(self, F, x):
return F.invoke(self._op, [x])
diff --git a/python/mxnet/foo/nn/layer.py b/python/mxnet/foo/nn/layer.py
index 5be7f1ff2085..172b3cb0c5cc 100644
--- a/python/mxnet/foo/nn/layer.py
+++ b/python/mxnet/foo/nn/layer.py
@@ -59,11 +59,11 @@ class Layer(object):
class Net(nn.Layer):
def __init__(self, **kwargs):
super(Net, self).__init__(**kwargs)
- with self.scope:
+ with self.name_scope():
self.dense0 = nn.Dense(20, in_units=10)
self.dense1 = nn.Dense(20, in_units=20)
- def forward(self, x):
+ def forward(self, F, x):
x = self.dense0(x)
return self.dense1(x)
@@ -129,8 +129,10 @@ def name(self):
return self.prefix[:-1]
return self.prefix
- @property
- def scope(self):
+ def name_scope(self):
+ """Returns a name space object managing sublayer and parameter
+ names. Should be used by `with` statement
+ """
return self._scope
def register_child(self, layer):
@@ -147,7 +149,7 @@ def infer_shape(self, *args):
inputs = [symbol.var('__input%d__'%i, shape=shape)
for i, shape in enumerate(args)]
params = {k: v.var() for k, v in self._reg_params.items()}
- sym = self.symbol_forward(*inputs, **params)
+ sym = self.forward(symbol, *inputs, **params)
arg_shapes, _, aux_shapes = sym.infer_shape()
sdict = {name: shape for name, shape in zip(sym.list_arguments(), arg_shapes)}
sdict.update(
@@ -157,33 +159,30 @@ def infer_shape(self, *args):
def __call__(self, *args):
"""Call forward."""
- try:
- return self.forward(*args) # pylint: disable= no-value-for-parameter
- except DeferredInitializationError:
- self.infer_shape(*[i.shape for i in args])
- for i in self.params.values():
- i._finish_deferred_init()
- return self.forward(*args) # pylint: disable= no-value-for-parameter
-
- def forward(self, x, *args):
+ return self.call(*args) # pylint: disable=no-value-for-parameter
+
+ def call(self, x, *args):
"""Defines the forward computation. Arguments can be either NDArray or Symbol."""
if isinstance(x, NDArray):
with x.context as ctx:
- params = {k: v.data(ctx) for k, v in self._reg_params.items()}
- return self.ndarray_forward(x, *args, **params)
+ try:
+ params = {k: v.data(ctx) for k, v in self._reg_params.items()}
+ except DeferredInitializationError:
+ arg_shapes = [x.shape]
+ arg_shapes += [i.shape if isinstance(i, NDArray) else i for i in args]
+ self.infer_shape(*arg_shapes)
+ for i in self.params.values():
+ i._finish_deferred_init()
+ params = {k: v.data(ctx) for k, v in self._reg_params.items()}
+ return self.forward(ndarray, x, *args, **params)
else:
assert isinstance(x, Symbol), \
- "Layer requires the first argument to forward be either Symbol or NDArray"
+ "Layer requires the first argument to forward be either " \
+ "Symbol or NDArray, but got %s"%type(x)
params = {k: v.var() for k, v in self._reg_params.items()}
- return self.symbol_forward(x, *args, **params)
-
- def ndarray_forward(self, x, *args, **kwargs):
- return self.generic_forward(ndarray, x, *args, **kwargs)
-
- def symbol_forward(self, x, *args, **kwargs):
- return self.generic_forward(symbol, x, *args, **kwargs)
+ return self.forward(symbol, x, *args, **params)
- def generic_forward(self, F, x, *args, **kwargs):
+ def forward(self, F, x, *args, **kwargs):
"""Simple forward supports both `Symbol` and `NDArray` API.
Parameters
@@ -217,13 +216,13 @@ def add(self, layer):
"""Add layer on top of the stack."""
self.register_child(layer)
- def forward(self, x):
+ def call(self, x):
#pylint: disable=arguments-differ
for layer in self._children:
x = layer(x)
return x
- def generic_forward(self, F, x, *args, **kwargs):
+ def forward(self, F, x, *args, **kwargs):
raise NotImplementedError
@@ -284,7 +283,7 @@ def __init__(self, units, activation=None, use_bias=True,
kernel_initializer=None, bias_initializer=None,
in_units=0, **kwargs):
super(Dense, self).__init__(**kwargs)
- with self.scope:
+ with self.name_scope():
self._op = symbol.CachedOp('FullyConnected', 3 if use_bias else 2,
num_hidden=units, no_bias=not use_bias)
self.weight = self.params.get('weight', shape=(units, in_units),
@@ -297,7 +296,7 @@ def __init__(self, units, activation=None, use_bias=True,
else:
self.act = None
- def generic_forward(self, F, x, weight, bias=None):
+ def forward(self, F, x, weight, bias=None):
if bias is None:
act = F.invoke(self._op, [x, weight])
else:
@@ -331,7 +330,7 @@ def __init__(self, activation, **kwargs):
def _alias(self):
return self._act_type
- def generic_forward(self, F, x):
+ def forward(self, F, x):
return F.invoke(self._op, [x])
@@ -348,14 +347,14 @@ class Dropout(Layer):
References
----------
- - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting](
+ [Dropout: A Simple Way to Prevent Neural Networks from Overfitting](
http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
"""
def __init__(self, rate, **kwargs):
super(Dropout, self).__init__(**kwargs)
self._op = symbol.CachedOp('Dropout', 1, p=rate)
- def generic_forward(self, F, x):
+ def forward(self, F, x):
return F.invoke(self._op, [x])
@@ -407,7 +406,7 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-3, center=True, scale=True,
shape=(num_features,),
init=running_variance_initializer)
- def generic_forward(self, F, x, gamma, beta, running_mean, running_var):
+ def forward(self, F, x, gamma, beta, running_mean, running_var):
return F.invoke(self._op, [x, gamma, beta, running_mean, running_var])
@@ -427,7 +426,7 @@ def __init__(self, alpha, **kwargs):
super(LeakyReLU, self).__init__(**kwargs)
self._op = symbol.CachedOp('LeakyReLU', 1, act_type='leaky', slope=alpha)
- def generic_forward(self, F, x):
+ def forward(self, F, x):
return F.invoke(self._op, [x])
@@ -463,5 +462,5 @@ def __init__(self, input_dim, output_dim, dtype='float32',
self.weight = self.params.get('weight', shape=(input_dim, output_dim),
init=embeddings_initializer)
- def generic_forward(self, F, x, weight):
+ def forward(self, F, x, weight):
return F.invoke(self._op, [x, weight])
diff --git a/python/mxnet/foo/parameter.py b/python/mxnet/foo/parameter.py
index 5901f1873d85..50c9c614b853 100644
--- a/python/mxnet/foo/parameter.py
+++ b/python/mxnet/foo/parameter.py
@@ -85,6 +85,8 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Xavier(),
allow_deferring=True):
"""Intialize parameter and gradient arrays. Only used for `NDArray` API.
+ Parameters
+ ----------
init : Initializer
The initializer to use. Overrides `Parameter.init` and default_init.
ctx : Context or list of Context, defaults to `context.current_context()`.
@@ -295,8 +297,8 @@ def get(self, name, **kwargs):
found, `get` will create a new Parameter with key-word arguments and
insert it to self.
- Parameter
- ---------
+ Parameters
+ ----------
name : str
name of the desired Parameter. It will be prepended with this dictionary's
prefix.
diff --git a/python/mxnet/foo/rnn/rnn_cell.py b/python/mxnet/foo/rnn/rnn_cell.py
index d0f6ebbcd118..2733cebe46bd 100644
--- a/python/mxnet/foo/rnn/rnn_cell.py
+++ b/python/mxnet/foo/rnn/rnn_cell.py
@@ -301,7 +301,7 @@ def _get_activation(self, F, inputs, activation, **kwargs):
else:
return activation(inputs, **kwargs)
- def forward(self, inputs, states):
+ def call(self, inputs, states):
"""Unroll the recurrent cell for one time step.
Parameters
@@ -329,7 +329,7 @@ def forward(self, inputs, states):
"""
# pylint: disable= arguments-differ
self._counter += 1
- return super(RecurrentCell, self).forward(inputs, states)
+ return super(RecurrentCell, self).call(inputs, states)
@@ -370,8 +370,8 @@ def _gate_names(self):
def _alias(self):
return 'rnn'
- def generic_forward(self, F, inputs, states, i2h_weight, i2h_bias,
- h2h_weight, h2h_bias):
+ def forward(self, F, inputs, states, i2h_weight, i2h_bias,
+ h2h_weight, h2h_bias):
name = self._curr_prefix
i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
num_hidden=self._num_hidden,
@@ -425,8 +425,8 @@ def _gate_names(self):
def _alias(self):
return 'lstm'
- def generic_forward(self, F, inputs, states, i2h_weight, i2h_bias,
- h2h_weight, h2h_bias):
+ def forward(self, F, inputs, states, i2h_weight, i2h_bias,
+ h2h_weight, h2h_bias):
name = self._curr_prefix
i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
num_hidden=self._num_hidden*4,
@@ -487,8 +487,8 @@ def _gate_names(self):
def _alias(self):
return 'gru'
- def generic_forward(self, F, inputs, states, i2h_weight, i2h_bias,
- h2h_weight, h2h_bias):
+ def forward(self, F, inputs, states, i2h_weight, i2h_bias,
+ h2h_weight, h2h_bias):
# pylint: disable=too-many-locals
name = self._curr_prefix
prev_state_h = states[0]
@@ -780,7 +780,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
return inputs, next_states
- def generic_forward(self, *args, **kwargs):
+ def forward(self, *args, **kwargs):
raise NotImplementedError
@@ -804,7 +804,7 @@ def state_info(self, batch_size=0):
def _alias(self):
return 'dropout'
- def generic_forward(self, F, inputs, states):
+ def forward(self, F, inputs, states):
if self.dropout > 0:
inputs = F.Dropout(data=inputs, p=self.dropout)
return inputs, states
@@ -814,7 +814,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs)
if isinstance(inputs, tensor_types):
- return self.generic_forward(F, inputs, begin_state if begin_state else [])
+ return self.forward(F, inputs, begin_state if begin_state else [])
else:
return super(DropoutCell, self).unroll(
length, inputs, begin_state=begin_state, layout=layout,
@@ -858,7 +858,7 @@ def unpack_weights(self, args):
def pack_weights(self, args):
return self.base_cell.pack_weights(args)
- def generic_forward(self, F, inputs, states):
+ def forward(self, F, inputs, states):
raise NotImplementedError
@@ -886,7 +886,7 @@ def reset(self):
super(ZoneoutCell, self).reset()
self.prev_output = None
- def generic_forward(self, F, inputs, states):
+ def forward(self, F, inputs, states):
cell, p_outputs, p_states = self.base_cell, self.zoneout_outputs, self.zoneout_states
next_output, next_states = cell(inputs, states)
mask = (lambda p, like: F.Dropout(F.ones_like(like), p=p))
@@ -915,7 +915,7 @@ class ResidualCell(ModifierCell):
def __init__(self, base_cell):
super(ResidualCell, self).__init__(base_cell)
- def generic_forward(self, F, inputs, states):
+ def forward(self, F, inputs, states):
output, states = self.base_cell(inputs, states)
output = F.elemwise_add(output, inputs, name="%s_plus_residual" % output.name)
return output, states
diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py
index 8900843f5937..a70b81b8f99b 100644
--- a/python/mxnet/ndarray.py
+++ b/python/mxnet/ndarray.py
@@ -62,6 +62,12 @@
3 : np.uint8,
4 : np.int32
}
+
+_GRAD_REQ_MAP = {
+ 'null': 0,
+ 'write': 1,
+ 'add': 3
+}
# pylint: enable= no-member
def _new_empty_handle():
@@ -116,8 +122,9 @@ class NDArray(NDArrayBase):
def __repr__(self):
"""Returns a string representation of the array."""
shape_info = 'x'.join(['%d' % x for x in self.shape])
- return '<%s %s @%s>' % (self.__class__.__name__,
- shape_info, self.context)
+ return '%s\n<%s %s @%s>' % (str(self.asnumpy()),
+ self.__class__.__name__,
+ shape_info, self.context)
def __add__(self, other):
"""x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """
@@ -926,6 +933,34 @@ def as_in_context(self, context):
return self
return self.copyto(context)
+ def set_grad(self, grad_req='write'):
+ """Attach a gradient buffer to this NDArray, so that `backward`
+ can compute gradient with respect to it.
+
+ Parameters
+ ----------
+ grad_req : {'write', 'add', 'null'}
+ How gradient will be accumulated.
+ - 'write': gradient will be overwritten on every backward.
+ - 'add': gradient will be added to existing value on every backward.
+ - 'null': do not compute gradient for this NDArray.
+ """
+ grad = zeros_like(self) # pylint: disable=undefined-variable
+ grad_req = _GRAD_REQ_MAP[grad_req]
+ check_call(_LIB.MXAutogradMarkVariables(
+ 1, ctypes.pointer(self.handle),
+ ctypes.pointer(mx_uint(grad_req)),
+ ctypes.pointer(grad.handle)))
+
+ @property
+ def grad(self):
+ """Returns gradient buffer attached to this NDArray."""
+ hdl = NDArrayHandle()
+ check_call(_LIB.MXNDArrayGetGrad(self.handle, ctypes.byref(hdl)))
+ if hdl.value is None:
+ return None
+ return NDArray(hdl)
+
def detach(self):
"""Returns a new NDArray, detached from the current graph."""
hdl = NDArrayHandle()
diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py
index 35c36841191b..84199b952a54 100644
--- a/python/mxnet/symbol.py
+++ b/python/mxnet/symbol.py
@@ -17,7 +17,7 @@
from .base import NDArrayHandle, ExecutorHandle, SymbolHandle, OpHandle
from .base import check_call, MXNetError, NotImplementedForSymbol, _Null # pylint: disable=unused-import
from .context import Context
-from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
+from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP
from .name import NameManager # pylint: disable=unused-import
from .executor import Executor
from . import _symbol_internal as _internal
@@ -42,7 +42,6 @@
from ._ctypes.symbol import SymbolBase, _set_symbol_class
from ._ctypes.symbol import CachedOp, invoke, _symbol_creator # pylint: disable=unused-import
-_GRAD_REQ_MAP = {'null': 0, 'write': 1, 'add': 3}
class Symbol(SymbolBase):
"""Symbol is symbolic graph of the mxnet."""
@@ -1572,7 +1571,7 @@ def bind(self, ctx, args, args_grad=None, grad_req='write',
executor.aux_arrays = aux_states
return executor
- def grad(self, wrt):
+ def gradient(self, wrt):
"""Gets the autodiff of current symbol.
This function can only be used if current symbol is a loss function.
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 9d60c8615027..cfd26301388a 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -398,6 +398,19 @@ int MXNDArrayGetContext(NDArrayHandle handle,
API_END();
}
+
+int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out) {
+ API_BEGIN();
+ NDArray *arr = static_cast(handle);
+ NDArray ret = arr->grad();
+ if (ret.is_none()) {
+ *out = NULL;
+ } else {
+ *out = new NDArray(ret);
+ }
+ API_END();
+}
+
int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out) {
API_BEGIN();
NDArray *arr = static_cast(handle);
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 6f1795d6f368..279ae9617fed 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -24,6 +24,14 @@ DMLC_REGISTRY_ENABLE(::mxnet::NDArrayFunctionReg);
namespace mxnet {
+NDArray NDArray::grad() const {
+ if (this->entry_.ag_node && this->entry_.ag_node->out_grads.size()) {
+ CHECK_EQ(this->entry_.ag_node->out_grads.size(), 1);
+ return this->entry_.ag_node->out_grads[0];
+ }
+ return NDArray();
+}
+
NDArray NDArray::Reshape(const TShape &shape) const {
using namespace autograd;
if (AutogradRuntime::Get()->IsTraining()) {
diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py
index 9b2ea4b867f3..eb73a125e819 100644
--- a/tests/python/unittest/test_autograd.py
+++ b/tests/python/unittest/test_autograd.py
@@ -234,6 +234,17 @@ def test_retain_grad():
"differentiating the same graph twice without retain_graph should fail")
+def test_set_grad():
+ x = mx.nd.zeros((10,))
+ assert x.grad is None
+ x.set_grad()
+ with train_section():
+ y = x * 2
+ assert y.grad is None
+ y.backward()
+ assert (x.grad.asnumpy() == 2).all()
+
+
if __name__ == "__main__":
import nose
nose.runmodule()
diff --git a/tests/python/unittest/test_nn.py b/tests/python/unittest/test_nn.py
index 42917855df34..bd9eca662fa4 100644
--- a/tests/python/unittest/test_nn.py
+++ b/tests/python/unittest/test_nn.py
@@ -27,11 +27,11 @@ def test_parameter_sharing():
class Net(nn.Layer):
def __init__(self, **kwargs):
super(Net, self).__init__(**kwargs)
- with self.scope:
+ with self.name_scope():
self.dense0 = nn.Dense(5, in_units=5)
self.dense1 = nn.Dense(5, in_units=5)
- def generic_forward(self, F, x):
+ def forward(self, F, x):
return self.dense1(self.dense0(x))
net1 = Net(prefix='net1_')