Skip to content

Commit

Permalink
Improve PTB results (apache#7059)
Browse files Browse the repository at this point in the history
* Fix speech demo.

* Using a random seed for cudnn dropout. Previously, the fixed seed will generate the same mask for each iteration in imperative mode.

* PTB LM example now has far btter PPL: 1) forget_bias=0 2) clipping range 3) lr anealing 4) initliazation.

* (1) Remove mean for loss function (good for multi-gpu). (2) Change clip and lr to sample based. (3) Change hyperparameters, now we get slightly better results than pytorch.

* Remove the lstmbias init in model.py since it already been set to 0.
  • Loading branch information
yzhang87 authored and piiswrong committed Jul 16, 2017
1 parent 1ae1fbe commit 5af56bb
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 26 deletions.
49 changes: 49 additions & 0 deletions example/gluon/word_language_model/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Word-level language modeling RNN

This example trains a multi-layer RNN (Elman, GRU, or LSTM) on Penn Treebank (PTB) language modeling benchmark.

The model obtains the state-of-the-art result on PTB using LSTM, getting a test perplexity of ~72.

The following techniques have been adopted for SOTA results:
- [LSTM for LM](https://arxiv.org/pdf/1409.2329.pdf)
- [Weight tying](https://arxiv.org/abs/1608.05859) between word vectors and softmax output embeddings

## Data

The PTB data is the processed version from [(Mikolov et al, 2010)](http://www.fit.vutbr.cz/research/groups/speech/publi/2010/mikolov_interspeech2010_IS100722.pdf):

```bash
python data.py
```

## Usage

Example runs and the results:

```
python train.py --cuda --tied --nhid 650 --emsize 650 --dropout 0.5 # Test ppl of 75.3
python train.py --cuda --tied --nhid 1500 --emsize 1500 --dropout 0.65 # Test ppl of 72.0
```

<br>

`python train.py --help` gives the following arguments:
```
Optional arguments:
-h, --help show this help message and exit
--data DATA location of the data corpus
--model MODEL type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)
--emsize EMSIZE size of word embeddings
--nhid NHID number of hidden units per layer
--nlayers NLAYERS number of layers
--lr LR initial learning rate
--clip CLIP gradient clipping
--epochs EPOCHS upper epoch limit
--batch_size N batch size
--bptt BPTT sequence length
--dropout DROPOUT dropout applied to layers (0 = no dropout)
--tied tie the word embedding and softmax weights
--cuda Whether to use gpu
--log-interval N report interval
--save SAVE path to save the final model
```
6 changes: 4 additions & 2 deletions example/gluon/word_language_model/model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import mxnet as mx
import mxnet.ndarray as F
from mxnet import gluon
from mxnet.gluon import nn, rnn

class RNNModel(gluon.Block):
"""A model with an encoder, recurrent layer, and a decoder."""

def __init__(self, mode, vocab_size, num_embed, num_hidden,
num_layers, dropout=0.5, tie_weights=False, **kwargs):
super(RNNModel, self).__init__(**kwargs)
with self.name_scope():
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(vocab_size, num_embed)
self.encoder = nn.Embedding(vocab_size, num_embed,
weight_initializer=mx.init.Uniform(0.1))
if mode == 'rnn_relu':
self.rnn = rnn.RNN(num_hidden, 'relu', num_layers, dropout=dropout,
input_size=num_embed)
Expand Down
51 changes: 29 additions & 22 deletions example/gluon/word_language_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import math
import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon import nn, rnn
import model
import data

Expand All @@ -18,27 +17,25 @@
help='number of hidden units per layer')
parser.add_argument('--nlayers', type=int, default=2,
help='number of layers')
parser.add_argument('--lr', type=float, default=20,
parser.add_argument('--lr', type=float, default=1.0,
help='initial learning rate')
parser.add_argument('--clip', type=float, default=0.25,
parser.add_argument('--clip', type=float, default=0.2,
help='gradient clipping')
parser.add_argument('--epochs', type=int, default=40,
help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=20, metavar='N',
parser.add_argument('--batch_size', type=int, default=32, metavar='N',
help='batch size')
parser.add_argument('--bptt', type=int, default=35,
help='sequence length')
parser.add_argument('--dropout', type=float, default=0.2,
help='dropout applied to layers (0 = no dropout)')
parser.add_argument('--tied', action='store_true',
help='tie the word embedding and softmax weights')
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
parser.add_argument('--cuda', action='store_true',
help='Whether to use gpu')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
help='report interval')
parser.add_argument('--save', type=str, default='model.params',
parser.add_argument('--save', type=str, default='model.params',
help='path to save the final model')
args = parser.parse_args()

Expand Down Expand Up @@ -73,49 +70,46 @@ def batchify(data, batch_size):


ntokens = len(corpus.dictionary)
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied)
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid,
args.nlayers, args.dropout, args.tied)
model.collect_params().initialize(mx.init.Xavier(), ctx=context)
trainer = gluon.Trainer(model.collect_params(), 'sgd',
{'learning_rate': args.lr,
'momentum': 0,
'wd': 0})
{'learning_rate': args.lr,
'momentum': 0,
'wd': 0})
loss = gluon.loss.SoftmaxCrossEntropyLoss()

###############################################################################
# Training code
###############################################################################


def get_batch(source, i):
seq_len = min(args.bptt, source.shape[0] - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len]
return data, target.reshape((-1,))


def detach(hidden):
if isinstance(hidden, (tuple, list)):
hidden = [i.detach() for i in hidden]
else:
hidden = hidden.detach()
return hidden


def eval(data_source):
total_L = 0.0
ntotal = 0
hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
for ibatch, i in enumerate(range(0, data_source.shape[0] - 1, args.bptt)):
for i in range(0, data_source.shape[0] - 1, args.bptt):
data, target = get_batch(data_source, i)
output, hidden = model(data, hidden)
L = loss(output, target)
total_L += mx.nd.sum(L).asscalar()
ntotal += L.size
return total_L / ntotal


def train():
best_val = None
best_val = float("Inf")
for epoch in range(args.epochs):
total_L = 0.0
start_time = time.time()
Expand All @@ -129,15 +123,15 @@ def train():
L.backward()

grads = [i.grad(context) for i in model.collect_params().values()]
# Here gradient is not divided by batch_size yet.
# So we multiply max_norm by batch_size to balance it.
gluon.utils.clip_global_norm(grads, args.clip * args.batch_size)
# Here gradient is for the whole batch.
# So we multiply max_norm by batch_size and bptt size to balance it.
gluon.utils.clip_global_norm(grads, args.clip * args.bptt * args.batch_size)

trainer.step(args.batch_size)
total_L += mx.nd.sum(L).asscalar()

if ibatch % args.log_interval == 0 and ibatch > 0:
cur_L = total_L / args.batch_size / args.bptt / args.log_interval
cur_L = total_L / args.bptt / args.batch_size / args.log_interval
print('[Epoch %d Batch %d] loss %.2f, ppl %.2f'%(
epoch, ibatch, cur_L, math.exp(cur_L)))
total_L = 0.0
Expand All @@ -147,8 +141,21 @@ def train():
print('[Epoch %d] time cost %.2fs, valid loss %.2f, valid ppl %.2f'%(
epoch, time.time()-start_time, val_L, math.exp(val_L)))

if val_L < best_val:
best_val = val_L
test_L = eval(test_data)
model.collect_params().save(args.save)
print('test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
else:
args.lr = args.lr*0.25
trainer._init_optimizer('sgd',
{'learning_rate': args.lr,
'momentum': 0,
'wd': 0})
model.collect_params().load(args.save, context)

if __name__ == '__main__':
train()
model.collect_params().load(args.save, context)
test_L = eval(test_data)
print('test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
print('Best test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
2 changes: 1 addition & 1 deletion example/speech-demo/decode_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def sym_gen(seq_len):
elif decoding_method == METHOD_SIMPLE:
for (ind, utt) in enumerate(batch.utt_id):
if utt != "GAP_UTT":
posteriors = posteriors[:batch.utt_len,1:] - np.log(data_test.label_mean[1:]).T
posteriors = posteriors[:batch.utt_len[0],1:] - np.log(data_test.label_mean[1:]).T
kaldiWriter.write(utt, posteriors)
else:
outputs = module.get_outputs()
Expand Down
2 changes: 1 addition & 1 deletion src/operator/cudnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ class CuDNNRNNOp : public Operator {
cudnnRNNInputMode_t input_mode_;
cudnnDropoutDescriptor_t dropout_desc_;
Storage::Handle dropout_states_, reserve_space_;
uint64_t seed_ = 1337ull;
uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn)
size_t workspace_byte_, reserve_space_byte_, dropout_byte_;
int workspace_size_, dropout_size_;
std::vector<cudnnTensorDescriptor_t> x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_;
Expand Down

0 comments on commit 5af56bb

Please sign in to comment.