forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
local config warpctc init code warpctc cpu can run warpctc gpu warpctc: use cpu success ocr example add warpctc path to config add readme Update README.md fix code style fix code style fix code style add cannot find -lwarpctc to README add library path to warpctc label size is diff from output size in ctc remove change in gitignore remove debug code remove debug code free cuda memory and fix test fail dmlc-core mshadow to current version
- Loading branch information
Showing
9 changed files
with
808 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Baidu Warp CTC with Mxnet | ||
|
||
Baidu-warpctc is a CTC implement by Baidu which support GPU. CTC can be used with LSTM to solve lable alignment problems in many areas such as OCR, speech recognition. | ||
|
||
## Install baidu warpctc | ||
|
||
``` | ||
cd ~/ | ||
git clone https://github.com/baidu-research/warp-ctc | ||
cd warp-ctc | ||
mkdir build | ||
cd build | ||
cmake .. | ||
make | ||
sudo make install | ||
``` | ||
|
||
## Enable warpctc in mxnet | ||
|
||
``` | ||
comment out following lines in make/config.mk | ||
WARPCTC_PATH = $(HOME)/warpctc | ||
MXNET_PLUGINS += plugin/warpctc/warpctc.mk | ||
rebuild mxnet by | ||
make clean && make -j4 | ||
``` | ||
|
||
## Run examples | ||
|
||
I implement two examples, one is just a toy example which can be used to prove ctc integration is right. The second is a OCR example with LSTM+CTC. You can run it by: | ||
|
||
``` | ||
cd examples/warpctc | ||
python lstm_ocr.py | ||
``` | ||
|
||
The OCR example is constructed as follows: | ||
|
||
1. I generate 80x30 image for 4 digits captcha by an python captcha library | ||
2. The 80x30 image is used as 80 input for lstm and every input is one column of image (a 30 dim vector) | ||
3. The output layer use CTC loss | ||
|
||
Following code show detail construction of the net: | ||
|
||
``` | ||
def lstm_unroll(num_lstm_layer, seq_len, | ||
num_hidden, num_label): | ||
param_cells = [] | ||
last_states = [] | ||
for i in range(num_lstm_layer): | ||
param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), | ||
i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), | ||
h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), | ||
h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) | ||
state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), | ||
h=mx.sym.Variable("l%d_init_h" % i)) | ||
last_states.append(state) | ||
assert(len(last_states) == num_lstm_layer) | ||
data = mx.sym.Variable('data') | ||
label = mx.sym.Variable('label') | ||
#every column of image is an input, there are seq_len inputs | ||
wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) | ||
hidden_all = [] | ||
for seqidx in range(seq_len): | ||
hidden = wordvec[seqidx] | ||
for i in range(num_lstm_layer): | ||
next_state = lstm(num_hidden, indata=hidden, | ||
prev_state=last_states[i], | ||
param=param_cells[i], | ||
seqidx=seqidx, layeridx=i) | ||
hidden = next_state.h | ||
last_states[i] = next_state | ||
hidden_all.append(hidden) | ||
hidden_concat = mx.sym.Concat(*hidden_all, dim=0) | ||
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) | ||
# here we do NOT need to transpose label as other lstm examples do | ||
label = mx.sym.Reshape(data=label, target_shape=(0,)) | ||
#label should be int type, so use cast | ||
label = mx.sym.Cast(data = label, dtype = 'int32') | ||
sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len) | ||
return sm | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# pylint:skip-file | ||
import sys | ||
sys.path.insert(0, "../../python") | ||
import mxnet as mx | ||
import numpy as np | ||
from collections import namedtuple | ||
import time | ||
import math | ||
LSTMState = namedtuple("LSTMState", ["c", "h"]) | ||
LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", | ||
"h2h_weight", "h2h_bias"]) | ||
LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol", | ||
"init_states", "last_states", | ||
"seq_data", "seq_labels", "seq_outputs", | ||
"param_blocks"]) | ||
|
||
def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx): | ||
"""LSTM Cell symbol""" | ||
i2h = mx.sym.FullyConnected(data=indata, | ||
weight=param.i2h_weight, | ||
bias=param.i2h_bias, | ||
num_hidden=num_hidden * 4, | ||
name="t%d_l%d_i2h" % (seqidx, layeridx)) | ||
h2h = mx.sym.FullyConnected(data=prev_state.h, | ||
weight=param.h2h_weight, | ||
bias=param.h2h_bias, | ||
num_hidden=num_hidden * 4, | ||
name="t%d_l%d_h2h" % (seqidx, layeridx)) | ||
gates = i2h + h2h | ||
slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, | ||
name="t%d_l%d_slice" % (seqidx, layeridx)) | ||
in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") | ||
in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") | ||
forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") | ||
out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") | ||
next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) | ||
next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") | ||
return LSTMState(c=next_c, h=next_h) | ||
|
||
|
||
def lstm_unroll(num_lstm_layer, seq_len, | ||
num_hidden, num_label): | ||
param_cells = [] | ||
last_states = [] | ||
for i in range(num_lstm_layer): | ||
param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), | ||
i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), | ||
h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), | ||
h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) | ||
state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), | ||
h=mx.sym.Variable("l%d_init_h" % i)) | ||
last_states.append(state) | ||
assert(len(last_states) == num_lstm_layer) | ||
|
||
# embeding layer | ||
data = mx.sym.Variable('data') | ||
label = mx.sym.Variable('label') | ||
wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) | ||
|
||
hidden_all = [] | ||
for seqidx in range(seq_len): | ||
hidden = wordvec[seqidx] | ||
for i in range(num_lstm_layer): | ||
next_state = lstm(num_hidden, indata=hidden, | ||
prev_state=last_states[i], | ||
param=param_cells[i], | ||
seqidx=seqidx, layeridx=i) | ||
hidden = next_state.h | ||
last_states[i] = next_state | ||
hidden_all.append(hidden) | ||
|
||
hidden_concat = mx.sym.Concat(*hidden_all, dim=0) | ||
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) | ||
|
||
label = mx.sym.Reshape(data=label, target_shape=(0,)) | ||
label = mx.sym.Cast(data = label, dtype = 'int32') | ||
sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len) | ||
return sm | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme | ||
# pylint: disable=superfluous-parens, no-member, invalid-name | ||
import sys, random | ||
sys.path.insert(0, "../../python") | ||
import numpy as np | ||
import mxnet as mx | ||
|
||
from lstm import lstm_unroll | ||
|
||
from io import BytesIO | ||
from captcha.image import ImageCaptcha | ||
import cv2, random | ||
|
||
class SimpleBatch(object): | ||
def __init__(self, data_names, data, label_names, label): | ||
self.data = data | ||
self.label = label | ||
self.data_names = data_names | ||
self.label_names = label_names | ||
|
||
self.pad = 0 | ||
self.index = None # TODO: what is index? | ||
|
||
@property | ||
def provide_data(self): | ||
return [(n, x.shape) for n, x in zip(self.data_names, self.data)] | ||
|
||
@property | ||
def provide_label(self): | ||
return [(n, x.shape) for n, x in zip(self.label_names, self.label)] | ||
|
||
def gen_rand(): | ||
num = random.randint(0, 9999) | ||
buf = str(num) | ||
while len(buf) < 4: | ||
buf = "0" + buf | ||
return buf | ||
|
||
def get_label(buf): | ||
ret = np.zeros(4) | ||
for i in range(4): | ||
ret[i] = 1 + int(buf[i]) | ||
return ret | ||
|
||
class OCRIter(mx.io.DataIter): | ||
def __init__(self, count, batch_size, num_label, init_states): | ||
super(OCRIter, self).__init__() | ||
self.captcha = ImageCaptcha(fonts=['./data/Xerox.ttf']) | ||
self.batch_size = batch_size | ||
self.count = count | ||
self.num_label = num_label | ||
self.init_states = init_states | ||
self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] | ||
self.provide_data = [('data', (batch_size, 2400))] + init_states | ||
self.provide_label = [('label', (self.batch_size, 4))] | ||
|
||
def __iter__(self): | ||
print 'iter' | ||
init_state_names = [x[0] for x in self.init_states] | ||
for k in range(self.count): | ||
data = [] | ||
label = [] | ||
for i in range(self.batch_size): | ||
num = gen_rand() | ||
img = self.captcha.generate(num) | ||
img = np.fromstring(img.getvalue(), dtype='uint8') | ||
img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE) | ||
img = cv2.resize(img, (80, 30)) | ||
img = img.transpose(1, 0) | ||
img = img.reshape((80 * 30)) | ||
img = np.multiply(img, 1/255.0) | ||
data.append(img) | ||
label.append(get_label(num)) | ||
|
||
data_all = [mx.nd.array(data)] + self.init_state_arrays | ||
label_all = [mx.nd.array(label)] | ||
data_names = ['data'] + init_state_names | ||
label_names = ['label'] | ||
|
||
|
||
data_batch = SimpleBatch(data_names, data_all, label_names, label_all) | ||
yield data_batch | ||
|
||
def reset(self): | ||
pass | ||
|
||
BATCH_SIZE = 32 | ||
SEQ_LENGTH = 80 | ||
|
||
def ctc_label(p): | ||
ret = [] | ||
p1 = [0] + p | ||
for i in range(len(p)): | ||
c1 = p1[i] | ||
c2 = p1[i+1] | ||
if c2 == 0 or c2 == c1: | ||
continue | ||
ret.append(c2) | ||
return ret | ||
|
||
def Accuracy(label, pred): | ||
global BATCH_SIZE | ||
global SEQ_LENGTH | ||
hit = 0. | ||
total = 0. | ||
for i in range(BATCH_SIZE): | ||
l = label[i] | ||
p = [] | ||
for k in range(SEQ_LENGTH): | ||
p.append(np.argmax(pred[k * BATCH_SIZE + i])) | ||
p = ctc_label(p) | ||
if len(p) == len(l): | ||
match = True | ||
for k in range(len(p)): | ||
if p[k] != int(l[k]): | ||
match = False | ||
break | ||
if match: | ||
hit += 1.0 | ||
total += 1.0 | ||
return hit / total | ||
|
||
if __name__ == '__main__': | ||
num_hidden = 100 | ||
num_lstm_layer = 2 | ||
|
||
num_epoch = 10 | ||
learning_rate = 0.001 | ||
momentum = 0.9 | ||
num_label = 4 | ||
|
||
contexts = [mx.context.gpu(1)] | ||
|
||
def sym_gen(seq_len): | ||
return lstm_unroll(num_lstm_layer, seq_len, | ||
num_hidden=num_hidden, | ||
num_label = num_label) | ||
|
||
init_c = [('l%d_init_c'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)] | ||
init_h = [('l%d_init_h'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)] | ||
init_states = init_c + init_h | ||
|
||
data_train = OCRIter(10000, BATCH_SIZE, num_label, init_states) | ||
data_val = OCRIter(1000, BATCH_SIZE, num_label, init_states) | ||
|
||
symbol = sym_gen(SEQ_LENGTH) | ||
|
||
model = mx.model.FeedForward(ctx=contexts, | ||
symbol=symbol, | ||
num_epoch=num_epoch, | ||
learning_rate=learning_rate, | ||
momentum=momentum, | ||
wd=0.00001, | ||
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34)) | ||
|
||
import logging | ||
head = '%(asctime)-15s %(message)s' | ||
logging.basicConfig(level=logging.DEBUG, format=head) | ||
|
||
print 'begin fit' | ||
|
||
model.fit(X=data_train, eval_data=data_val, | ||
eval_metric = mx.metric.np(Accuracy), | ||
batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50),) | ||
|
||
model.save("ocr") |
Oops, something went wrong.