Skip to content

add det+rec ckpt prediction pipeline #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ jobs:
pytest tests/ut/*.py
- name: Test with pytest (ST)
run: |
pytest tests/st/test_train_eval_dummy.py
pytest tests/st/test_train_eval_dummy.py
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ MX, which is short for [MindX](https://www.hiascend.com/zh/software/mindx-sdk),

MindOCR supports OCR model inference with MX Engine. Please refer to [mx_infer](docs/cn/inference_tutorial_cn.md) for detailed illustrations.

#### 2.2 Inference with MS Lite
#### 2.2 Inference with MindSpore Lite

Coming soon

Expand Down
4 changes: 2 additions & 2 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ MX ([MindX](https://www.hiascend.com/zh/software/mindx-sdk)的缩写) 是一个
MindOCR集成了MX推理引擎,支持文本检测识别任务,请参考[mx_infer](docs/cn/inference_tutorial_cn.md)。


#### 2.2 使用Lite推理
#### 2.2 使用MindSpore Lite推理

敬请期待

#### 2.3 使用原生MindSpore推理

敬请期待
MindOCR支持使用MindSpore训练好的ckpt文件进行文本检测+文本识别串联推理,请参考[此处](docs/cn/predict_ckpt_cn.md)。

## 模型列表

Expand Down
40 changes: 39 additions & 1 deletion configs/det/dbnet/db_r50_icdar15.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ train:
num_workers: 8

eval:
ckpt_load_path: 'tmp_det/best.ckpt'
ckpt_load_path: tmp_det/best.ckpt
dataset_sink_mode: False
dataset:
type: DetDataset
Expand Down Expand Up @@ -148,3 +148,41 @@ eval:
batch_size: 1 # TODO: due to dynamic shape of polygons (num of boxes varies), BS has to be 1
drop_remainder: False
num_workers: 2

predict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the separate PredictDataset class, there is many duplicates in the config file. It's easy to make a mistake when there's a lot of repetitions.

ckpt_load_path: tmp_det/best.ckpt
dataset_sink_mode: False
dataset:
type: PredictDataset
dataset_root: path/to/dataset_root
data_dir: ic15/det/test/ch4_test_images
# label_file: test.txt
sample_ratio: 1.0
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
# - DetLabelEncode:
- GridResize:
factor: 32
# GridResize already sets the evaluation size to [ 736, 1280 ].
# Uncomment ScalePadImage block for other resolutions.
# - ScalePadImage:
# target_size: [ 736, 1280 ] # h, w
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the labels for evaluation
output_columns: [ 'img_path', 'image', 'raw_img_shape' ] # shape in h, w order
num_columns_to_net: 1 # num inputs for network forward func
# num_keys_of_labels: 2 # num labels

loader:
shuffle: False
batch_size: 1 # TODO: due to dynamic shape of polygons (num of boxes varies), BS has to be 1
drop_remainder: False
num_workers: 2

46 changes: 44 additions & 2 deletions configs/rec/crnn/crnn_resnet34.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: True
distribute: True
amp_level: 'O3'
seed: 42
log_interval: 100
Expand Down Expand Up @@ -111,7 +111,7 @@ train:
num_workers: 8

eval:
ckpt_load_path: './tmp_rec/best.ckpt'
ckpt_load_path: ./tmp_rec/best.ckpt
dataset_sink_mode: False
dataset:
type: LMDBDataset
Expand Down Expand Up @@ -151,3 +151,45 @@ eval:
drop_remainder: False
max_rowsize: 12
num_workers: 8

predict:
ckpt_load_path: ./tmp_rec/best.ckpt
vis_font_path: tools/utils/simfang.ttf
dataset_sink_mode: False
dataset:
type: PredictDataset
dataset_root: path/to/dataset_root
data_dir: predict_result/crop
# label_files: # not required when using LMDBDataset
sample_ratio: 1.0
shuffle: False
transform_pipeline:
- DecodeImage:
img_mode: BGR
to_float32: False
# - RecCTCLabelEncode:
# max_text_len: *max_text_len
# character_dict_path: *character_dict_path
# use_space_char: *use_space_char
# lower: True
- RecResizeImg: # different from paddle (paddle converts image from HWC to CHW and rescale to [-1, 1] after resize.
image_shape: [32, 100] # H, W
infer_mode: *infer_mode
character_dict_path: *character_dict_path
padding: False # aspect ratio will be preserved if true.
- NormalizeImage: # different from paddle (paddle wrongly normalize BGR image with RGB mean/std from ImageNet for det, and simple rescale to [-1, 1] in rec.
bgr_to_rgb: True
is_hwc: True
mean : [127.0, 127.0, 127.0]
std : [127.0, 127.0, 127.0]
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
output_columns: [ 'img_path', 'image', 'raw_img_shape' ]
num_columns_to_net: 1 # num inputs for network forward func

loader:
shuffle: False # TODO: tbc
batch_size: 1
drop_remainder: True
max_rowsize: 12
num_workers: 8
90 changes: 90 additions & 0 deletions docs/cn/predict_ckpt_cn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# MindOCR串联推理

本文档介绍如何使用MindSpore训练出来的ckpt文件进行文本检测+文本识别的串联推理。

## 1. 支持的串联模型组合

| 文本检测+文本识别模型组合 | 数据集 | 推理精度 |
|---------------|-------------------------------------------------------------------|---------|
| DBNet+CRNN | [ICDAR15](https://rrc.cvc.uab.es/?ch=4&com=downloads)<sup>*</sup> | 55.99% |

> *此处用于推理的是ICDAR15 Task 4.1中的Test Set

## 2. 快速开始

### 2.1 环境配置

| 环境/设备 | 版本 |
|-----------|-------|
| MindSpore | >=1.9 |
| Python | >=3.7 |


### 2.2 参数配置

参数配置包含两部分:(1)模型yaml配置文件(2)推理脚本`tools/predict/text/predict_system.py`中的args参数。

**注意:如果在(2)中传入args参数值,则会覆盖(1)yaml配置文件中的相应参数值;否则,将会使用yaml配置文件中的默认参数值,您可以手动更新yaml配置文件中的参数值。**

#### (1) yaml配置文件

检测模型和识别模型各有一个yaml配置文件。请重点关注这**两个**文件中`predict`模块内的内容,重点参数如下。

```yaml
...
predict:
ckpt_load_path: tmp_det/best.ckpt <--- args.det_ckpt_path覆盖检测yaml, args.rec_ckpt_path覆盖识别yaml; 或手动更新该值
dataset_sink_mode: False
dataset:
type: PredictDataset
dataset_root: path/to/dataset_root <--- args.raw_data_dir覆盖检测yaml, args.crop_save_dir覆盖识别yaml; 或手动更新该值
data_dir: ic15/det/test/ch4_test_images <--- args.raw_data_dir覆盖检测yaml, args.crop_save_dir覆盖识别yaml; 或手动更新该值
sample_ratio: 1.0
transform_pipeline:
...
output_columns: [ 'img_path', 'image', 'raw_img_shape' ]
num_columns_to_net: 1
loader:
shuffle: False
batch_size: 1
...
```

#### (2) args参数列表

| 参数名 | 含义 | 默认值 |
|--------------------------------------|-----------------------------------------| -------- |
| raw_data_dir | 待预测数据的文件夹 | - |
| det_ckpt_path | 检测模型ckpt文件路径 | - |
| rec_ckpt_path | 识别模型ckpt文件路径 | - |
| det_config_path | 检测模型yaml配置文件路径 | 'configs/det/dbnet/db_r50_icdar15.yaml' |
| rec_config_path | 识别模型yaml配置文件路径 | 'configs/rec/crnn/crnn_resnet34.yaml' |
| crop_save_dir | 串联推理中检测后裁剪图片的保存文件夹,**即识别模型读取图片的文件夹** | 'predict_result/crop' |
| result_save_path | 串联推理结果保存路径 | 'predict_result/ckpt_pred_result.txt' |


### 2.3 推理

运行以下命令,开始串联推理。**以下传入的参数值将覆盖yaml文件中的对应参数值。**

```bash
python tools/predict/text/predict_system.py \
--raw_data_dir path/to/raw_data \
--det_ckpt_path path/to/detection_ckpt \
--rec_ckpt_path path/to/recognition_ckpt
```

### 2.4 精度评估

推理完成后,图片名、文字检测框(points)和识别的文字(trancription)将保存在args.result_save_path。推理结果文件格式示例如下:
```text
img_1.jpg [{"transcription": "hello", "points": [600, 150, 715, 157, 714, 177, 599, 170]}, {"transcription": "world", "points": [622, 126, 695, 129, 694, 154, 621, 151]}, ...]
img_2.jpg [{"transcription": "apple", "points": [553, 338, 706, 318, 709, 342, 556, 362]}, ...]
...
```

准备好串联推理图片的**ground truth文件**(格式与上述推理结果文件一致)和**推理结果文件**后,执行以下命令,开始对串联推理进行精度评估。
```bash
cd deploy/eval_utils
python eval_pipeline.py --gt_path path/to/gt.txt --pred_path path/to/ckpt_pred_result.txt
```
4 changes: 2 additions & 2 deletions mindocr/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ def __init__(self,
data_dir = [data_dir]
for f in data_dir:
if not os.path.exists(f):
raise ValueError(f"{f} not existed. Please check the yaml file for both train and eval")
raise ValueError(f"data_dir '{f}' does not existed. Please check the yaml file for both train and eval")
self.data_dir = data_dir

if label_file is not None:
if isinstance(label_file, str):
label_file = [label_file]
for f in label_file:
if not os.path.exists(f):
raise ValueError(f"{f} not existed. Please check the yaml file for both train and eval")
raise ValueError(f"label_file '{f}' does not existed. Please check the yaml file for both train and eval")
else:
label_file = []
self.label_file = label_file
Expand Down
3 changes: 2 additions & 1 deletion mindocr/data/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from .det_dataset import DetDataset, SynthTextDataset
from .rec_dataset import RecDataset
from .rec_lmdb_dataset import LMDBDataset
from .predict_dataset import PredictDataset

__all__ = ['build_dataset']

supported_dataset_types = ['BaseDataset', 'DetDataset', 'RecDataset', 'LMDBDataset', 'SynthTextDataset']
supported_dataset_types = ['BaseDataset', 'DetDataset', 'RecDataset', 'LMDBDataset', 'SynthTextDataset', 'PredictDataset']

def build_dataset(
dataset_config: dict,
Expand Down
4 changes: 2 additions & 2 deletions mindocr/data/det_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def __init__(self,
if k in _data:
self.output_columns.append(k)
else:
raise ValueError(f'Key {k} does not exist in data (available keys: {_data.keys()}). '
'Please check the name or the completeness transformation pipeline.')
raise ValueError(f"Key '{k}' does not exist in data (available keys: {_data.keys()}). "
"Please check the name or the completeness transformation pipeline.")

def __getitem__(self, index):
data = self.data_list[index].copy() # WARNING: shallow copy. Do deep copy if necessary.
Expand Down
80 changes: 80 additions & 0 deletions mindocr/data/predict_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
'''
Inference dataset class
'''
import os
import random
from typing import Union, List

from .base_dataset import BaseDataset
from .transforms.transforms_factory import create_transforms, run_transforms

__all__ = ['PredictDataset']


class PredictDataset(BaseDataset):
Copy link
Collaborator

@hadipash hadipash May 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like PredictDataset should be a part of DetDataset rather than a separate class. Can set a flag such as prediction to load images only. Perhaps, we can do the same for RecDataset.

"""
Notes:
1. The data file structure should be like
├── img_dir
│ ├── 000001.jpg
│ ├── 000002.jpg
│ ├── {image_file_name}
"""
def __init__(self,
# is_train: bool = False,
dataset_root: str = '',
data_dir: str = '',
sample_ratio: Union[List, float] = 1.0,
shuffle: bool = None,
transform_pipeline: List[dict] = None,
output_columns: List[str] = None,
**kwargs):
img_dir = os.path.join(dataset_root, data_dir)
super().__init__(data_dir=img_dir, label_file=None, output_columns=output_columns)
self.data_list = self.load_data_list(img_dir, sample_ratio, shuffle)

# create transform
if transform_pipeline is not None:
self.transforms = create_transforms(transform_pipeline) # , global_config=global_config)
else:
raise ValueError('No transform pipeline is specified!')

# prefetch the data keys, to fit GeneratorDataset
_data = self.data_list[0]
_data = run_transforms(_data, transforms=self.transforms)
_available_keys = list(_data.keys())
if output_columns is None:
self.output_columns = _available_keys
else:
self.output_columns = []
for k in output_columns:
if k in _data:
self.output_columns.append(k)
else:
raise ValueError(f"Key '{k}' does not exist in data (available keys: {_data.keys()}). "
"Please check the name or the completeness transformation pipeline.")

def __getitem__(self, index):
data = self.data_list[index]

# perform transformation on data
data = run_transforms(data, transforms=self.transforms)
output_tuple = tuple(data[k] for k in self.output_columns)

return output_tuple

def load_data_list(self,
img_dir: str,
sample_ratio: List[float],
shuffle: bool = False,
**kwargs) -> List[dict]:
# read image file name
img_filenames = os.listdir(img_dir)
if shuffle:
img_filenames = random.sample(img_filenames, round(len(img_filenames) * sample_ratio))
else:
img_filenames = img_filenames[:round(len(img_filenames) * sample_ratio)]

img_paths = [{'img_path': os.path.join(img_dir, filename)} for filename in img_filenames]

return img_paths
1 change: 1 addition & 0 deletions mindocr/data/transforms/general_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __call__(self, data):
img = img.astype('float32')
data['image'] = img
# data['ori_image'] = img.copy()
data['raw_img_shape'] = img.shape[:2]
return data


Expand Down
Loading