Skip to content

Commit

Permalink
Merge pull request clovaai#209 from ku21fan/master
Browse files Browse the repository at this point in the history
add Baidu warpctc option to reproduce CTC results of our paper.
  • Loading branch information
gwkrsrch authored Aug 3, 2020
2 parents 32fd91c + f0a295e commit 3c2c89a
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 7 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Based on this framework, we recorded the 1st place of [ICDAR2013 focused scene t
The difference between our paper and ICDAR challenge is summarized [here](https://github.com/clovaai/deep-text-recognition-benchmark/issues/13).

## Updates
**Aug 3, 2020**: added [guideline to use Baidu warpctc](https://github.com/clovaai/deep-text-recognition-benchmark/pull/209) which reproduces CTC results of our paper. <br>
**Dec 27, 2019**: added [FLOPS](https://github.com/clovaai/deep-text-recognition-benchmark/issues/125) in our paper, and minor updates such as log_dataset.txt and [ICDAR2019-NormalizedED](https://github.com/clovaai/deep-text-recognition-benchmark/blob/86451088248e0490ff8b5f74d33f7d014f6c249a/test.py#L139-L165). <br>
**Oct 22, 2019**: added [confidence score](https://github.com/clovaai/deep-text-recognition-benchmark/issues/82), and arranged the output form of training logs. <br>
**Jul 31, 2019**: The paper is accepted at International Conference on Computer Vision (ICCV), Seoul 2019, as an oral talk. <br>
Expand Down
16 changes: 14 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857',
'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']

# # To easily compute the total accuracy of our paper.
# eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_867',
# 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80']

if calculate_infer_time:
evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image.
else:
Expand Down Expand Up @@ -100,10 +104,17 @@ def validation(model, criterion, evaluation_loader, converter, opt):
# Calculate evaluation loss for CTC deocder.
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
# permute 'preds' to use CTCloss format
cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
if opt.baiduCTC:
cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) / batch_size
else:
cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)

# Select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)
if opt.baiduCTC:
_, preds_index = preds.max(2)
preds_index = preds_index.view(-1)
else:
_, preds_index = preds.max(2)
preds_str = converter.decode(preds_index.data, preds_size.data)

else:
Expand Down Expand Up @@ -246,6 +257,7 @@ def test(opt):
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode')
parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode')
""" Model Architecture """
parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
Expand Down
23 changes: 18 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.utils.data
import numpy as np

from utils import CTCLabelConverter, AttnLabelConverter, Averager
from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
from model import Model
from test import validation
Expand Down Expand Up @@ -45,7 +45,10 @@ def train(opt):

""" model configuration """
if 'CTC' in opt.Prediction:
converter = CTCLabelConverter(opt.character)
if opt.baiduCTC:
converter = CTCLabelConverterForBaiduWarpctc(opt.character)
else:
converter = CTCLabelConverter(opt.character)
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
Expand Down Expand Up @@ -86,7 +89,12 @@ def train(opt):

""" setup loss """
if 'CTC' in opt.Prediction:
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
if opt.baiduCTC:
# need to install warpctc. see our guideline.
from warpctc_pytorch import CTCLoss
criterion = CTCLoss()
else:
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
else:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
# loss averager
Expand Down Expand Up @@ -144,8 +152,12 @@ def train(opt):
if 'CTC' in opt.Prediction:
preds = model(image, text)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.log_softmax(2).permute(1, 0, 2)
cost = criterion(preds, text, preds_size, length)
if opt.baiduCTC:
preds = preds.permute(1, 0, 2) # to use CTCLoss format
cost = criterion(preds, text, preds_size, length) / batch_size
else:
preds = preds.log_softmax(2).permute(1, 0, 2)
cost = criterion(preds, text, preds_size, length)

else:
preds = model(image, text[:, :-1]) # align with Attention.forward
Expand Down Expand Up @@ -232,6 +244,7 @@ def train(opt):
parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95')
parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5')
parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode')
""" Data processing """
parser.add_argument('--select_data', type=str, default='MJ-ST',
help='select training data (default is MJ-ST, which means MJ and ST used as training data)')
Expand Down
47 changes: 47 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,53 @@ def decode(self, text_index, length):
return texts


class CTCLabelConverterForBaiduWarpctc(object):
""" Convert between text-label and text-index for baidu warpctc """

def __init__(self, character):
# character (str): set of the possible characters.
dict_character = list(character)

self.dict = {}
for i, char in enumerate(dict_character):
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
self.dict[char] = i + 1

self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)

def encode(self, text, batch_max_length=25):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
length = [len(s) for s in text]
text = ''.join(text)
text = [self.dict[char] for char in text]

return (torch.IntTensor(text), torch.IntTensor(length))

def decode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
index = 0
for l in length:
t = text_index[index:index + l]

char_list = []
for i in range(l):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
char_list.append(self.character[t[i]])
text = ''.join(char_list)

texts.append(text)
index += l
return texts


class AttnLabelConverter(object):
""" Convert between text-label and text-index """

Expand Down

0 comments on commit 3c2c89a

Please sign in to comment.