Skip to content

Commit

Permalink
add PAD option
Browse files Browse the repository at this point in the history
  • Loading branch information
Baek JeongHun committed May 10, 2019
1 parent 1426f66 commit ce837ab
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 8 deletions.
52 changes: 47 additions & 5 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import re
import six
import math
import lmdb
import torch

Expand All @@ -24,7 +25,7 @@ def __init__(self, opt):
print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}')
assert len(opt.select_data) == len(opt.batch_ratio)

_AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW)
_AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
self.data_loader_list = []
self.dataloader_iter_list = []
batch_size_list = []
Expand Down Expand Up @@ -191,19 +192,60 @@ def __call__(self, img):
return img


class NormalizePAD(object):

def __init__(self, max_size, PAD_type='right'):
self.toTensor = transforms.ToTensor()
self.max_size = max_size
self.max_width_half = math.floor(max_size[2] / 2)
self.PAD_type = PAD_type

def __call__(self, img):
img = self.toTensor(img)
img.sub_(0.5).div_(0.5)
c, h, w = img.size()
Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
Pad_img[:, :, :w] = img # right pad
if self.max_size[2] != w: # add border Pad
Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)

return Pad_img


class AlignCollate(object):

def __init__(self, imgH=32, imgW=100):
def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio_with_pad = keep_ratio_with_pad

def __call__(self, batch):
batch = filter(lambda x: x is not None, batch)
images, labels = zip(*batch)

transform = ResizeNormalize((self.imgW, self.imgH))
image_tensors = [transform(image) for image in images]
image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)
if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper
resized_max_w = self.imgW
transform = NormalizePAD((1, self.imgH, resized_max_w))

resized_images = []
for image in images:
w, h = image.size
ratio = w / float(h)
if math.ceil(self.imgH * ratio) > self.imgW:
resized_w = self.imgW
else:
resized_w = math.ceil(self.imgH * ratio)

resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
resized_images.append(transform(resized_image))
# resized_image.save('./image_test/%d_test.jpg' % w)

image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)

else:
transform = ResizeNormalize((self.imgW, self.imgH))
image_tensors = [transform(image) for image in images]
image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)

return image_tensors, labels

Expand Down
5 changes: 3 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa
print('-' * 80)
for eval_data in eval_data_list:
eval_data_path = os.path.join(opt.eval_data, eval_data)
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW)
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
eval_data = hierarchical_dataset(root=eval_data_path, opt=opt)
evaluation_loader = torch.utils.data.DataLoader(
eval_data, batch_size=evaluation_batch_size,
Expand Down Expand Up @@ -172,7 +172,7 @@ def test(opt):
if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets
benchmark_all_eval(model, criterion, converter, opt)
else:
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW)
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
eval_data = hierarchical_dataset(root=opt.eval_data, opt=opt)
evaluation_loader = torch.utils.data.DataLoader(
eval_data, batch_size=opt.batch_size,
Expand Down Expand Up @@ -201,6 +201,7 @@ def test(opt):
parser.add_argument('--rgb', action='store_true', help='use rgb input')
parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
""" Model Architecture """
parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def train(opt):
opt.batch_ratio = opt.batch_ratio.split('-')
train_dataset = Batch_Balanced_Dataset(opt)

AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW)
AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt)
valid_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=opt.batch_size,
Expand Down Expand Up @@ -224,6 +224,7 @@ def train(opt):
parser.add_argument('--rgb', action='store_true', help='use rgb input')
parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
""" Model Architecture """
parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
Expand Down

0 comments on commit ce837ab

Please sign in to comment.