Skip to content

Commit

Permalink
Merge pull request #1 from ZwwWayne/swin
Browse files Browse the repository at this point in the history
Support swin transformer
  • Loading branch information
ZwwWayne authored Jun 16, 2021
2 parents e3277a8 + 23e0fa8 commit fa4a200
Show file tree
Hide file tree
Showing 21 changed files with 2,315 additions and 0 deletions.
121 changes: 121 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

data/
data
.vscode
.idea
.DS_Store

# custom
*.pkl
*.pkl.json
*.log.json
work_dirs/

# Pytorch
*.pth
*.py~
*.sh~
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ English | [简体中文](README_zh-CN.md)
Based on MIM and other OpenMMLab codebases, you can build new projects conveniently by just writing several python files. In this repository we provide some examples:

1. [mmcls_custom_backbone](/mmcls_custom_backbone): Use custom backbone in MMClassification.
2. [Swin Transformer](/swin_transformer): Minimal code implementation of Swin Transformer for object detection and semantic segmentation.
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
基于 [MIM](https://github.com/open-mmlab/mim) 和 OpenMMLAB 中的代码库,用户仅需编写数个 python 文件就可轻松构建新的项目。在这里我们提供了如下示例

1. [mmcls_custom_backbone](/mmcls_custom_backbone):在 MMClassification 中使用自定义主干网络
2. [Swin Transformer](/swin_transformer): Swin Transformer 的最简实现,可以直接用于目标检测和语义分割任务
125 changes: 125 additions & 0 deletions swin_transformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Swin Transformer for Object Detection and Segmentation

This is an unofficial implementation of Swin Transformer.
It implements Swin Transformer for object detection and segmentation tasks to show how we can use [MIM](https://github.com/open-mmlab/mim) to accelerate the research projects.

## Requirements

- MIM>=0.1.1
- MMCV-full v1.3.5
- MMDetection v2.13.0
- MMSegmentation v0.14.0
- timm

You can install them after installing mim through the following commands

```bash
pip install openmim>=0.1.1 # install mim through pypi
pip install timm # swin transformer relies timm
mim install mmcv-full==1.3.5 # install mmcv
MKL_THREADING_LAYER=GNU mim install mmdet==2.13.0 # install mmdet to run object detection
MKL_THREADING_LAYER=GNU mim install mmsegmentation=0.14.0 # install mmseg to run semantic segmentation
```

**Note**: `MKL_THREADING_LAYER=GNU` is a workaround according to the [issue](https://github.com/pytorch/pytorch/issues/37377).

## Explaination

Because MMDetection and MMSegmentation inherits the model registry in MMCV since v2.12.0 and v0.13.0, respectively, we only need one implementation of swin transformer and add it into the model registry of MMCV. Then we can use it for object detection and segmentation by modifying configs.


### Step 1: implement Swin Transformer

The implementation of Swin Transformer and its pre-trained models are taken from the [official implementation](https://github.com/microsoft/Swin-Transformer).
The key file structure is as below:

```
swin_transformer
|---- configs
|---- swin_mask_rcnn # config files to run with MMDetection
|---- mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py
|---- mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py
|---- swin_upernet # config files to run with MMSegmentation
|---- upernet_swin-t_512x512_160k_8x2_ade20k.py
|---- swin
|---- swin_checkpoint.py # for checkout loading
|---- swin_transformer.py # implementation of swin transformer
```

### Step 2: register Swin Transformer into model registry

The key step that allow MMDet and MMSeg to use a unique implementation of Swin Transformer is to register the backbone into the registry in MMCV.

```python
from mmcv.cnn import MODELS


@MODELS.register_module()
class SwinTransformer(nn.Module):
# code implementation
def __init__(self, *args, **kwargs):
super().__init__()
```

It essentially builds a mapping as below

```python
'SwinTransformer' -> <class 'SwinTransformer'>
```

Because MMDetection and MMSegmentation inherits the model registry in MMCV since v2.12.0 and v0.13.0, their `MODELS` registries are under descendants of the `MODELS` registry in MMCV. Therefore, such a mapping in MMDet/MMSeg becomes

```python
'mmcv.SwinTransformer' -> <class 'SwinTransformer'>
```

To enable the `MODEL.build()` in MMDet/MMSeg to correctly find the implementation of `SwinTransformer`, we need to specify the scope of the module by `mmcv.SwinTransformer` as you will see in the configs.

### Step 3: use Swin Transformer through config

To use Swin Transformer, we can simply use the config and the build function

```python
module_cfg = dict(type='mmcv.SwinTransformer')
module = build_backbone(module_cfg)
```

To run it with MMDetection or MMSegmentation, we need to define the model backbone as below

```python
model = dict(
type='MaskRCNN',
pretrained='./pretrain/swin/swin_tiny_patch4_window7_224.pth',
backbone=dict(type='mmcv.SwinTransformer'))

custom_imports = dict(
imports=['swin.swin_transformer'], allow_failed_imports=False)
```

## Usages

Assume now you are in the directory under `swin_transformer`, to run it with mmdet and slurm, we can use the command as below

```bash
PYTHONPATH='.':$PYTHONPATH mim train mmdet configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py \--work-dir ../work_dir/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py --launcher slurm --partition $PARTITION --gpus 8 --gpus-per-node 8 --srun-args $SRUN_ARGS
```

To run it with mmseg, we can use the command as below

```bash
PYTHONPATH='.':$PYTHONPATH mim train mmseg configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py --work-dir ../work_dir/upernet_swin-t_512x512_160k_8x2_ade20k.py --launcher slurm --partition $PARTITION --gpus 8 --gpus-per-node 8 --srun-args $SRUN_ARGS
```

## Results

### ADE20K

| Backbone | Method | Crop Size | Lr Schd | mIoU | Config | Download |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-T | UPerNet | 512x512 | 160K | 44.3 | [config](/configs/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py) | [model](https://download.openmmlab.com/mim-example/swin_transformer/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k_20210613_201937-4f09fb29.pth) &#124; [log](https://download.openmmlab.com/mim-example/swin_transformer/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k_20210613_201937.log.json) |
### COCO

| Backbone | Method | Lr Schd | Bbox mAP | Mask mAP| Config | Download |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-T | Mask R-CNN | 1x| 42.6| 39.5 |[config](/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_1x_coco_20210612_135948-bf3d7aa4.pth) &#124; [log](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_1x_coco_20210612_135948.log.json) |
| Swin-T | Mask R-CNN | FP16 1x| 42.5|39.3 |[config](/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py) | [model](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco_20210612_135948-6434d76f.pth) &#124; [log](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco_20210612_135948.log.json) |
54 changes: 54 additions & 0 deletions swin_transformer/configs/_base_/datasets/ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))
48 changes: 48 additions & 0 deletions swin_transformer/configs/_base_/datasets/coco_instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
evaluation = dict(metric=['bbox', 'segm'])
16 changes: 16 additions & 0 deletions swin_transformer/configs/_base_/default_runtime_det.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# custom_hooks = [dict(type='NumClassCheckHook')]

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
14 changes: 14 additions & 0 deletions swin_transformer/configs/_base_/default_runtime_seg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True
Loading

0 comments on commit fa4a200

Please sign in to comment.