Skip to content

Commit

Permalink
Module tutorial improvements and metric API doc linked (apache#6532)
Browse files Browse the repository at this point in the history
* Module tutorial improvements

* prerequisite section added

* Metric API md file created and linked from index page

* link to fit function added

* section rearranged

* fixes after review

* fixes after review
  • Loading branch information
Roshrini authored and piiswrong committed Jun 2, 2017
1 parent 1f7a7cd commit 2cbab7b
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 42 deletions.
1 change: 1 addition & 0 deletions docs/api/python/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ imported by running:
io
optimization
callback
metric
```
28 changes: 28 additions & 0 deletions docs/api/python/metric.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Evaluation Metric API

```eval_rst
.. currentmodule:: mxnet.metric
```

## Overview

This document lists all the evaluation metrics available to evaluate
the performance of a learned model.

```eval_rst
.. autosummary::
:nosignatures:
mxnet.metric
```

## API Reference

<script type="text/javascript" src='../../_static/js/auto_module_index.js'></script>

```eval_rst
.. automodule:: mxnet.metric
:members:
```

<script>auto_index("api-reference");</script>
135 changes: 93 additions & 42 deletions docs/tutorials/basic/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,24 @@ steps. All this can be quite daunting to both newcomers as well as experienced
developers.

Luckily, MXNet modularizes commonly used code for training and inference in
the `module` (`mod` for short) package. `module` provides both a
high-level and intermediate-level interfaces for executing predefined networks.
the `module` (`mod` for short) package. `Module` provides both high-level and
intermediate-level interfaces for executing predefined networks. One can use
both interfaces interchangeably. We will show the usage of both interfaces in
this tutorial.

## Prerequisites

To complete this tutorial, we need:

- MXNet. See the instructions for your operating system in [Setup and Installation](http://mxnet.io/get_started/install.html)
- [Python](https://www.python.org/downloads/)

## Preliminary

In this tutorial we will demonstrate `module` usage by training a
[Multilayer Perceptron](https://en.wikipedia.org/wiki/Multilayer_perceptron) (MLP)
on the [UCI letter recognition](https://archive.ics.uci.edu/ml/datasets/letter+recognition) dataset.
on the [UCI letter recognition](https://archive.ics.uci.edu/ml/datasets/letter+recognition)
dataset.

The following code downloads the dataset and creates an 80:20 train:test
split. It also initializes a training data iterator to return a batch of 32
Expand Down Expand Up @@ -48,9 +58,7 @@ net = mx.sym.SoftmaxOutput(net, name='softmax')
mx.viz.plot_network(net)
```

## High-level Interface

### Creating a Module
## Creating a Module

Now we are ready to introduce module. The commonly used module class is
`Module`. We can construct a module by specifying the following parameters:
Expand All @@ -70,12 +78,69 @@ mod = mx.mod.Module(symbol=net,
label_names=['softmax_label'])
```

### Train, Predict, and Evaluate
## Intermediate-level Interface

We have created module. Now let us see how to run training and inference using module's intermediate-level APIs. These APIs give developers flexibility to do step-by-step
computation by running `forward` and `backward` passes. It's also useful for debugging.

To train a module, we need to perform following steps:

- `bind` : Prepares environment for the computation by allocating memory.
- `init_params` : Assigns and initializes parameters.
- `init_optimizer` : Initializes optimizers. Defaults to `sgd`.
- `metric.create` : Creates evaluation metric from input metric name.
- `forward` : Forward computation.
- `update_metric` : Evaluates and accumulates evaluation metric on outputs of the last forward computation.
- `backward` : Backward computation.
- `update` : Updates parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch.

This can be used as follows:

```python
# allocate memory given the input data and label shapes
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
# initialize parameters by uniform random numbers
mod.init_params(initializer=mx.init.Uniform(scale=.1))
# use SGD with learning rate 0.1 to train
mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))
# use accuracy as the metric
metric = mx.metric.create('acc')
# train 5 epochs, i.e. going over the data iter one pass
for epoch in range(5):
train_iter.reset()
metric.reset()
for batch in train_iter:
mod.forward(batch, is_train=True) # compute predictions
mod.update_metric(metric, batch.label) # accumulate prediction accuracy
mod.backward() # compute gradients
mod.update() # update parameters
print('Epoch %d, Training %s' % (epoch, metric.get()))
```

To learn more about these APIs, visit [Module API](http://mxnet.io/api/python/module.html).

## High-level Interface

### Train

Module also provides high-level APIs for training, predicting and evaluating for
user convenience. Instead of doing all the steps mentioned in the above section,
one can simply call [fit API](http://mxnet.io/api/python/module.html#mxnet.module.BaseModule.fit)
and it internally executes the same steps.

Module provides high-level APIs for training, predicting and evaluating.
To fit a module, simply call the `fit` function.
To fit a module, call the `fit` function as follows:

```python
# reset train_iter to the beginning
train_iter.reset()

# create a module
mod = mx.mod.Module(symbol=net,
context=mx.cpu(),
data_names=['data'],
label_names=['softmax_label'])

# fit the module
mod.fit(train_iter,
eval_data=val_iter,
optimizer='sgd',
Expand All @@ -84,7 +149,12 @@ mod.fit(train_iter,
num_epoch=8)
```

To predict with module, simply call `predict()`. It will collect and
By default, `fit` function has `eval_metric` set to `accuracy`, `optimizer` to `sgd`
and optimizer_params to `(('learning_rate', 0.01),)`.

### Predict and Evaluate

To predict with module, we can call `predict()`. It will collect and
return all the prediction results.

```python
Expand All @@ -93,12 +163,23 @@ assert y.shape == (4000, 26)
```

If we do not need the prediction outputs, but just need to evaluate on a test
set, we can call the `score()` function:
set, we can call the `score()` function. It runs prediction in the input validation
dataset and evaluates the performance according to the given input metric.

It can be used as follows:

```python
mod.score(val_iter, ['mse', 'acc'])
score = mod.score(val_iter, ['mse', 'acc'])
print "Accuracy score is ", score
```

Some of the other metrics which can be used are `top_k_acc`(top-k-accuracy),
`F1`, `RMSE`, `MSE`, `MAE`, `ce`(CrossEntropy). To learn more about the metrics,
visit [Evaluation metric](http://mxnet.io/api/python/metric.html).

One can vary number of epochs, learning_rate, optimizer parameters to change the score
and tune these parameters to get best score.

### Save and Load

We can save the module parameters after each training epoch by using a checkpoint callback.
Expand Down Expand Up @@ -139,34 +220,4 @@ mod.fit(train_iter,
begin_epoch=3)
```

## Intermediate-level Interface

We already saw how to use module for basic training and inference. Now we are
going to see a more flexible usage of module. Instead of calling
the high-level `fit` and `predict` APIs, one can write a training program with the intermediate-level APIs
`forward` and `backward`.

```python
# create module
mod = mx.mod.Module(symbol=net)
# allocate memory by given the input data and label shapes
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
# initialize parameters by uniform random numbers
mod.init_params(initializer=mx.init.Uniform(scale=.1))
# use SGD with learning rate 0.1 to train
mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))
# use accuracy as the metric
metric = mx.metric.create('acc')
# train 5 epochs, i.e. going over the data iter one pass
for epoch in range(5):
train_iter.reset()
metric.reset()
for batch in train_iter:
mod.forward(batch, is_train=True) # compute predictions
mod.update_metric(metric, batch.label) # accumulate prediction accuracy
mod.backward() # compute gradients
mod.update() # update parameters
print('Epoch %d, Training %s' % (epoch, metric.get()))
```

<!-- INSERT SOURCE DOWNLOAD BUTTONS -->

0 comments on commit 2cbab7b

Please sign in to comment.