-
Notifications
You must be signed in to change notification settings - Fork 0
/
detect_text.py
62 lines (44 loc) · 1.5 KB
/
detect_text.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from ultralytics import YOLO
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
import numpy as np
import pytesseract
import matplotlib.pyplot as plt
from PIL import Image
from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg
names_index = {0 : 'id', 1 : 'name', 2 : 'birth' }
imgWidth = 640
imgHeight = 480
# YOLO
sources_model = './model/finalTextDetection.pt'
model = YOLO(sources_model)
# VietOcr
config = Cfg.load_config_from_name('vgg_transformer')
config['weights'] = './model/transformerocr.pth'
config['device'] = 'cuda:0'
config['cnn']['pretrained']=False
config['predictor']['beamsearch']=False
detector = Predictor(config)
def ocr(crop_img) :
text = detector.predict(crop_img)
return text
def get_text(img):
results = model.predict(source = img)
dic = {'id' : [],
'name' : [],
'birth': []}
for box in results[0].boxes :
name = names_index[int(box.cls[0])]
dic[name].append(box.xyxy[0].cpu().numpy().astype(int))
res = {'id' : '',
'name' : '',
'birth': ''}
for key in dic :
for value in dic[key] :
crop_img = img.crop(value)
res[key] = ocr(crop_img)
# res['id'] = pytesseract.image_to_string()
return res
# if __name__ == '__main__':
# path_image = '/home/anhalu/anhalu-data/AN.LAB/id_card_ocr/Data/Data_TextDetection/images/train/10.jpg'
# get_text(path_image)