Skip to content

Commit

Permalink
box可视化
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaofengShi committed Apr 23, 2018
1 parent 1ffed36 commit e24eb9c
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 145 deletions.
66 changes: 34 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# 本文基于tensorflow、keras/pytorch实现对自然场景的文字检测及端到端的OCR中文文字识别
# `本文基于tensorflow、keras/pytorch实现对自然场景的文字检测及端到端的OCR中文文字识别`

# 参考github仓库
[TOTAL](https://github.com/chineseocr/chinese-ocr/tree/chinese-ocr-python-3.6)
[TOTAL](https://github.com/chineseocr/chinese-ocr/tree/chinese-ocr-python-3.6)-挂掉了

[CRNN—pytorch](https://github.com/meijieru/crnn.pytorch.git)


# 实现功能

- [x] 文字方向检测 0、90、180、270度检测
- [x] 文字检测 后期将切换到keras版本文本检测 实现keras端到端的文本检测及识别
- [x] 不定长OCR识别
- [x] 增加python3.6 支持
- 文字方向检测 0、90、180、270度检测
- 文字检测 后期将切换到keras版本文本检测 实现keras端到端的文本检测及识别
- 不定长OCR识别


## 环境部署
Expand All @@ -24,33 +24,39 @@ sh setup-python3.sh
```

# 模型训练
* 一共分为3个网络
* **1. 文本方向检测网络-Classify(vgg16)**
* **2. 文本区域检测网络-CTPN(CNN+RNN)**
* **3. EndToEnd文本识别网络-CRNN(CNN+GRU/LSTM+CTC)**

## 训练keras版本的crnn

``` Bash
cd train & sh train-keras.sh
# 文字方向检测
```

## 训练pytorch版本的crnn

``` Bash
cd train & sh train-pytorch.sh
基于图像分类,在VGG16模型的基础上,训练0、90、180、270度检测的分类模型.
详细代码参考angle/predict.py文件,训练图片8000张,准确率88.23%
```
# 文字方向检测
基于图像分类,在VGG16模型的基础上,训练0、90、180、270度检测的分类模型,详细代码参考angle/predict.py文件,训练图片8000张,准确率88.23%。
模型地址[百度云](https://pan.baidu.com/s/1pM2ha5P)下载
模型地址[BaiduCloud](https://pan.baidu.com/s/1zquQNdO0MUsLMsuwxbgPYg)

# 文字检测
# 文字区域检测CTPN
支持CPU、GPU环境,一键部署,
[文本检测训练参考](https://github.com/eragonruan/text-detection-ctpn)


# OCR 端到端识别:GRU+CTC
# OCR 端到端识别:CRNN
## ocr识别采用GRU+CTC端到到识别技术,实现不分隔识别不定长文字
提供keras 与pytorch版本的训练代码,在理解keras的基础上,可以切换到pytorch版本,此版本更稳定
- 此外还添加了tensorflow版本的资源仓库:[TF:LSTM-CTC_loss](https://github.com/ilovin/lstm_ctc_ocr)
- 此外还添加了tensorflow版本的资源仓库:[TF:LSTM-CTC_loss](https://github.com/ilovin/lstm_ctc_ocr)
## 训练keras版本的crnn

``` Bash
cd train & sh train-keras.sh
```

## 训练pytorch版本的crnn

``` Bash
cd train & sh train-pytorch.sh
```

# 识别结果展示
## 文字检测及OCR识别结果
<div>
Expand All @@ -66,16 +72,12 @@ cd train & sh train-pytorch.sh
</div>

## 参考
```
1.crnn
https://github.com/meijieru/crnn.pytorch.git
2.keras-crnn 版本实现参考 https://www.zhihu.com/question/59645822
3.tensorflow-crnn
https://github.com/ilovin/lstm_ctc_ocr

3.ctpn
https://github.com/eragonruan/text-detection-ctpn
https://github.com/tianzhi0549/CTPN
```
- [pytorch 实现crnn](https://github.com/meijieru/crnn.pytorch.git)
- [keras-crnn 版本实现参考](https://www.zhihu.com/question/59645822)
- [tensorflow-crnn](https://github.com/ilovin/lstm_ctc_ocr)
- [tensorflow-ctpn](https://github.com/eragonruan/text-detection-ctpn
)
- [CAFFE-CTPN](https://github.com/tianzhi0549/CTPN
)

4 changes: 2 additions & 2 deletions angle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python2
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
图像文字方向检测
@author: lywen
@author: xiaofeng
"""
15 changes: 14 additions & 1 deletion angle/predict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# _Author_: xiaofeng
# Date: 2018-04-22 18:13:46
# Last Modified by: xiaofeng
# Last Modified time: 2018-04-22 18:13:46
'''
根据给定的图形,分析文字的朝向
'''
# from keras.models import load_model
import numpy as np
from PIL import Image
Expand Down Expand Up @@ -42,10 +50,15 @@ def predict(path=None, img=None):
elif img is not None:
im = Image.fromarray(img).convert('RGB')
w, h = im.size
# 对图像进行剪裁
# 左上角(int(0.1 * w), int(0.1 * h))
# 右下角(w - int(0.1 * w), h - int(0.1 * h))
xmin, ymin, xmax, ymax = int(0.1 * w), int(
0.1 * h), w - int(0.1 * w), h - int(0.1 * h)
im = im.crop((xmin, ymin, xmax, ymax)) ##剪切图片边缘,清楚边缘噪声
im = im.crop((xmin, ymin, xmax, ymax)) ##剪切图片边缘,清除边缘噪声
# 对图片进行剪裁之后进行resize成(224,224)
im = im.resize((224, 224))
# 将图像转化成数组形式
img = np.array(im)
img = preprocess_input(img.astype(np.float32))
pred = model.predict(np.array([img]))
Expand Down
21 changes: 16 additions & 5 deletions ctpn/ctpn/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys, os
import sys
import os

import tensorflow as tf

Expand All @@ -13,38 +14,48 @@
# from ..lib.networks.factory import get_network
# from ..lib.fast_rcnn.config import cfg
# from..lib.fast_rcnn.test import test_ctpn
'''
load network
输入的名称为'Net_model'
'VGGnet_test'--test
'VGGnet_train'-train
'''


def load_tf_model():
cfg.TEST.HAS_RPN = True # Use RPN for proposals
# init session
config = tf.ConfigProto(allow_soft_placement=True)
# load network
net = get_network("VGGnet_test")
# load model
saver = tf.train.Saver()
# sess = tf.Session(config=config)
sess = tf.Session()
ckpt = tf.train.get_checkpoint_state('/Users/xiaofeng/Code/Github/dataset/CHINESE_OCR/ctpn/checkpoints/')
ckpt = tf.train.get_checkpoint_state(
'/Users/xiaofeng/Code/Github/dataset/CHINESE_OCR/ctpn/checkpoints/')
reader = tf.train.NewCheckpointReader(ckpt.model_checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
print("Tensor_name is : ", key)
# print(reader.get_tensor(key))
saver.restore(sess, ckpt.model_checkpoint_path)
print("load vggnet done")
return sess, saver, net


##init model
# init model
sess, saver, net = load_tf_model()


# 进行文本识别
def ctpn(img):
"""
text box detect
"""
scale, max_scale = Config.SCALE, Config.MAX_SCALE
# 对图像进行resize,输出的图像长宽
print('original_size',img.shape)
img, f = resize_im(img, scale=scale, max_scale=max_scale)
print('resize',img.shape,f)
scores, boxes = test_ctpn(sess, net, img)
return scores, boxes, img
11 changes: 10 additions & 1 deletion ctpn/ctpn/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ def prepare_img(im, mean):
return im_data


def draw_boxes(im, bboxes, is_display=True, color=None, caption="Image", wait=True):
def draw_boxes(im,
bboxes,
is_display=True,
color=None,
caption="Image",
wait=True):
"""
boxes: bounding boxes
"""
Expand Down Expand Up @@ -70,6 +75,9 @@ def draw_boxes(im, bboxes, is_display=True, color=None, caption="Image", wait=Tr
text_recs[index, 7] = y4
index = index + 1
# cv2.rectangle(im, tuple(box[:2]), tuple(box[2:4]), c,2)
# cv2.waitKey(0)
# cv2.imshow('kk', im)
cv2.imwrite('/Users/xiaofeng/Code/Github/Chinese-OCR/test/lllll.png',im)

return text_recs, im

Expand All @@ -96,6 +104,7 @@ def normalize(data):


def resize_im(im, scale, max_scale=None):
# 按照scale和图片的长宽的最小值的比值作为输入模型的图片的尺寸
f = float(scale) / min(im.shape[0], im.shape[1])
if max_scale != None and f * max(im.shape[0], im.shape[1]) > max_scale:
f = float(max_scale) / max(im.shape[0], im.shape[1])
Expand Down
8 changes: 7 additions & 1 deletion ctpn/text_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
from .ctpn.detectors import TextDetector
from .ctpn.model import ctpn
from .ctpn.other import draw_boxes
'''
进行文区别于识别-网络结构为cnn+rnn
'''


def text_detect(img):
# ctpn网络测到
scores, boxes, img = ctpn(img)
textdetector = TextDetector()
boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
text_recs, tmp = draw_boxes(img, boxes, caption='im_name', wait=True, is_display=False)
# text_recs, tmp = draw_boxes(img, boxes, caption='im_name', wait=True, is_display=False)
text_recs, tmp = draw_boxes(
img, boxes, caption='im_name', wait=True, is_display=True)
return text_recs, tmp, img
5 changes: 4 additions & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
paths = glob('./test/*.*')

if __name__ == '__main__':
im = Image.open("./test/test2.png")
im = Image.open("./test/ttttt.png")
img = np.array(im.convert('RGB'))
t = time.time()
'''
result,img,angel分别对应-识别结果,图像的数组,文字旋转角度
'''
result, img, angle = model.model(
img, model='keras', adjust=True, detectAngle=True)
print("It takes time:{}s".format(time.time() - t))
Expand Down
8 changes: 6 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,11 @@ def model(img, model='keras', adjust=False, detectAngle=False):
"""
angle = 0
if detectAngle:

# 进行文字旋转方向检测,分为[0, 90, 180, 270]四种情况
angle = angle_detect(img=np.copy(img)) ##文字朝向检测
print('The angel of this character is:', angle)
im = Image.fromarray(img)
print('Rotate the array of this img!')
if angle == 90:
im = im.transpose(Image.ROTATE_90)
elif angle == 180:
Expand All @@ -109,8 +111,10 @@ def model(img, model='keras', adjust=False, detectAngle=False):
im = im.transpose(Image.ROTATE_270)
img = np.array(im)
# 进行图像中的文字区域的识别
text_recs, tmp, img = text_detect(img)
text_recs, tmp, img=text_detect(img)
# 识别区域排列
text_recs = sort_box(text_recs)
#
result = crnnRec(img, text_recs, model, adjust=adjust)
return result, tmp, angle

Expand Down
100 changes: 0 additions & 100 deletions test.txt

This file was deleted.

Binary file added test/lllll.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/ttttt.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit e24eb9c

Please sign in to comment.