Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPTS] train #1756

Merged
merged 30 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
2089b02
[Feature] Add RepeatAugSampler
gaotongxiao Jan 13, 2023
88574e0
initial commit
gaotongxiao Jan 13, 2023
e2fd325
spts inference done
gaotongxiao Jan 23, 2023
393506a
merge repeat_aug (bug in multi-node?)
gaotongxiao Jan 23, 2023
3db3e1a
fix inference
gaotongxiao Jan 29, 2023
6e164fc
train done
gaotongxiao Jan 31, 2023
f45ae14
rm readme
gaotongxiao Jan 31, 2023
a8748f2
Revert "merge repeat_aug (bug in multi-node?)"
gaotongxiao Jan 31, 2023
208d785
Revert "[Feature] Add RepeatAugSampler"
gaotongxiao Jan 31, 2023
3d1a93a
remove utils
gaotongxiao Jan 31, 2023
73adab4
readme & conversion script
gaotongxiao Jan 31, 2023
31756c4
Merge branch 'dev-1.x' into spts
gaotongxiao Jan 31, 2023
ac6e3ef
update readme
gaotongxiao Jan 31, 2023
1b0ee1e
fix
gaotongxiao Jan 31, 2023
05c12f9
Merge branch 'batch_aug' into spts
gaotongxiao Jan 31, 2023
55ce451
optimize
gaotongxiao Jan 31, 2023
c7dbd4e
rename cfg & del compose
gaotongxiao Jan 31, 2023
5e85ed8
fix
gaotongxiao Jan 31, 2023
2551f2a
fix
gaotongxiao Feb 1, 2023
d2b6f9b
tmp commit
gaotongxiao Feb 15, 2023
22f118f
update training setting
gaotongxiao Feb 28, 2023
aabf7c8
update cfg
gaotongxiao Mar 2, 2023
a982351
update readme
gaotongxiao Mar 2, 2023
84b0ac5
e2e metric
gaotongxiao Mar 2, 2023
243b258
Merge branch 'dev-1.x' into spts
gaotongxiao Mar 2, 2023
e44b563
update cfg
gaotongxiao Mar 3, 2023
3416ec8
fix
gaotongxiao Mar 3, 2023
e6cafbc
update readme
gaotongxiao Mar 3, 2023
cb1e2c7
fix
gaotongxiao Mar 3, 2023
7ede2fe
update
gaotongxiao Mar 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmocr/utils/polygon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def crop_polygon(polygon: ArrayLike,
np.array or None: Cropped polygon. If the polygon is not within the
crop box, return None.
"""
poly = poly2shapely(polygon)
crop_poly = poly2shapely(bbox2poly(crop_box))
poly = poly_make_valid(poly2shapely(polygon))
crop_poly = poly_make_valid(poly2shapely(bbox2poly(crop_box)))
poly_cropped = poly.intersection(crop_poly)
if poly_cropped.area == 0. or not isinstance(
poly_cropped, shapely.geometry.polygon.Polygon):
Expand Down
59 changes: 39 additions & 20 deletions projects/SPTS/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,24 @@ $env:PYTHONPATH=Get-Location

### Dataset

**As of now, the implementation uses datasets provided by SPTS, but these datasets
will be available in MMOCR's dataset preparer very soon.**
As of now, the implementation uses datasets provided by SPTS for pre-training, and uses MMOCR's datasets for fine-tuning and testing. It's because the test split of SPTS's datasets does not contain enough information for e2e evaluation; and MMOCR's dataset preparer has not yet supported all the datasets used in SPTS. *We are working on this issue, and they will be available in MMOCR's dataset preparer very soon.*

Download and extract all the datasets into `data/` following [SPTS official guide](https://github.com/shannanyinxiang/SPTS#dataset).
Please follow these steps to prepare the datasets:

- Download and extract all the SPTS datasets into `spts-data/` following [SPTS official guide](https://github.com/shannanyinxiang/SPTS#dataset).

- Use [Dataset Preparer](https://mmocr.readthedocs.io/en/dev-1.x/user_guides/data_prepare/dataset_preparer.html) to prepare `icdar2013`, `icdar2015` and `totaltext` for `textspotting` task.

```shell
# Run in MMOCR's root directory
python tools/dataset_converters/prepare_dataset.py icdar2013 icdar2015 totaltext --task textspotting
```

Then create a soft link to `data/` directory in the project root directory:

```shell
ln -s ../../data/ .
```

### Training commands

Expand All @@ -48,52 +62,57 @@ In the current directory, run the following command to train the model:
#### Pretrain

```bash
mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/
mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/ --amp
```

To train on multiple GPUs, e.g. 8 GPUs, run the following command:

```bash
mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/ --launcher pytorch --gpus 8
mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/ --launcher pytorch --gpus 8 --amp
```

#### Finetune

Similarly, run the following command to finetune the model on a dataset (e.g. icdar2013):

```bash
mim train mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work_dirs/ --cfg-options "load_from={CHECKPOINT_PATH}"
mim train mmocr config/spts/spts_resnet50_8xb8-200e_icdar2013.py --work-dir work_dirs/ --cfg-options "load_from={CHECKPOINT_PATH}" --amp
```

To finetune on multiple GPUs, e.g. 8 GPUs, run the following command:

```bash
mim train mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work_dirs/ --launcher pytorch --gpus 8 --cfg-options "load_from={CHECKPOINT_PATH}"
mim train mmocr config/spts/spts_resnet50_8xb8-200e_icdar2013.py --work-dir work_dirs/ --launcher pytorch --gpus 8 --cfg-options "load_from={CHECKPOINT_PATH}" --amp
```

### Testing commands

In the current directory, run the following command to test the model on a dataset (e.g. icdar2013):

```bash
mim test mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work_dirs/ --checkpoint ${CHECKPOINT_PATH}
mim test mmocr config/spts/spts_resnet50_8xb8-200e_icdar2013.py --work-dir work_dirs/ --checkpoint ${CHECKPOINT_PATH}
```

## Results
## Convert Weights from Official Repo

The weights from MMOCR are on the way. Users may download the weights from [SPTS](https://github.com/shannanyinxiang/SPTS#inference) and use the conversion script to convert them into MMOCR format.
Users may download the weights from [SPTS](https://github.com/shannanyinxiang/SPTS#inference) and use the conversion script to convert them into MMOCR format.

```bash
python tools/ckpt_adapter.py [SPTS_WEIGHTS_PATH] [MMOCR_WEIGHTS_PATH]
```

Here are the results obtained on the converted weights. The results are lower than the original ones due to the difference in the test split of datasets, which will be addressed in next update.
## Results

All the models are trained on 8x A100 GPUs with AMP on (`--amp`). The overall batch size is 64.

| Name | Pretrained | Generic | Weak | Strong | Download |
| ---------- | --------------------------------------------------------------------------------------- | ------- | ----- | ------ | ------------------------------------------------------------------------------------- |
| ICDAR 2013 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/spts_resnet50_150e_pretrain-spts-c9fe4c78.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/20230223_194550.log) | 87.10 | 91.46 | 93.41 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2013/spts_resnet50_200e_icdar2013-64cb4d31.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2013/20230303_140316.log) |
| ICDAR 2015 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/spts_resnet50_150e_pretrain-spts-c9fe4c78.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/20230223_194550.log) | 69.09 | 73.45 | 79.19 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2015/spts_resnet50_200e_icdar2015-d6e8621c.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2015/20230302_230026.log) |

| Name | Model | E2E-None-Hmean |
| :--------: | :-------------------: | :------------: |
| ICDAR 2013 | ic13.pth (converted) | 0.8573 |
| ctw1500 | ctw1500 (converted) | 0.6304 |
| totaltext | totaltext (converted) | 0.6596 |
| Name | Pretrained | None-Hmean | Full-Hmean | Download |
| :-------: | -------------------------------------------------------------------------------------- | :--------: | :--------: | ------------------------------------------------------------------------------------- |
| Totaltext | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/spts_resnet50_150e_pretrain-spts-c9fe4c78.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/20230223_194550.log) | 73.99 | 82.34 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_totaltext/spts_resnet50_200e_totaltext-e3521af6.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_totaltext/20230303_103040.log) |

## Citation

Expand Down Expand Up @@ -136,15 +155,15 @@ A project does not necessarily have to be finished in a single PR, but it's esse

<!-- As this template does. -->

- [ ] Milestone 2: Indicates a successful model implementation.
- [x] Milestone 2: Indicates a successful model implementation.

- [ ] Training-time correctness
- [x] Training-time correctness

<!-- If you are reproducing the result from a paper, checking this item means that you should have trained your model from scratch based on the original paper's specification and verified that the final result matches the report within a minor error range. -->

- [ ] Milestone 3: Good to be a part of our core package!
- [x] Milestone 3: Good to be a part of our core package!

- [ ] Type hints and docstrings
- [x] Type hints and docstrings

<!-- Ideally *all* the methods should have [type hints](https://www.pythontutorial.net/python-basics/python-type-hints/) and [docstrings](https://google.github.io/styleguide/pyguide.html#381-docstrings). [Example](https://github.com/open-mmlab/mmocr/blob/76637a290507f151215d299707c57cea5120976e/mmocr/utils/polygon_utils.py#L80-L96) -->

Expand Down
2 changes: 1 addition & 1 deletion projects/SPTS/config/_base_/datasets/icdar2013-spts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
icdar2013_textspotting_data_root = 'data/icdar2013'
icdar2013_textspotting_data_root = 'spts-data/icdar2013'

icdar2013_textspotting_train = dict(
type='AdelDataset',
Expand Down
2 changes: 1 addition & 1 deletion projects/SPTS/config/_base_/datasets/icdar2015-spts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
icdar2015_textspotting_data_root = 'data/icdar2015'
icdar2015_textspotting_data_root = 'spts-data/icdar2015'

icdar2015_textspotting_train = dict(
type='AdelDataset',
Expand Down
1 change: 0 additions & 1 deletion projects/SPTS/config/_base_/datasets/icdar2015.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,4 @@
data_root=icdar2015_textspotting_data_root,
ann_file='textspotting_test.json',
test_mode=True,
# indices=50,
pipeline=None)
2 changes: 1 addition & 1 deletion projects/SPTS/config/_base_/datasets/mlt-spts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mlt_textspotting_data_root = 'data/mlt2017'
mlt_textspotting_data_root = 'spts-data/mlt2017'

mlt_textspotting_train = dict(
type='AdelDataset',
Expand Down
2 changes: 1 addition & 1 deletion projects/SPTS/config/_base_/datasets/syntext1-spts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
syntext1_textspotting_data_root = 'data/syntext1'
syntext1_textspotting_data_root = 'spts-data/syntext1'

syntext1_textspotting_train = dict(
type='AdelDataset',
Expand Down
2 changes: 1 addition & 1 deletion projects/SPTS/config/_base_/datasets/syntext2-spts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
syntext2_textspotting_data_root = 'data/syntext2'
syntext2_textspotting_data_root = 'spts-data/syntext2'

syntext2_textspotting_train = dict(
type='AdelDataset',
Expand Down
2 changes: 1 addition & 1 deletion projects/SPTS/config/_base_/datasets/totaltext-spts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
totaltext_textspotting_data_root = 'data/totaltext'
totaltext_textspotting_data_root = 'spts-data/totaltext'

totaltext_textspotting_train = dict(
type='AdelDataset',
Expand Down
15 changes: 15 additions & 0 deletions projects/SPTS/config/_base_/datasets/totaltext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
totaltext_textspotting_data_root = 'data/totaltext'

totaltext_textspotting_train = dict(
type='OCRDataset',
data_root=totaltext_textspotting_data_root,
ann_file='textspotting_train.json',
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=None)

totaltext_textspotting_test = dict(
type='OCRDataset',
data_root=totaltext_textspotting_data_root,
ann_file='textspotting_test.json',
test_mode=True,
pipeline=None)
3 changes: 2 additions & 1 deletion projects/SPTS/config/_base_/default_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
randomness = dict(seed=None)

randomness = dict(seed=42)

default_hooks = dict(
timer=dict(type='IterTimerHook'),
Expand Down
9 changes: 3 additions & 6 deletions projects/SPTS/config/spts/_base_spts_resnet50.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
custom_imports = dict(imports=['spts'], allow_failed_imports=False)
custom_imports = dict(
imports=['projects.SPTS.spts'], allow_failed_imports=False)

file_client_args = dict(backend='disk')

Expand Down Expand Up @@ -65,10 +66,7 @@
type='LoadOCRAnnotationsWithBezier',
with_bbox=True,
with_label=True,
with_bezier=True,
with_text=True),
dict(type='Bezier2Polygon'),
dict(type='ConvertText', dictionary=dictionary),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
Expand All @@ -87,7 +85,7 @@
with_text=True),
dict(type='Bezier2Polygon'),
dict(type='FixInvalidPolygon'),
dict(type='ConvertText', dictionary=dictionary),
dict(type='ConvertText', dictionary=dict(**dictionary, num_bins=0)),
dict(type='RemoveIgnored'),
dict(type='RandomCrop', min_side_ratio=0.5),
dict(
Expand Down Expand Up @@ -119,7 +117,6 @@
hue=0.5)
],
prob=0.5),
# dict(type='Polygon2Bezier'),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
Expand Down
63 changes: 63 additions & 0 deletions projects/SPTS/config/spts/_base_spts_resnet50_mmocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
_base_ = '_base_spts_resnet50.py'

test_pipeline = [
dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
dict(
type='RescaleToShortSide',
short_side_lens=[1000],
long_side_bound=1824),
dict(
type='LoadOCRAnnotations',
with_bbox=True,
with_label=True,
with_polygon=True,
with_text=True),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
]

train_pipeline = [
dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
dict(
type='LoadOCRAnnotations',
with_bbox=True,
with_label=True,
with_polygon=True,
with_text=True),
dict(type='FixInvalidPolygon'),
dict(type='RemoveIgnored'),
dict(type='RandomCrop', min_side_ratio=0.5),
dict(
type='RandomApply',
transforms=[
dict(
type='RandomRotate',
max_angle=30,
pad_with_fixed_color=True,
use_canvas=True)
],
prob=0.3),
dict(type='FixInvalidPolygon'),
dict(
type='RandomChoiceResize',
scales=[(640, 1600), (672, 1600), (704, 1600), (736, 1600),
(768, 1600), (800, 1600), (832, 1600), (864, 1600),
(896, 1600)],
keep_ratio=True),
dict(
type='RandomApply',
transforms=[
dict(
type='TorchVisionWrapper',
op='ColorJitter',
brightness=0.5,
contrast=0.5,
saturation=0.5,
hue=0.5)
],
prob=0.5),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
]
59 changes: 0 additions & 59 deletions projects/SPTS/config/spts/spts_resnet50_350e_ctw1500-spts.py

This file was deleted.

Loading