Skip to content

Commit

Permalink
Intergrate with baidu warpctc
Browse files Browse the repository at this point in the history
local config

warpctc init code

warpctc

cpu can run

warpctc gpu

warpctc: use cpu

success ocr example

add warpctc path to config

add readme

Update README.md

fix code style

fix code style

fix code style

add cannot find -lwarpctc to README

add library path to warpctc

label size is diff from output size in ctc

remove change in gitignore

remove debug code

remove debug code

free cuda memory and fix test fail

dmlc-core mshadow to current version
  • Loading branch information
xlvector committed Jun 16, 2016
1 parent 967e07d commit 8fd4d16
Show file tree
Hide file tree
Showing 9 changed files with 808 additions and 1 deletion.
86 changes: 86 additions & 0 deletions example/warpctc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Baidu Warp CTC with Mxnet

Baidu-warpctc is a CTC implement by Baidu which support GPU. CTC can be used with LSTM to solve lable alignment problems in many areas such as OCR, speech recognition.

## Install baidu warpctc

```
cd ~/
git clone https://github.com/baidu-research/warp-ctc
cd warp-ctc
mkdir build
cd build
cmake ..
make
sudo make install
```

## Enable warpctc in mxnet

```
comment out following lines in make/config.mk
WARPCTC_PATH = $(HOME)/warpctc
MXNET_PLUGINS += plugin/warpctc/warpctc.mk
rebuild mxnet by
make clean && make -j4
```

## Run examples

I implement two examples, one is just a toy example which can be used to prove ctc integration is right. The second is a OCR example with LSTM+CTC. You can run it by:

```
cd examples/warpctc
python lstm_ocr.py
```

The OCR example is constructed as follows:

1. I generate 80x30 image for 4 digits captcha by an python captcha library
2. The 80x30 image is used as 80 input for lstm and every input is one column of image (a 30 dim vector)
3. The output layer use CTC loss

Following code show detail construction of the net:

```
def lstm_unroll(num_lstm_layer, seq_len,
num_hidden, num_label):
param_cells = []
last_states = []
for i in range(num_lstm_layer):
param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
h=mx.sym.Variable("l%d_init_h" % i))
last_states.append(state)
assert(len(last_states) == num_lstm_layer)
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
#every column of image is an input, there are seq_len inputs
wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)
hidden_all = []
for seqidx in range(seq_len):
hidden = wordvec[seqidx]
for i in range(num_lstm_layer):
next_state = lstm(num_hidden, indata=hidden,
prev_state=last_states[i],
param=param_cells[i],
seqidx=seqidx, layeridx=i)
hidden = next_state.h
last_states[i] = next_state
hidden_all.append(hidden)
hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11)
# here we do NOT need to transpose label as other lstm examples do
label = mx.sym.Reshape(data=label, target_shape=(0,))
#label should be int type, so use cast
label = mx.sym.Cast(data = label, dtype = 'int32')
sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len)
return sm
```

79 changes: 79 additions & 0 deletions example/warpctc/lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# pylint:skip-file
import sys
sys.path.insert(0, "../../python")
import mxnet as mx
import numpy as np
from collections import namedtuple
import time
import math
LSTMState = namedtuple("LSTMState", ["c", "h"])
LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
"h2h_weight", "h2h_bias"])
LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
"init_states", "last_states",
"seq_data", "seq_labels", "seq_outputs",
"param_blocks"])

def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx):
"""LSTM Cell symbol"""
i2h = mx.sym.FullyConnected(data=indata,
weight=param.i2h_weight,
bias=param.i2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_i2h" % (seqidx, layeridx))
h2h = mx.sym.FullyConnected(data=prev_state.h,
weight=param.h2h_weight,
bias=param.h2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_h2h" % (seqidx, layeridx))
gates = i2h + h2h
slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
name="t%d_l%d_slice" % (seqidx, layeridx))
in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
return LSTMState(c=next_c, h=next_h)


def lstm_unroll(num_lstm_layer, seq_len,
num_hidden, num_label):
param_cells = []
last_states = []
for i in range(num_lstm_layer):
param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
h=mx.sym.Variable("l%d_init_h" % i))
last_states.append(state)
assert(len(last_states) == num_lstm_layer)

# embeding layer
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)

hidden_all = []
for seqidx in range(seq_len):
hidden = wordvec[seqidx]
for i in range(num_lstm_layer):
next_state = lstm(num_hidden, indata=hidden,
prev_state=last_states[i],
param=param_cells[i],
seqidx=seqidx, layeridx=i)
hidden = next_state.h
last_states[i] = next_state
hidden_all.append(hidden)

hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11)

label = mx.sym.Reshape(data=label, target_shape=(0,))
label = mx.sym.Cast(data = label, dtype = 'int32')
sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len)
return sm

166 changes: 166 additions & 0 deletions example/warpctc/lstm_ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys, random
sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx

from lstm import lstm_unroll

from io import BytesIO
from captcha.image import ImageCaptcha
import cv2, random

class SimpleBatch(object):
def __init__(self, data_names, data, label_names, label):
self.data = data
self.label = label
self.data_names = data_names
self.label_names = label_names

self.pad = 0
self.index = None # TODO: what is index?

@property
def provide_data(self):
return [(n, x.shape) for n, x in zip(self.data_names, self.data)]

@property
def provide_label(self):
return [(n, x.shape) for n, x in zip(self.label_names, self.label)]

def gen_rand():
num = random.randint(0, 9999)
buf = str(num)
while len(buf) < 4:
buf = "0" + buf
return buf

def get_label(buf):
ret = np.zeros(4)
for i in range(4):
ret[i] = 1 + int(buf[i])
return ret

class OCRIter(mx.io.DataIter):
def __init__(self, count, batch_size, num_label, init_states):
super(OCRIter, self).__init__()
self.captcha = ImageCaptcha(fonts=['./data/Xerox.ttf'])
self.batch_size = batch_size
self.count = count
self.num_label = num_label
self.init_states = init_states
self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states]
self.provide_data = [('data', (batch_size, 2400))] + init_states
self.provide_label = [('label', (self.batch_size, 4))]

def __iter__(self):
print 'iter'
init_state_names = [x[0] for x in self.init_states]
for k in range(self.count):
data = []
label = []
for i in range(self.batch_size):
num = gen_rand()
img = self.captcha.generate(num)
img = np.fromstring(img.getvalue(), dtype='uint8')
img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (80, 30))
img = img.transpose(1, 0)
img = img.reshape((80 * 30))
img = np.multiply(img, 1/255.0)
data.append(img)
label.append(get_label(num))

data_all = [mx.nd.array(data)] + self.init_state_arrays
label_all = [mx.nd.array(label)]
data_names = ['data'] + init_state_names
label_names = ['label']


data_batch = SimpleBatch(data_names, data_all, label_names, label_all)
yield data_batch

def reset(self):
pass

BATCH_SIZE = 32
SEQ_LENGTH = 80

def ctc_label(p):
ret = []
p1 = [0] + p
for i in range(len(p)):
c1 = p1[i]
c2 = p1[i+1]
if c2 == 0 or c2 == c1:
continue
ret.append(c2)
return ret

def Accuracy(label, pred):
global BATCH_SIZE
global SEQ_LENGTH
hit = 0.
total = 0.
for i in range(BATCH_SIZE):
l = label[i]
p = []
for k in range(SEQ_LENGTH):
p.append(np.argmax(pred[k * BATCH_SIZE + i]))
p = ctc_label(p)
if len(p) == len(l):
match = True
for k in range(len(p)):
if p[k] != int(l[k]):
match = False
break
if match:
hit += 1.0
total += 1.0
return hit / total

if __name__ == '__main__':
num_hidden = 100
num_lstm_layer = 2

num_epoch = 10
learning_rate = 0.001
momentum = 0.9
num_label = 4

contexts = [mx.context.gpu(1)]

def sym_gen(seq_len):
return lstm_unroll(num_lstm_layer, seq_len,
num_hidden=num_hidden,
num_label = num_label)

init_c = [('l%d_init_c'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)]
init_h = [('l%d_init_h'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)]
init_states = init_c + init_h

data_train = OCRIter(10000, BATCH_SIZE, num_label, init_states)
data_val = OCRIter(1000, BATCH_SIZE, num_label, init_states)

symbol = sym_gen(SEQ_LENGTH)

model = mx.model.FeedForward(ctx=contexts,
symbol=symbol,
num_epoch=num_epoch,
learning_rate=learning_rate,
momentum=momentum,
wd=0.00001,
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))

import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

print 'begin fit'

model.fit(X=data_train, eval_data=data_val,
eval_metric = mx.metric.np(Accuracy),
batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50),)

model.save("ocr")
Loading

0 comments on commit 8fd4d16

Please sign in to comment.