diff --git a/dataset.py b/dataset.py index 8e82805496..c63e386b66 100755 --- a/dataset.py +++ b/dataset.py @@ -2,6 +2,7 @@ import sys import re import six +import math import lmdb import torch @@ -24,7 +25,7 @@ def __init__(self, opt): print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}') assert len(opt.select_data) == len(opt.batch_ratio) - _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW) + _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) self.data_loader_list = [] self.dataloader_iter_list = [] batch_size_list = [] @@ -191,19 +192,60 @@ def __call__(self, img): return img +class NormalizePAD(object): + + def __init__(self, max_size, PAD_type='right'): + self.toTensor = transforms.ToTensor() + self.max_size = max_size + self.max_width_half = math.floor(max_size[2] / 2) + self.PAD_type = PAD_type + + def __call__(self, img): + img = self.toTensor(img) + img.sub_(0.5).div_(0.5) + c, h, w = img.size() + Pad_img = torch.FloatTensor(*self.max_size).fill_(0) + Pad_img[:, :, :w] = img # right pad + if self.max_size[2] != w: # add border Pad + Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) + + return Pad_img + + class AlignCollate(object): - def __init__(self, imgH=32, imgW=100): + def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False): self.imgH = imgH self.imgW = imgW + self.keep_ratio_with_pad = keep_ratio_with_pad def __call__(self, batch): batch = filter(lambda x: x is not None, batch) images, labels = zip(*batch) - transform = ResizeNormalize((self.imgW, self.imgH)) - image_tensors = [transform(image) for image in images] - image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) + if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper + resized_max_w = self.imgW + transform = NormalizePAD((1, self.imgH, resized_max_w)) + + resized_images = [] + for image in images: + w, h = image.size + ratio = w / float(h) + if math.ceil(self.imgH * ratio) > self.imgW: + resized_w = self.imgW + else: + resized_w = math.ceil(self.imgH * ratio) + + resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC) + resized_images.append(transform(resized_image)) + # resized_image.save('./image_test/%d_test.jpg' % w) + + image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0) + + else: + transform = ResizeNormalize((self.imgW, self.imgH)) + image_tensors = [transform(image) for image in images] + image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) return image_tensors, labels diff --git a/test.py b/test.py index d09bd33971..7832be7d55 100755 --- a/test.py +++ b/test.py @@ -32,7 +32,7 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa print('-' * 80) for eval_data in eval_data_list: eval_data_path = os.path.join(opt.eval_data, eval_data) - AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW) + AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data = hierarchical_dataset(root=eval_data_path, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=evaluation_batch_size, @@ -172,7 +172,7 @@ def test(opt): if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets benchmark_all_eval(model, criterion, converter, opt) else: - AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW) + AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, @@ -201,6 +201,7 @@ def test(opt): parser.add_argument('--rgb', action='store_true', help='use rgb input') parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') + parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') """ Model Architecture """ parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet') diff --git a/train.py b/train.py index c5717a351f..e8bf30bc7a 100755 --- a/train.py +++ b/train.py @@ -24,7 +24,7 @@ def train(opt): opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) - AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW) + AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, @@ -224,6 +224,7 @@ def train(opt): parser.add_argument('--rgb', action='store_true', help='use rgb input') parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') + parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') """ Model Architecture """ parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')