-
Notifications
You must be signed in to change notification settings - Fork 123
/
ctpn_predict.py
90 lines (74 loc) · 2.64 KB
/
ctpn_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#-*- coding:utf-8 -*-
#'''
# Created on 18-12-11 上午10:03
#
# @Author: Greg Gao(laygin)
#'''
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from ctpn_model import CTPN_Model
from ctpn_utils import gen_anchor, bbox_transfor_inv, clip_box, filter_bbox,nms, TextProposalConnectorOriented
from ctpn_utils import resize
import config
prob_thresh = 0.7
width = 600
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
weights = os.path.join(config.checkpoints_dir, 'trained weights file.pth.tar')
img_path = 'path to test image'
model = CTPN_Model()
model.load_state_dict(torch.load(weights, map_location=device)['model_state_dict'])
model.to(device)
model.eval()
def dis(image):
cv2.imshow('image', image)
cv2.waitKey(0)
cv2.destroyAllWindows()
image = cv2.imread(img_path)
image = resize(image, width=width)
image_c = image.copy()
h, w = image.shape[:2]
image = image.astype(np.float32) - config.IMAGE_MEAN
image = torch.from_numpy(image.transpose(2, 0, 1)).unsqueeze(0).float()
with torch.no_grad():
image = image.to(device)
cls, regr = model(image)
cls_prob = F.softmax(cls, dim=-1).cpu().numpy()
regr = regr.cpu().numpy()
anchor = gen_anchor((int(h / 16), int(w / 16)), 16)
bbox = bbox_transfor_inv(anchor, regr)
bbox = clip_box(bbox, [h, w])
fg = np.where(cls_prob[0, :, 1] > prob_thresh)[0]
select_anchor = bbox[fg, :]
select_score = cls_prob[0, fg, 1]
select_anchor = select_anchor.astype(np.int32)
keep_index = filter_bbox(select_anchor, 16)
# nsm
select_anchor = select_anchor[keep_index]
select_score = select_score[keep_index]
select_score = np.reshape(select_score, (select_score.shape[0], 1))
nmsbox = np.hstack((select_anchor, select_score))
keep = nms(nmsbox, 0.3)
select_anchor = select_anchor[keep]
select_score = select_score[keep]
# text line-
textConn = TextProposalConnectorOriented()
text = textConn.get_text_lines(select_anchor, select_score, [h, w])
print(text)
for i in text:
s = str(round(i[-1] * 100, 2)) + '%'
i = [int(j) for j in i]
cv2.line(image_c, (i[0], i[1]), (i[2], i[3]), (0, 0, 255), 2)
cv2.line(image_c, (i[0], i[1]), (i[4], i[5]), (0, 0, 255), 2)
cv2.line(image_c, (i[6], i[7]), (i[2], i[3]), (0, 0, 255), 2)
cv2.line(image_c, (i[4], i[5]), (i[6], i[7]), (0, 0, 255), 2)
cv2.putText(image_c, s, (i[0]+13, i[1]+13),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(255,0,0),
2,
cv2.LINE_AA)
dis(image_c)