-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from ZwwWayne/swin
Support swin transformer
- Loading branch information
Showing
21 changed files
with
2,315 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
data/ | ||
data | ||
.vscode | ||
.idea | ||
.DS_Store | ||
|
||
# custom | ||
*.pkl | ||
*.pkl.json | ||
*.log.json | ||
work_dirs/ | ||
|
||
# Pytorch | ||
*.pth | ||
*.py~ | ||
*.sh~ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Swin Transformer for Object Detection and Segmentation | ||
|
||
This is an unofficial implementation of Swin Transformer. | ||
It implements Swin Transformer for object detection and segmentation tasks to show how we can use [MIM](https://github.com/open-mmlab/mim) to accelerate the research projects. | ||
|
||
## Requirements | ||
|
||
- MIM>=0.1.1 | ||
- MMCV-full v1.3.5 | ||
- MMDetection v2.13.0 | ||
- MMSegmentation v0.14.0 | ||
- timm | ||
|
||
You can install them after installing mim through the following commands | ||
|
||
```bash | ||
pip install openmim>=0.1.1 # install mim through pypi | ||
pip install timm # swin transformer relies timm | ||
mim install mmcv-full==1.3.5 # install mmcv | ||
MKL_THREADING_LAYER=GNU mim install mmdet==2.13.0 # install mmdet to run object detection | ||
MKL_THREADING_LAYER=GNU mim install mmsegmentation=0.14.0 # install mmseg to run semantic segmentation | ||
``` | ||
|
||
**Note**: `MKL_THREADING_LAYER=GNU` is a workaround according to the [issue](https://github.com/pytorch/pytorch/issues/37377). | ||
|
||
## Explaination | ||
|
||
Because MMDetection and MMSegmentation inherits the model registry in MMCV since v2.12.0 and v0.13.0, respectively, we only need one implementation of swin transformer and add it into the model registry of MMCV. Then we can use it for object detection and segmentation by modifying configs. | ||
|
||
|
||
### Step 1: implement Swin Transformer | ||
|
||
The implementation of Swin Transformer and its pre-trained models are taken from the [official implementation](https://github.com/microsoft/Swin-Transformer). | ||
The key file structure is as below: | ||
|
||
``` | ||
swin_transformer | ||
|---- configs | ||
|---- swin_mask_rcnn # config files to run with MMDetection | ||
|---- mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py | ||
|---- mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py | ||
|---- swin_upernet # config files to run with MMSegmentation | ||
|---- upernet_swin-t_512x512_160k_8x2_ade20k.py | ||
|---- swin | ||
|---- swin_checkpoint.py # for checkout loading | ||
|---- swin_transformer.py # implementation of swin transformer | ||
``` | ||
|
||
### Step 2: register Swin Transformer into model registry | ||
|
||
The key step that allow MMDet and MMSeg to use a unique implementation of Swin Transformer is to register the backbone into the registry in MMCV. | ||
|
||
```python | ||
from mmcv.cnn import MODELS | ||
|
||
|
||
@MODELS.register_module() | ||
class SwinTransformer(nn.Module): | ||
# code implementation | ||
def __init__(self, *args, **kwargs): | ||
super().__init__() | ||
``` | ||
|
||
It essentially builds a mapping as below | ||
|
||
```python | ||
'SwinTransformer' -> <class 'SwinTransformer'> | ||
``` | ||
|
||
Because MMDetection and MMSegmentation inherits the model registry in MMCV since v2.12.0 and v0.13.0, their `MODELS` registries are under descendants of the `MODELS` registry in MMCV. Therefore, such a mapping in MMDet/MMSeg becomes | ||
|
||
```python | ||
'mmcv.SwinTransformer' -> <class 'SwinTransformer'> | ||
``` | ||
|
||
To enable the `MODEL.build()` in MMDet/MMSeg to correctly find the implementation of `SwinTransformer`, we need to specify the scope of the module by `mmcv.SwinTransformer` as you will see in the configs. | ||
|
||
### Step 3: use Swin Transformer through config | ||
|
||
To use Swin Transformer, we can simply use the config and the build function | ||
|
||
```python | ||
module_cfg = dict(type='mmcv.SwinTransformer') | ||
module = build_backbone(module_cfg) | ||
``` | ||
|
||
To run it with MMDetection or MMSegmentation, we need to define the model backbone as below | ||
|
||
```python | ||
model = dict( | ||
type='MaskRCNN', | ||
pretrained='./pretrain/swin/swin_tiny_patch4_window7_224.pth', | ||
backbone=dict(type='mmcv.SwinTransformer')) | ||
|
||
custom_imports = dict( | ||
imports=['swin.swin_transformer'], allow_failed_imports=False) | ||
``` | ||
|
||
## Usages | ||
|
||
Assume now you are in the directory under `swin_transformer`, to run it with mmdet and slurm, we can use the command as below | ||
|
||
```bash | ||
PYTHONPATH='.':$PYTHONPATH mim train mmdet configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py \--work-dir ../work_dir/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py --launcher slurm --partition $PARTITION --gpus 8 --gpus-per-node 8 --srun-args $SRUN_ARGS | ||
``` | ||
|
||
To run it with mmseg, we can use the command as below | ||
|
||
```bash | ||
PYTHONPATH='.':$PYTHONPATH mim train mmseg configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py --work-dir ../work_dir/upernet_swin-t_512x512_160k_8x2_ade20k.py --launcher slurm --partition $PARTITION --gpus 8 --gpus-per-node 8 --srun-args $SRUN_ARGS | ||
``` | ||
|
||
## Results | ||
|
||
### ADE20K | ||
|
||
| Backbone | Method | Crop Size | Lr Schd | mIoU | Config | Download | | ||
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | | ||
| Swin-T | UPerNet | 512x512 | 160K | 44.3 | [config](/configs/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py) | [model](https://download.openmmlab.com/mim-example/swin_transformer/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k_20210613_201937-4f09fb29.pth) | [log](https://download.openmmlab.com/mim-example/swin_transformer/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k_20210613_201937.log.json) | | ||
### COCO | ||
|
||
| Backbone | Method | Lr Schd | Bbox mAP | Mask mAP| Config | Download | | ||
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | | ||
| Swin-T | Mask R-CNN | 1x| 42.6| 39.5 |[config](/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_1x_coco_20210612_135948-bf3d7aa4.pth) | [log](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_1x_coco_20210612_135948.log.json) | | ||
| Swin-T | Mask R-CNN | FP16 1x| 42.5|39.3 |[config](/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py) | [model](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco_20210612_135948-6434d76f.pth) | [log](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco_20210612_135948.log.json) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# dataset settings | ||
dataset_type = 'ADE20KDataset' | ||
data_root = 'data/ade/ADEChallengeData2016' | ||
img_norm_cfg = dict( | ||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) | ||
crop_size = (512, 512) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadAnnotations', reduce_zero_label=True), | ||
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), | ||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PhotoMetricDistortion'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), | ||
dict(type='DefaultFormatBundle'), | ||
dict(type='Collect', keys=['img', 'gt_semantic_seg']), | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='MultiScaleFlipAug', | ||
img_scale=(2048, 512), | ||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], | ||
flip=False, | ||
transforms=[ | ||
dict(type='Resize', keep_ratio=True), | ||
dict(type='RandomFlip'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='Collect', keys=['img']), | ||
]) | ||
] | ||
data = dict( | ||
samples_per_gpu=4, | ||
workers_per_gpu=4, | ||
train=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
img_dir='images/training', | ||
ann_dir='annotations/training', | ||
pipeline=train_pipeline), | ||
val=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
img_dir='images/validation', | ||
ann_dir='annotations/validation', | ||
pipeline=test_pipeline), | ||
test=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
img_dir='images/validation', | ||
ann_dir='annotations/validation', | ||
pipeline=test_pipeline)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
dataset_type = 'CocoDataset' | ||
data_root = 'data/coco/' | ||
img_norm_cfg = dict( | ||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadAnnotations', with_bbox=True, with_mask=True), | ||
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), | ||
dict(type='RandomFlip', flip_ratio=0.5), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Pad', size_divisor=32), | ||
dict(type='DefaultFormatBundle'), | ||
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='MultiScaleFlipAug', | ||
img_scale=(1333, 800), | ||
flip=False, | ||
transforms=[ | ||
dict(type='Resize', keep_ratio=True), | ||
dict(type='RandomFlip'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Pad', size_divisor=32), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='Collect', keys=['img']), | ||
]) | ||
] | ||
data = dict( | ||
samples_per_gpu=2, | ||
workers_per_gpu=2, | ||
train=dict( | ||
type=dataset_type, | ||
ann_file=data_root + 'annotations/instances_train2017.json', | ||
img_prefix=data_root + 'train2017/', | ||
pipeline=train_pipeline), | ||
val=dict( | ||
type=dataset_type, | ||
ann_file=data_root + 'annotations/instances_val2017.json', | ||
img_prefix=data_root + 'val2017/', | ||
pipeline=test_pipeline), | ||
test=dict( | ||
type=dataset_type, | ||
ann_file=data_root + 'annotations/instances_val2017.json', | ||
img_prefix=data_root + 'val2017/', | ||
pipeline=test_pipeline)) | ||
evaluation = dict(metric=['bbox', 'segm']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
checkpoint_config = dict(interval=1) | ||
# yapf:disable | ||
log_config = dict( | ||
interval=50, | ||
hooks=[ | ||
dict(type='TextLoggerHook'), | ||
# dict(type='TensorboardLoggerHook') | ||
]) | ||
# yapf:enable | ||
# custom_hooks = [dict(type='NumClassCheckHook')] | ||
|
||
dist_params = dict(backend='nccl') | ||
log_level = 'INFO' | ||
load_from = None | ||
resume_from = None | ||
workflow = [('train', 1)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# yapf:disable | ||
log_config = dict( | ||
interval=50, | ||
hooks=[ | ||
dict(type='TextLoggerHook', by_epoch=False), | ||
# dict(type='TensorboardLoggerHook') | ||
]) | ||
# yapf:enable | ||
dist_params = dict(backend='nccl') | ||
log_level = 'INFO' | ||
load_from = None | ||
resume_from = None | ||
workflow = [('train', 1)] | ||
cudnn_benchmark = True |
Oops, something went wrong.