-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Bidirectional lstm example #2093 (#2096)
* bi lstm examples * add readme * do not change config.mk * config mk
- Loading branch information
1 parent
2cff740
commit 6d99054
Showing
7 changed files
with
597 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -83,6 +83,7 @@ R-package/man/*.Rd | |
*.zip | ||
*ubyte | ||
*.bin | ||
*.txt | ||
|
||
# ipython notebook | ||
*_pb2.py | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
This is an example of using bidirection lstm to sort an array. | ||
|
||
Firstly, generate data by: | ||
|
||
cd data | ||
python gen_data.py | ||
|
||
Then, train the model by: | ||
|
||
python lstm_sort.py | ||
|
||
At last, test model by: | ||
|
||
python infer_sort.py 234 189 785 763 231 | ||
|
||
and will output sorted seq | ||
|
||
189 | ||
231 | ||
234 | ||
763 | ||
785 | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme | ||
# pylint: disable=superfluous-parens, no-member, invalid-name | ||
import sys | ||
sys.path.insert(0, "../../python") | ||
import numpy as np | ||
import mxnet as mx | ||
|
||
from sort_io import BucketSentenceIter, default_build_vocab | ||
from rnn_model import BiLSTMInferenceModel | ||
|
||
def MakeInput(char, vocab, arr): | ||
idx = vocab[char] | ||
tmp = np.zeros((1,)) | ||
tmp[0] = idx | ||
arr[:] = tmp | ||
|
||
if __name__ == '__main__': | ||
batch_size = 1 | ||
buckets = [] | ||
num_hidden = 300 | ||
num_embed = 512 | ||
num_lstm_layer = 2 | ||
|
||
num_epoch = 1 | ||
learning_rate = 0.1 | ||
momentum = 0.9 | ||
|
||
contexts = [mx.context.gpu(i) for i in range(1)] | ||
|
||
vocab = default_build_vocab("./data/sort.train.txt") | ||
rvocab = {} | ||
for k, v in vocab.items(): | ||
rvocab[v] = k | ||
|
||
_, arg_params, __ = mx.model.load_checkpoint("sort", 1) | ||
|
||
model = BiLSTMInferenceModel(5, len(vocab), | ||
num_hidden=num_hidden, num_embed=num_embed, | ||
num_label=len(vocab), arg_params=arg_params, ctx=contexts, dropout=0.0) | ||
|
||
tks = sys.argv[1:] | ||
data = np.zeros((1, len(tks))) | ||
for k in range(len(tks)): | ||
data[0][k] = vocab[tks[k]] | ||
|
||
data = mx.nd.array(data) | ||
prob = model.forward(data) | ||
for k in range(len(tks)): | ||
print rvocab[np.argmax(prob, axis = 1)[k]] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# pylint:skip-file | ||
import sys | ||
sys.path.insert(0, "../../python") | ||
import mxnet as mx | ||
import numpy as np | ||
from collections import namedtuple | ||
import time | ||
import math | ||
LSTMState = namedtuple("LSTMState", ["c", "h"]) | ||
LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", | ||
"h2h_weight", "h2h_bias"]) | ||
LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol", | ||
"init_states", "last_states", "forward_state", "backward_state", | ||
"seq_data", "seq_labels", "seq_outputs", | ||
"param_blocks"]) | ||
|
||
def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.): | ||
"""LSTM Cell symbol""" | ||
if dropout > 0.: | ||
indata = mx.sym.Dropout(data=indata, p=dropout) | ||
i2h = mx.sym.FullyConnected(data=indata, | ||
weight=param.i2h_weight, | ||
bias=param.i2h_bias, | ||
num_hidden=num_hidden * 4, | ||
name="t%d_l%d_i2h" % (seqidx, layeridx)) | ||
h2h = mx.sym.FullyConnected(data=prev_state.h, | ||
weight=param.h2h_weight, | ||
bias=param.h2h_bias, | ||
num_hidden=num_hidden * 4, | ||
name="t%d_l%d_h2h" % (seqidx, layeridx)) | ||
gates = i2h + h2h | ||
slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, | ||
name="t%d_l%d_slice" % (seqidx, layeridx)) | ||
in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") | ||
in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") | ||
forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") | ||
out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") | ||
next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) | ||
next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") | ||
return LSTMState(c=next_c, h=next_h) | ||
|
||
|
||
def bi_lstm_unroll(seq_len, input_size, | ||
num_hidden, num_embed, num_label, dropout=0.): | ||
|
||
embed_weight = mx.sym.Variable("embed_weight") | ||
cls_weight = mx.sym.Variable("cls_weight") | ||
cls_bias = mx.sym.Variable("cls_bias") | ||
last_states = [] | ||
last_states.append(LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h"))) | ||
last_states.append(LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h"))) | ||
forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"), | ||
i2h_bias=mx.sym.Variable("l0_i2h_bias"), | ||
h2h_weight=mx.sym.Variable("l0_h2h_weight"), | ||
h2h_bias=mx.sym.Variable("l0_h2h_bias")) | ||
backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"), | ||
i2h_bias=mx.sym.Variable("l1_i2h_bias"), | ||
h2h_weight=mx.sym.Variable("l1_h2h_weight"), | ||
h2h_bias=mx.sym.Variable("l1_h2h_bias")) | ||
|
||
# embeding layer | ||
data = mx.sym.Variable('data') | ||
label = mx.sym.Variable('softmax_label') | ||
embed = mx.sym.Embedding(data=data, input_dim=input_size, | ||
weight=embed_weight, output_dim=num_embed, name='embed') | ||
wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1) | ||
|
||
forward_hidden = [] | ||
for seqidx in range(seq_len): | ||
hidden = wordvec[seqidx] | ||
next_state = lstm(num_hidden, indata=hidden, | ||
prev_state=last_states[0], | ||
param=forward_param, | ||
seqidx=seqidx, layeridx=0, dropout=dropout) | ||
hidden = next_state.h | ||
last_states[0] = next_state | ||
forward_hidden.append(hidden) | ||
|
||
backward_hidden = [] | ||
for seqidx in range(seq_len): | ||
k = seq_len - seqidx - 1 | ||
hidden = wordvec[k] | ||
next_state = lstm(num_hidden, indata=hidden, | ||
prev_state=last_states[1], | ||
param=backward_param, | ||
seqidx=k, layeridx=1,dropout=dropout) | ||
hidden = next_state.h | ||
last_states[1] = next_state | ||
backward_hidden.insert(0, hidden) | ||
|
||
hidden_all = [] | ||
for i in range(seq_len): | ||
hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1)) | ||
|
||
hidden_concat = mx.sym.Concat(*hidden_all, dim=0) | ||
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label, | ||
weight=cls_weight, bias=cls_bias, name='pred') | ||
|
||
label = mx.sym.transpose(data=label) | ||
label = mx.sym.Reshape(data=label, target_shape=(0,)) | ||
sm = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax') | ||
|
||
return sm | ||
|
||
|
||
def bi_lstm_inference_symbol(input_size, seq_len, | ||
num_hidden, num_embed, num_label, dropout=0.): | ||
seqidx = 0 | ||
embed_weight=mx.sym.Variable("embed_weight") | ||
cls_weight = mx.sym.Variable("cls_weight") | ||
cls_bias = mx.sym.Variable("cls_bias") | ||
last_states = [LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h")), | ||
LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h"))] | ||
forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"), | ||
i2h_bias=mx.sym.Variable("l0_i2h_bias"), | ||
h2h_weight=mx.sym.Variable("l0_h2h_weight"), | ||
h2h_bias=mx.sym.Variable("l0_h2h_bias")) | ||
backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"), | ||
i2h_bias=mx.sym.Variable("l1_i2h_bias"), | ||
h2h_weight=mx.sym.Variable("l1_h2h_weight"), | ||
h2h_bias=mx.sym.Variable("l1_h2h_bias")) | ||
data = mx.sym.Variable("data") | ||
embed = mx.sym.Embedding(data=data, input_dim=input_size, | ||
weight=embed_weight, output_dim=num_embed, name='embed') | ||
wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1) | ||
forward_hidden = [] | ||
for seqidx in range(seq_len): | ||
next_state = lstm(num_hidden, indata=wordvec[seqidx], | ||
prev_state=last_states[0], | ||
param=forward_param, | ||
seqidx=seqidx, layeridx=0, dropout=0.0) | ||
hidden = next_state.h | ||
last_states[0] = next_state | ||
forward_hidden.append(hidden) | ||
|
||
backward_hidden = [] | ||
for seqidx in range(seq_len): | ||
k = seq_len - seqidx - 1 | ||
next_state = lstm(num_hidden, indata=wordvec[k], | ||
prev_state=last_states[1], | ||
param=backward_param, | ||
seqidx=k, layeridx=1, dropout=0.0) | ||
hidden = next_state.h | ||
last_states[1] = next_state | ||
backward_hidden.insert(0, hidden) | ||
|
||
hidden_all = [] | ||
for i in range(seq_len): | ||
hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1)) | ||
hidden_concat = mx.sym.Concat(*hidden_all, dim=0) | ||
fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label, | ||
weight=cls_weight, bias=cls_bias, name='pred') | ||
sm = mx.sym.SoftmaxOutput(data=fc, name='softmax') | ||
output = [sm] | ||
for state in last_states: | ||
output.append(state.c) | ||
output.append(state.h) | ||
return mx.sym.Group(output) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme | ||
# pylint: disable=superfluous-parens, no-member, invalid-name | ||
import sys | ||
sys.path.insert(0, "../../python") | ||
import numpy as np | ||
import mxnet as mx | ||
|
||
from lstm import bi_lstm_unroll | ||
from sort_io import BucketSentenceIter, default_build_vocab | ||
|
||
def Perplexity(label, pred): | ||
label = label.T.reshape((-1,)) | ||
loss = 0. | ||
for i in range(pred.shape[0]): | ||
loss += -np.log(max(1e-10, pred[i][int(label[i])])) | ||
return np.exp(loss / label.size) | ||
|
||
if __name__ == '__main__': | ||
batch_size = 100 | ||
buckets = [] | ||
num_hidden = 300 | ||
num_embed = 512 | ||
num_lstm_layer = 2 | ||
|
||
num_epoch = 1 | ||
learning_rate = 0.1 | ||
momentum = 0.9 | ||
|
||
contexts = [mx.context.gpu(i) for i in range(1)] | ||
|
||
vocab = default_build_vocab("./data/sort.train.txt") | ||
|
||
def sym_gen(seq_len): | ||
return bi_lstm_unroll(seq_len, len(vocab), | ||
num_hidden=num_hidden, num_embed=num_embed, | ||
num_label=len(vocab)) | ||
|
||
init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] | ||
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] | ||
init_states = init_c + init_h | ||
|
||
data_train = BucketSentenceIter("./data/sort.train.txt", vocab, | ||
buckets, batch_size, init_states) | ||
data_val = BucketSentenceIter("./data/sort.valid.txt", vocab, | ||
buckets, batch_size, init_states) | ||
|
||
if len(buckets) == 1: | ||
symbol = sym_gen(buckets[0]) | ||
else: | ||
symbol = sym_gen | ||
|
||
model = mx.model.FeedForward(ctx=contexts, | ||
symbol=symbol, | ||
num_epoch=num_epoch, | ||
learning_rate=learning_rate, | ||
momentum=momentum, | ||
wd=0.00001, | ||
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34)) | ||
|
||
import logging | ||
head = '%(asctime)-15s %(message)s' | ||
logging.basicConfig(level=logging.DEBUG, format=head) | ||
|
||
model.fit(X=data_train, eval_data=data_val, | ||
eval_metric = mx.metric.np(Perplexity), | ||
batch_end_callback=mx.callback.Speedometer(batch_size, 50),) | ||
|
||
model.save("sort") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme | ||
# pylint: disable=superfluous-parens, no-member, invalid-name | ||
import sys | ||
sys.path.insert(0, "../../python") | ||
import numpy as np | ||
import mxnet as mx | ||
|
||
from lstm import LSTMState, LSTMParam, lstm, bi_lstm_inference_symbol | ||
|
||
class BiLSTMInferenceModel(object): | ||
def __init__(self, | ||
seq_len, | ||
input_size, | ||
num_hidden, | ||
num_embed, | ||
num_label, | ||
arg_params, | ||
ctx=mx.cpu(), | ||
dropout=0.): | ||
self.sym = bi_lstm_inference_symbol(input_size, seq_len, | ||
num_hidden, | ||
num_embed, | ||
num_label, | ||
dropout) | ||
batch_size = 1 | ||
init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(2)] | ||
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(2)] | ||
|
||
data_shape = [("data", (batch_size, seq_len, ))] | ||
|
||
input_shapes = dict(init_c + init_h + data_shape) | ||
self.executor = self.sym.simple_bind(ctx=mx.cpu(), **input_shapes) | ||
|
||
for key in self.executor.arg_dict.keys(): | ||
if key in arg_params: | ||
arg_params[key].copyto(self.executor.arg_dict[key]) | ||
|
||
state_name = [] | ||
for i in range(2): | ||
state_name.append("l%d_init_c" % i) | ||
state_name.append("l%d_init_h" % i) | ||
|
||
self.states_dict = dict(zip(state_name, self.executor.outputs[1:])) | ||
self.input_arr = mx.nd.zeros(data_shape[0][1]) | ||
|
||
def forward(self, input_data, new_seq=False): | ||
if new_seq == True: | ||
for key in self.states_dict.keys(): | ||
self.executor.arg_dict[key][:] = 0. | ||
input_data.copyto(self.executor.arg_dict["data"]) | ||
self.executor.forward() | ||
for key in self.states_dict.keys(): | ||
self.states_dict[key].copyto(self.executor.arg_dict[key]) | ||
prob = self.executor.outputs[0].asnumpy() | ||
return prob | ||
|
||
|
Oops, something went wrong.