Skip to content

Commit f2d4864

Browse files
committed
init commit
0 parents  commit f2d4864

File tree

8 files changed

+683
-0
lines changed

8 files changed

+683
-0
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Convolutional Recurrent Neural Network
2+
======================================
3+
4+
This software implements the Convolutional Recurrent Neural Network (CRNN) in pytorch.
5+
Origin software could be found in [crnn](https://github.com/bgshih/crnn)

crnn_main.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
from __future__ import print_function
2+
import argparse
3+
import random
4+
import torch
5+
import torch.backends.cudnn as cudnn
6+
import torch.optim as optim
7+
import torch.utils.data
8+
from torch.autograd import Variable
9+
import numpy as np
10+
from warpctc_pytorch import CTCLoss
11+
import os
12+
import utils
13+
import dataset
14+
15+
import models.crnn as crnn
16+
17+
parser = argparse.ArgumentParser()
18+
parser.add_argument('--trainroot', required=True, help='path to dataset')
19+
parser.add_argument('--valroot', required=True, help='path to dataset')
20+
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
21+
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
22+
parser.add_argument('--imgH', type=int, default=64, help='the height / width of the input image to network')
23+
parser.add_argument('--nh', type=int, default=100, help='size of the lstm hidden state')
24+
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
25+
parser.add_argument('--lr', type=float, default=1, help='learning rate for Critic, default=0.00005')
26+
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
27+
parser.add_argument('--cuda', action='store_true', help='enables cuda')
28+
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
29+
parser.add_argument('--crnn', default='', help="path to crnn (to continue training)")
30+
parser.add_argument('--alphabet', type=str, default='abcdefghijklmnopqrstuvwxyz0123456789')
31+
parser.add_argument('--Diters', type=int, default=5, help='number of D iters per each G iter')
32+
parser.add_argument('--experiment', default=None, help='Where to store samples and models')
33+
parser.add_argument('--displayInterval', type=int, default=500, help='Interval to be displayed')
34+
parser.add_argument('--n_test_disp', type=int, default=10, help='Number of samples to display when test')
35+
parser.add_argument('--valInterval', type=int, default=500, help='Interval to be displayed')
36+
parser.add_argument('--saveInterval', type=int, default=500, help='Interval to be displayed')
37+
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
38+
parser.add_argument('--adadelta', action='store_true', help='Whether to use adadelta (default is rmsprop)')
39+
parser.add_argument('--keep_ratio', action='store_true', help='whether to keep ratio for image resize')
40+
parser.add_argument('--random_sample', action='store_true', help='whether to sample the dataset with random sampler')
41+
opt = parser.parse_args()
42+
print(opt)
43+
44+
if opt.experiment is None:
45+
opt.experiment = 'samples'
46+
os.system('mkdir {0}'.format(opt.experiment))
47+
48+
opt.manualSeed = random.randint(1, 10000) # fix seed
49+
print("Random Seed: ", opt.manualSeed)
50+
random.seed(opt.manualSeed)
51+
np.random.seed(opt.manualSeed)
52+
torch.manual_seed(opt.manualSeed)
53+
54+
cudnn.benchmark = True
55+
56+
if torch.cuda.is_available() and not opt.cuda:
57+
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
58+
59+
train_dataset = dataset.lmdbDataset(root=opt.trainroot)
60+
assert train_dataset
61+
if not opt.random_sample:
62+
sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
63+
else:
64+
sampler = None
65+
train_loader = torch.utils.data.DataLoader(
66+
train_dataset, batch_size=opt.batchSize,
67+
shuffle=True, sampler=sampler,
68+
num_workers=int(opt.workers),
69+
collate_fn=dataset.alignCollate(imgH=opt.imgH,
70+
keep_ratio=opt.keep_ratio))
71+
test_dataset = dataset.lmdbDataset(root=opt.valroot, transform=dataset.resizeNormalize((128, 32)))
72+
73+
ngpu = int(opt.ngpu)
74+
nh = int(opt.nh)
75+
alphabet = opt.alphabet
76+
nclass = len(alphabet) + 1
77+
nc = 1
78+
79+
converter = utils.strLabelConverter(alphabet)
80+
criterion = CTCLoss()
81+
82+
83+
# custom weights initialization called on crnn
84+
def weights_init(m):
85+
classname = m.__class__.__name__
86+
if classname.find('Conv') != -1:
87+
m.weight.data.normal_(0.0, 0.02)
88+
elif classname.find('BatchNorm') != -1:
89+
m.weight.data.normal_(1.0, 0.02)
90+
m.bias.data.fill_(0)
91+
92+
crnn = crnn.CRNN(opt.imgH, nc, nclass, nh, ngpu)
93+
crnn.apply(weights_init)
94+
if opt.crnn != '':
95+
print('loading pretrained model from %s' % opt.crnn)
96+
crnn.load_state_dict(torch.load(opt.crnn))
97+
print(crnn)
98+
99+
image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH)
100+
text = torch.IntTensor(opt.batchSize * 5)
101+
length = torch.IntTensor(opt.batchSize)
102+
103+
if opt.cuda:
104+
crnn.cuda()
105+
image = image.cuda()
106+
criterion = criterion.cuda()
107+
108+
image = Variable(image)
109+
text = Variable(text)
110+
length = Variable(length)
111+
112+
# loss averager
113+
loss_avg = utils.averager()
114+
115+
# setup optimizer
116+
if opt.adam:
117+
optimizer = optim.Adam(crnn.parameters(), lr=opt.lrD, betas=(opt.beta1, 0.999))
118+
elif opt.adadelta:
119+
optimizer = optim.Adadelta(crnn.parameters(), lr=opt.lrD)
120+
else:
121+
optimizer = optim.RMSprop(crnn.parameters(), lr=opt.lrD)
122+
123+
124+
def val(net, dataset, criterion, max_iter=100):
125+
print('Start val')
126+
127+
for p in crnn.parameters():
128+
p.requires_grad = False
129+
130+
net.eval()
131+
data_loader = torch.utils.data.DataLoader(
132+
dataset, shuffle=True, batch_size=opt.batchSize, num_workers=int(opt.workers))
133+
val_iter = iter(data_loader)
134+
135+
i = 0
136+
n_correct = 0
137+
loss_avg = utils.averager()
138+
139+
for i in range(max_iter):
140+
data = val_iter.next()
141+
i += 1
142+
cpu_images, cpu_texts = data
143+
batch_size = cpu_images.size(0)
144+
utils.loadData(image, cpu_images)
145+
t, l = converter.encode(cpu_texts)
146+
utils.loadData(text, t)
147+
utils.loadData(length, l)
148+
149+
preds = crnn(image)
150+
preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
151+
cost = criterion(preds, text, preds_size, length) / batch_size
152+
loss_avg.add(cost)
153+
154+
_, preds = preds.max(2)
155+
preds = preds.squeeze(2)
156+
preds = preds.transpose(1, 0).contiguous().view(-1)
157+
sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
158+
for pred, target in zip(sim_preds, cpu_texts):
159+
if pred == target.lower():
160+
n_correct += 1
161+
162+
raw_preds = converter.decode(preds.data, preds_size.data, raw=True)
163+
for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts):
164+
print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
165+
166+
accuracy = n_correct / float(max_iter * opt.batchSize)
167+
print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
168+
169+
170+
def trainBatch(net, criterion, optimizer):
171+
data = train_iter.next()
172+
cpu_images, cpu_texts = data
173+
batch_size = cpu_images.size(0)
174+
utils.loadData(image, cpu_images)
175+
t, l = converter.encode(cpu_texts)
176+
utils.loadData(text, t)
177+
utils.loadData(length, l)
178+
179+
preds = crnn(image)
180+
preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
181+
cost = criterion(preds, text, preds_size, length) / batch_size
182+
crnn.zero_grad()
183+
cost.backward()
184+
optimizer.step()
185+
return cost
186+
187+
188+
for epoch in range(opt.niter):
189+
train_iter = iter(train_loader)
190+
i = 0
191+
while i < len(train_loader):
192+
for p in crnn.parameters():
193+
p.requires_grad = True
194+
crnn.train()
195+
196+
cost = trainBatch(crnn, criterion, optimizer)
197+
loss_avg.add(cost)
198+
i += 1
199+
200+
if i % opt.displayInterval == 0:
201+
print('[%d/%d][%d/%d] Loss: %f' % (epoch, opt.niter, i, len(train_loader), loss_avg.val()))
202+
loss_avg.reset()
203+
204+
if i % opt.valInterval == 0:
205+
val(crnn, test_dataset, criterion)
206+
207+
# do checkpointing
208+
if i % opt.saveInterval == 0:
209+
torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.experiment, epoch, i))

dataset.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
#!/usr/bin/python
2+
# encoding: utf-8
3+
4+
import random
5+
import torch
6+
from torch.utils.data import Dataset
7+
from torch.utils.data import sampler
8+
import torchvision.transforms as transforms
9+
import lmdb
10+
import six
11+
import sys
12+
from PIL import Image
13+
import numpy as np
14+
15+
16+
class lmdbDataset(Dataset):
17+
18+
def __init__(self, root=None, transform=None, target_transform=None):
19+
self.env = lmdb.open(
20+
root,
21+
max_readers=1,
22+
readonly=True,
23+
lock=False,
24+
readahead=False,
25+
meminit=False)
26+
27+
if not self.env:
28+
print('cannot creat lmdb from %s' % (root))
29+
sys.exit(0)
30+
31+
with self.env.begin(write=False) as txn:
32+
nSamples = int(txn.get('num-samples'))
33+
self.nSamples = nSamples
34+
35+
self.transform = transform
36+
self.target_transform = target_transform
37+
38+
def __len__(self):
39+
return self.nSamples
40+
41+
def __getitem__(self, index):
42+
assert index <= len(self), 'index range error'
43+
index += 1
44+
with self.env.begin(write=False) as txn:
45+
img_key = 'image-%09d' % index
46+
imgbuf = txn.get(img_key)
47+
48+
buf = six.BytesIO()
49+
buf.write(imgbuf)
50+
buf.seek(0)
51+
try:
52+
img = Image.open(buf).convert('L')
53+
except IOError:
54+
print('Corrupted image for %d' % index)
55+
return self[index + 1]
56+
57+
if self.transform is not None:
58+
img = self.transform(img)
59+
60+
label_key = 'label-%09d' % index
61+
label = str(txn.get(label_key))
62+
63+
if self.target_transform is not None:
64+
label = self.target_transform(label)
65+
66+
return (img, label)
67+
68+
69+
class resizeNormalize(object):
70+
71+
def __init__(self, size, interpolation=Image.BILINEAR):
72+
self.size = size
73+
self.interpolation = interpolation
74+
self.toTensor = transforms.ToTensor()
75+
76+
def __call__(self, img):
77+
img = img.resize(self.size, self.interpolation)
78+
img = self.toTensor(img)
79+
img.sub_(0.5).div_(0.5)
80+
return img
81+
82+
83+
class randomSequentialSampler(sampler.Sampler):
84+
85+
def __init__(self, data_source, batch_size):
86+
self.num_samples = len(data_source)
87+
self.batch_size = batch_size
88+
89+
def __iter__(self):
90+
n_batch = len(self) // self.batch_size
91+
tail = len(self) % self.batch_size
92+
index = torch.LongTensor(len(self)).fill_(0)
93+
for i in range(n_batch):
94+
random_start = random.randint(0, len(self) - self.batch_size)
95+
batch_index = random_start + torch.range(0, self.batch_size - 1)
96+
index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index
97+
# deal with tail
98+
if tail:
99+
random_start = random.randint(0, len(self) - self.batch_size)
100+
tail_index = random_start + torch.range(0, tail - 1)
101+
index[(i + 1) * self.batch_size:] = tail_index
102+
103+
return iter(index)
104+
105+
def __len__(self):
106+
return self.num_samples
107+
108+
109+
class alignCollate(object):
110+
111+
def __init__(self, imgH=32, imgW=128, keep_ratio=False, min_ratio=1):
112+
self.imgH = imgH
113+
self.imgW = imgW
114+
self.keep_ratio = keep_ratio
115+
self.min_ratio = min_ratio
116+
117+
def __call__(self, batch):
118+
images, labels = zip(*batch)
119+
120+
imgH = self.imgH
121+
imgW = self.imgW
122+
if self.keep_ratio:
123+
ratios = []
124+
for image in images:
125+
w, h = image.size
126+
ratios.append(w / float(h))
127+
ratios.sort()
128+
max_ratio = ratios[-1]
129+
imgW = int(np.floor(max_ratio * imgH))
130+
imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW
131+
132+
transform = resizeNormalize((imgW, imgH))
133+
images = [transform(image) for image in images]
134+
images = torch.cat([t.unsqueeze(0) for t in images], 0)
135+
136+
return images, labels

models/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)