forked from FangShancheng/ABINet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
109 lines (97 loc) · 4.16 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import logging
import os
import glob
import tqdm
import torch
import PIL
import cv2
import numpy as np
import torch.nn.functional as F
from torchvision import transforms
from utils import Config, Logger, CharsetMapper
def get_model(config):
import importlib
names = config.model_name.split('.')
module_name, class_name = '.'.join(names[:-1]), names[-1]
cls = getattr(importlib.import_module(module_name), class_name)
model = cls(config)
logging.info(model)
model = model.eval()
return model
def preprocess(img, width, height):
img = cv2.resize(np.array(img), (width, height))
img = transforms.ToTensor()(img).unsqueeze(0)
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
return (img-mean[...,None,None]) / std[...,None,None]
def postprocess(output, charset, model_eval):
def _get_output(last_output, model_eval):
if isinstance(last_output, (tuple, list)):
for res in last_output:
if res['name'] == model_eval: output = res
else: output = last_output
return output
def _decode(logit):
""" Greed decode """
out = F.softmax(logit, dim=2)
pt_text, pt_scores, pt_lengths = [], [], []
for o in out:
text = charset.get_text(o.argmax(dim=1), padding=False, trim=False)
text = text.split(charset.null_char)[0] # end at end-token
pt_text.append(text)
pt_scores.append(o.max(dim=1)[0])
pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token
return pt_text, pt_scores, pt_lengths
output = _get_output(output, model_eval)
logits, pt_lengths = output['logits'], output['pt_lengths']
pt_text, pt_scores, pt_lengths_ = _decode(logits)
return pt_text, pt_scores, pt_lengths_
def load(model, file, device=None, strict=True):
if device is None: device = 'cpu'
elif isinstance(device, int): device = torch.device('cuda', device)
assert os.path.isfile(file)
state = torch.load(file, map_location=device)
if set(state.keys()) == {'model', 'opt'}:
state = state['model']
model.load_state_dict(state, strict=strict)
return model
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='configs/train_abinet.yaml',
help='path to config file')
parser.add_argument('--input', type=str, default='figs/test')
parser.add_argument('--cuda', type=int, default=-1)
parser.add_argument('--checkpoint', type=str, default='workdir/train-abinet/best-train-abinet.pth')
parser.add_argument('--model_eval', type=str, default='alignment',
choices=['alignment', 'vision', 'language'])
args = parser.parse_args()
config = Config(args.config)
if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
if args.model_eval is not None: config.model_eval = args.model_eval
config.global_phase = 'test'
config.model_vision_checkpoint, config.model_language_checkpoint = None, None
device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}'
Logger.init(config.global_workdir, config.global_name, config.global_phase)
Logger.enable_file()
logging.info(config)
logging.info('Construct model.')
model = get_model(config).to(device)
model = load(model, config.model_checkpoint, device=device)
charset = CharsetMapper(filename=config.dataset_charset_path,
max_length=config.dataset_max_length + 1)
if os.path.isdir(args.input):
paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)]
else:
paths = glob.glob(os.path.expanduser(args.input))
assert paths, "The input path(s) was not found"
paths = sorted(paths)
for path in tqdm.tqdm(paths):
img = PIL.Image.open(path).convert('RGB')
img = preprocess(img, config.dataset_image_width, config.dataset_image_height)
img = img.to(device)
res = model(img)
pt_text, _, __ = postprocess(res, charset, config.model_eval)
logging.info(f'{path}: {pt_text[0]}')
if __name__ == '__main__':
main()