Skip to content

Commit

Permalink
Bidirectional lstm example #2093 (#2096)
Browse files Browse the repository at this point in the history
* bi lstm examples

* add readme

* do not change config.mk

* config mk
  • Loading branch information
xlvector authored and antinucleon committed May 12, 2016
1 parent 2cff740 commit 6d99054
Show file tree
Hide file tree
Showing 7 changed files with 597 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ R-package/man/*.Rd
*.zip
*ubyte
*.bin
*.txt

# ipython notebook
*_pb2.py
Expand Down
24 changes: 24 additions & 0 deletions example/bi-lstm-sort/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
This is an example of using bidirection lstm to sort an array.

Firstly, generate data by:

cd data
python gen_data.py

Then, train the model by:

python lstm_sort.py

At last, test model by:

python infer_sort.py 234 189 785 763 231

and will output sorted seq

189
231
234
763
785


50 changes: 50 additions & 0 deletions example/bi-lstm-sort/infer_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys
sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx

from sort_io import BucketSentenceIter, default_build_vocab
from rnn_model import BiLSTMInferenceModel

def MakeInput(char, vocab, arr):
idx = vocab[char]
tmp = np.zeros((1,))
tmp[0] = idx
arr[:] = tmp

if __name__ == '__main__':
batch_size = 1
buckets = []
num_hidden = 300
num_embed = 512
num_lstm_layer = 2

num_epoch = 1
learning_rate = 0.1
momentum = 0.9

contexts = [mx.context.gpu(i) for i in range(1)]

vocab = default_build_vocab("./data/sort.train.txt")
rvocab = {}
for k, v in vocab.items():
rvocab[v] = k

_, arg_params, __ = mx.model.load_checkpoint("sort", 1)

model = BiLSTMInferenceModel(5, len(vocab),
num_hidden=num_hidden, num_embed=num_embed,
num_label=len(vocab), arg_params=arg_params, ctx=contexts, dropout=0.0)

tks = sys.argv[1:]
data = np.zeros((1, len(tks)))
for k in range(len(tks)):
data[0][k] = vocab[tks[k]]

data = mx.nd.array(data)
prob = model.forward(data)
for k in range(len(tks)):
print rvocab[np.argmax(prob, axis = 1)[k]]

159 changes: 159 additions & 0 deletions example/bi-lstm-sort/lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# pylint:skip-file
import sys
sys.path.insert(0, "../../python")
import mxnet as mx
import numpy as np
from collections import namedtuple
import time
import math
LSTMState = namedtuple("LSTMState", ["c", "h"])
LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
"h2h_weight", "h2h_bias"])
LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
"init_states", "last_states", "forward_state", "backward_state",
"seq_data", "seq_labels", "seq_outputs",
"param_blocks"])

def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.):
"""LSTM Cell symbol"""
if dropout > 0.:
indata = mx.sym.Dropout(data=indata, p=dropout)
i2h = mx.sym.FullyConnected(data=indata,
weight=param.i2h_weight,
bias=param.i2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_i2h" % (seqidx, layeridx))
h2h = mx.sym.FullyConnected(data=prev_state.h,
weight=param.h2h_weight,
bias=param.h2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_h2h" % (seqidx, layeridx))
gates = i2h + h2h
slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
name="t%d_l%d_slice" % (seqidx, layeridx))
in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
return LSTMState(c=next_c, h=next_h)


def bi_lstm_unroll(seq_len, input_size,
num_hidden, num_embed, num_label, dropout=0.):

embed_weight = mx.sym.Variable("embed_weight")
cls_weight = mx.sym.Variable("cls_weight")
cls_bias = mx.sym.Variable("cls_bias")
last_states = []
last_states.append(LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h")))
last_states.append(LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h")))
forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"),
i2h_bias=mx.sym.Variable("l0_i2h_bias"),
h2h_weight=mx.sym.Variable("l0_h2h_weight"),
h2h_bias=mx.sym.Variable("l0_h2h_bias"))
backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"),
i2h_bias=mx.sym.Variable("l1_i2h_bias"),
h2h_weight=mx.sym.Variable("l1_h2h_weight"),
h2h_bias=mx.sym.Variable("l1_h2h_bias"))

# embeding layer
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')
embed = mx.sym.Embedding(data=data, input_dim=input_size,
weight=embed_weight, output_dim=num_embed, name='embed')
wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1)

forward_hidden = []
for seqidx in range(seq_len):
hidden = wordvec[seqidx]
next_state = lstm(num_hidden, indata=hidden,
prev_state=last_states[0],
param=forward_param,
seqidx=seqidx, layeridx=0, dropout=dropout)
hidden = next_state.h
last_states[0] = next_state
forward_hidden.append(hidden)

backward_hidden = []
for seqidx in range(seq_len):
k = seq_len - seqidx - 1
hidden = wordvec[k]
next_state = lstm(num_hidden, indata=hidden,
prev_state=last_states[1],
param=backward_param,
seqidx=k, layeridx=1,dropout=dropout)
hidden = next_state.h
last_states[1] = next_state
backward_hidden.insert(0, hidden)

hidden_all = []
for i in range(seq_len):
hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1))

hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label,
weight=cls_weight, bias=cls_bias, name='pred')

label = mx.sym.transpose(data=label)
label = mx.sym.Reshape(data=label, target_shape=(0,))
sm = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

return sm


def bi_lstm_inference_symbol(input_size, seq_len,
num_hidden, num_embed, num_label, dropout=0.):
seqidx = 0
embed_weight=mx.sym.Variable("embed_weight")
cls_weight = mx.sym.Variable("cls_weight")
cls_bias = mx.sym.Variable("cls_bias")
last_states = [LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h")),
LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h"))]
forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"),
i2h_bias=mx.sym.Variable("l0_i2h_bias"),
h2h_weight=mx.sym.Variable("l0_h2h_weight"),
h2h_bias=mx.sym.Variable("l0_h2h_bias"))
backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"),
i2h_bias=mx.sym.Variable("l1_i2h_bias"),
h2h_weight=mx.sym.Variable("l1_h2h_weight"),
h2h_bias=mx.sym.Variable("l1_h2h_bias"))
data = mx.sym.Variable("data")
embed = mx.sym.Embedding(data=data, input_dim=input_size,
weight=embed_weight, output_dim=num_embed, name='embed')
wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1)
forward_hidden = []
for seqidx in range(seq_len):
next_state = lstm(num_hidden, indata=wordvec[seqidx],
prev_state=last_states[0],
param=forward_param,
seqidx=seqidx, layeridx=0, dropout=0.0)
hidden = next_state.h
last_states[0] = next_state
forward_hidden.append(hidden)

backward_hidden = []
for seqidx in range(seq_len):
k = seq_len - seqidx - 1
next_state = lstm(num_hidden, indata=wordvec[k],
prev_state=last_states[1],
param=backward_param,
seqidx=k, layeridx=1, dropout=0.0)
hidden = next_state.h
last_states[1] = next_state
backward_hidden.insert(0, hidden)

hidden_all = []
for i in range(seq_len):
hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1))
hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label,
weight=cls_weight, bias=cls_bias, name='pred')
sm = mx.sym.SoftmaxOutput(data=fc, name='softmax')
output = [sm]
for state in last_states:
output.append(state.c)
output.append(state.h)
return mx.sym.Group(output)

68 changes: 68 additions & 0 deletions example/bi-lstm-sort/lstm_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys
sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx

from lstm import bi_lstm_unroll
from sort_io import BucketSentenceIter, default_build_vocab

def Perplexity(label, pred):
label = label.T.reshape((-1,))
loss = 0.
for i in range(pred.shape[0]):
loss += -np.log(max(1e-10, pred[i][int(label[i])]))
return np.exp(loss / label.size)

if __name__ == '__main__':
batch_size = 100
buckets = []
num_hidden = 300
num_embed = 512
num_lstm_layer = 2

num_epoch = 1
learning_rate = 0.1
momentum = 0.9

contexts = [mx.context.gpu(i) for i in range(1)]

vocab = default_build_vocab("./data/sort.train.txt")

def sym_gen(seq_len):
return bi_lstm_unroll(seq_len, len(vocab),
num_hidden=num_hidden, num_embed=num_embed,
num_label=len(vocab))

init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
init_states = init_c + init_h

data_train = BucketSentenceIter("./data/sort.train.txt", vocab,
buckets, batch_size, init_states)
data_val = BucketSentenceIter("./data/sort.valid.txt", vocab,
buckets, batch_size, init_states)

if len(buckets) == 1:
symbol = sym_gen(buckets[0])
else:
symbol = sym_gen

model = mx.model.FeedForward(ctx=contexts,
symbol=symbol,
num_epoch=num_epoch,
learning_rate=learning_rate,
momentum=momentum,
wd=0.00001,
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))

import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

model.fit(X=data_train, eval_data=data_val,
eval_metric = mx.metric.np(Perplexity),
batch_end_callback=mx.callback.Speedometer(batch_size, 50),)

model.save("sort")
57 changes: 57 additions & 0 deletions example/bi-lstm-sort/rnn_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys
sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx

from lstm import LSTMState, LSTMParam, lstm, bi_lstm_inference_symbol

class BiLSTMInferenceModel(object):
def __init__(self,
seq_len,
input_size,
num_hidden,
num_embed,
num_label,
arg_params,
ctx=mx.cpu(),
dropout=0.):
self.sym = bi_lstm_inference_symbol(input_size, seq_len,
num_hidden,
num_embed,
num_label,
dropout)
batch_size = 1
init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(2)]
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(2)]

data_shape = [("data", (batch_size, seq_len, ))]

input_shapes = dict(init_c + init_h + data_shape)
self.executor = self.sym.simple_bind(ctx=mx.cpu(), **input_shapes)

for key in self.executor.arg_dict.keys():
if key in arg_params:
arg_params[key].copyto(self.executor.arg_dict[key])

state_name = []
for i in range(2):
state_name.append("l%d_init_c" % i)
state_name.append("l%d_init_h" % i)

self.states_dict = dict(zip(state_name, self.executor.outputs[1:]))
self.input_arr = mx.nd.zeros(data_shape[0][1])

def forward(self, input_data, new_seq=False):
if new_seq == True:
for key in self.states_dict.keys():
self.executor.arg_dict[key][:] = 0.
input_data.copyto(self.executor.arg_dict["data"])
self.executor.forward()
for key in self.states_dict.keys():
self.states_dict[key].copyto(self.executor.arg_dict[key])
prob = self.executor.outputs[0].asnumpy()
return prob


Loading

0 comments on commit 6d99054

Please sign in to comment.