From f8079b7ec4053f8d39aeae2f62f981dd2185cd6d Mon Sep 17 00:00:00 2001 From: wwzhang Date: Mon, 24 May 2021 22:57:12 +0800 Subject: [PATCH 01/10] support swin transformer --- swin_transformer/README.md | 24 + .../configs/_base_/datasets/ade20k.py | 54 ++ .../configs/_base_/datasets/coco_instance.py | 48 ++ .../configs/_base_/default_runtime_det.py | 16 + .../configs/_base_/default_runtime_seg.py | 14 + .../models/cascade_mask_rcnn_r50_fpn.py | 196 +++++ .../_base_/models/cascade_rcnn_r50_fpn.py | 179 +++++ .../_base_/models/faster_rcnn_r50_fpn.py | 107 +++ .../_base_/models/mask_rcnn_r50_fpn.py | 120 +++ .../configs/_base_/models/upernet_r50.py | 44 ++ .../configs/_base_/schedules/schedule_160k.py | 9 + .../configs/_base_/schedules/schedule_1x.py | 11 + .../configs/_base_/schedules/schedule_80k.py | 9 + .../mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py | 45 ++ ...mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py | 47 ++ ...n_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py | 94 +++ ...k_rcnn_swim-t-p4-w7_fpn_ms-crop-3x_coco.py | 92 +++ .../upernet_swin-t_512x512_160k_8x2_ade20k.py | 53 ++ swin_transformer/swin/swin_checkpoint.py | 491 ++++++++++++ swin_transformer/swin/swin_transformer.py | 710 ++++++++++++++++++ 20 files changed, 2363 insertions(+) create mode 100644 swin_transformer/README.md create mode 100644 swin_transformer/configs/_base_/datasets/ade20k.py create mode 100644 swin_transformer/configs/_base_/datasets/coco_instance.py create mode 100644 swin_transformer/configs/_base_/default_runtime_det.py create mode 100644 swin_transformer/configs/_base_/default_runtime_seg.py create mode 100644 swin_transformer/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py create mode 100644 swin_transformer/configs/_base_/models/cascade_rcnn_r50_fpn.py create mode 100644 swin_transformer/configs/_base_/models/faster_rcnn_r50_fpn.py create mode 100644 swin_transformer/configs/_base_/models/mask_rcnn_r50_fpn.py create mode 100644 swin_transformer/configs/_base_/models/upernet_r50.py create mode 100644 swin_transformer/configs/_base_/schedules/schedule_160k.py create mode 100644 swin_transformer/configs/_base_/schedules/schedule_1x.py create mode 100644 swin_transformer/configs/_base_/schedules/schedule_80k.py create mode 100644 swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py create mode 100644 swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py create mode 100644 swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py create mode 100644 swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_ms-crop-3x_coco.py create mode 100644 swin_transformer/configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py create mode 100644 swin_transformer/swin/swin_checkpoint.py create mode 100644 swin_transformer/swin/swin_transformer.py diff --git a/swin_transformer/README.md b/swin_transformer/README.md new file mode 100644 index 0000000..cfbcb14 --- /dev/null +++ b/swin_transformer/README.md @@ -0,0 +1,24 @@ +# Swin Transformer for Object Detection and Segmentation + +This is an unofficial implementation of Swin Transformer. +It implements Swin Transformer for object detection and segmentation tasks to show how we can use [MIM](https://github.com/open-mmlab/mim) to accelerate the research projects. + +## Requirements + +- MMCV-full v1.3.4 +- MMDetection v2.12.0 +- MMSegmentation v0.13.0 + +You can install them after installing mim through the following commands + +```bash +mim install mmcv==1.3.4 +mim install mmdet==2.12.0 +mim install mmsegmentation=0.13.0 +``` + +## Explaination + +Because MMDetection and MMSegmentation inherits the model registry in MMCV since v2.12.0 and v0.13.0, we only need the implementation of swin transformer and add it into the model registry of MMCV. Then we can use it for object detection and segmentation by modifying configs. + +The implementation of Swin Transformer and its pre-trained models are taken from the [official implementation]() diff --git a/swin_transformer/configs/_base_/datasets/ade20k.py b/swin_transformer/configs/_base_/datasets/ade20k.py new file mode 100644 index 0000000..efc8b4b --- /dev/null +++ b/swin_transformer/configs/_base_/datasets/ade20k.py @@ -0,0 +1,54 @@ +# dataset settings +dataset_type = 'ADE20KDataset' +data_root = 'data/ade/ADEChallengeData2016' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 512), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/swin_transformer/configs/_base_/datasets/coco_instance.py b/swin_transformer/configs/_base_/datasets/coco_instance.py new file mode 100644 index 0000000..f6ea4f4 --- /dev/null +++ b/swin_transformer/configs/_base_/datasets/coco_instance.py @@ -0,0 +1,48 @@ +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +evaluation = dict(metric=['bbox', 'segm']) diff --git a/swin_transformer/configs/_base_/default_runtime_det.py b/swin_transformer/configs/_base_/default_runtime_det.py new file mode 100644 index 0000000..a7dc78e --- /dev/null +++ b/swin_transformer/configs/_base_/default_runtime_det.py @@ -0,0 +1,16 @@ +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# custom_hooks = [dict(type='NumClassCheckHook')] + +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/swin_transformer/configs/_base_/default_runtime_seg.py b/swin_transformer/configs/_base_/default_runtime_seg.py new file mode 100644 index 0000000..b564cc4 --- /dev/null +++ b/swin_transformer/configs/_base_/default_runtime_seg.py @@ -0,0 +1,14 @@ +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook', by_epoch=False), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True diff --git a/swin_transformer/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py b/swin_transformer/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py new file mode 100644 index 0000000..9ef6673 --- /dev/null +++ b/swin_transformer/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py @@ -0,0 +1,196 @@ +# model settings +model = dict( + type='CascadeRCNN', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + roi_head=dict( + type='CascadeRoIHead', + num_stages=3, + stage_loss_weights=[1, 0.5, 0.25], + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=[ + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, + loss_weight=1.0)), + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.05, 0.05, 0.1, 0.1]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, + loss_weight=1.0)), + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.033, 0.033, 0.067, 0.067]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) + ], + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=80, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=[ + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.6, + neg_iou_thr=0.6, + min_pos_iou=0.6, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.7, + min_pos_iou=0.7, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False) + ]), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100, + mask_thr_binary=0.5))) diff --git a/swin_transformer/configs/_base_/models/cascade_rcnn_r50_fpn.py b/swin_transformer/configs/_base_/models/cascade_rcnn_r50_fpn.py new file mode 100644 index 0000000..cde2a96 --- /dev/null +++ b/swin_transformer/configs/_base_/models/cascade_rcnn_r50_fpn.py @@ -0,0 +1,179 @@ +# model settings +model = dict( + type='CascadeRCNN', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + roi_head=dict( + type='CascadeRoIHead', + num_stages=3, + stage_loss_weights=[1, 0.5, 0.25], + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=[ + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, + loss_weight=1.0)), + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.05, 0.05, 0.1, 0.1]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, + loss_weight=1.0)), + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.033, 0.033, 0.067, 0.067]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) + ]), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=[ + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.6, + neg_iou_thr=0.6, + min_pos_iou=0.6, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.7, + min_pos_iou=0.7, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False) + ]), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100))) diff --git a/swin_transformer/configs/_base_/models/faster_rcnn_r50_fpn.py b/swin_transformer/configs/_base_/models/faster_rcnn_r50_fpn.py new file mode 100644 index 0000000..0f038d1 --- /dev/null +++ b/swin_transformer/configs/_base_/models/faster_rcnn_r50_fpn.py @@ -0,0 +1,107 @@ +model = dict( + type='FasterRCNN', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + )) diff --git a/swin_transformer/configs/_base_/models/mask_rcnn_r50_fpn.py b/swin_transformer/configs/_base_/models/mask_rcnn_r50_fpn.py new file mode 100644 index 0000000..6fc7908 --- /dev/null +++ b/swin_transformer/configs/_base_/models/mask_rcnn_r50_fpn.py @@ -0,0 +1,120 @@ +# model settings +model = dict( + type='MaskRCNN', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=80, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100, + mask_thr_binary=0.5))) diff --git a/swin_transformer/configs/_base_/models/upernet_r50.py b/swin_transformer/configs/_base_/models/upernet_r50.py new file mode 100644 index 0000000..1097496 --- /dev/null +++ b/swin_transformer/configs/_base_/models/upernet_r50.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='UPerHead', + in_channels=[256, 512, 1024, 2048], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/swin_transformer/configs/_base_/schedules/schedule_160k.py b/swin_transformer/configs/_base_/schedules/schedule_160k.py new file mode 100644 index 0000000..5260389 --- /dev/null +++ b/swin_transformer/configs/_base_/schedules/schedule_160k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=160000) +checkpoint_config = dict(by_epoch=False, interval=16000) +evaluation = dict(interval=16000, metric='mIoU') diff --git a/swin_transformer/configs/_base_/schedules/schedule_1x.py b/swin_transformer/configs/_base_/schedules/schedule_1x.py new file mode 100644 index 0000000..13b3783 --- /dev/null +++ b/swin_transformer/configs/_base_/schedules/schedule_1x.py @@ -0,0 +1,11 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[8, 11]) +runner = dict(type='EpochBasedRunner', max_epochs=12) diff --git a/swin_transformer/configs/_base_/schedules/schedule_80k.py b/swin_transformer/configs/_base_/schedules/schedule_80k.py new file mode 100644 index 0000000..c190cee --- /dev/null +++ b/swin_transformer/configs/_base_/schedules/schedule_80k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=80000) +checkpoint_config = dict(by_epoch=False, interval=8000) +evaluation = dict(interval=8000, metric='mIoU') diff --git a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py new file mode 100644 index 0000000..e193390 --- /dev/null +++ b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py @@ -0,0 +1,45 @@ +_base_ = [ + '../_base_/models/mask_rcnn_r50_fpn.py', + '../_base_/datasets/coco_instance.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] + +model = dict( + type='MaskRCNN', + pretrained='./pretrain/swin/swin_tiny_patch4_window7_224.pth', + backbone=dict( + _delete_=True, + type='SwinTransformer', + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + use_checkpoint=False), + neck=dict(in_channels=[96, 192, 384, 768])) + +optimizer = dict( + _delete_=True, + type='AdamW', + lr=0.0001, + betas=(0.9, 0.999), + weight_decay=0.05, + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + })) +lr_config = dict(warmup_iters=1000, step=[8, 11]) +runner = dict(max_epochs=12) + +custom_imports = dict( + imports=['swin.swin_transformer'], allow_failed_imports=False) diff --git a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py new file mode 100644 index 0000000..2d8bdb8 --- /dev/null +++ b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py @@ -0,0 +1,47 @@ +_base_ = [ + '../_base_/models/mask_rcnn_r50_fpn.py', + '../_base_/datasets/coco_instance.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime_det.py' +] + +model = dict( + type='MaskRCNN', + pretrained='./pretrain/swin/swin_tiny_patch4_window7_224.pth', + backbone=dict( + _delete_=True, + type='SwinTransformer', + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + use_checkpoint=False), + neck=dict(in_channels=[96, 192, 384, 768])) + +optimizer = dict( + _delete_=True, + type='AdamW', + lr=0.0001, + betas=(0.9, 0.999), + weight_decay=0.05, + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + })) +lr_config = dict(warmup_iters=1000, step=[8, 11]) +runner = dict(max_epochs=12) + +custom_imports = dict( + imports=['swin.swin_transformer'], allow_failed_imports=False) + +fp16 = dict(loss_scale=dict(init_scale=512, mode='dynamic')) diff --git a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py new file mode 100644 index 0000000..96d5f2d --- /dev/null +++ b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py @@ -0,0 +1,94 @@ +_base_ = [ + '../_base_/models/mask_rcnn_r50_fpn.py', + '../_base_/datasets/coco_instance.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime_det.py' +] + +model = dict( + type='MaskRCNN', + pretrained='./pretrain/swin/swin_tiny_patch4_window7_224.pth', + backbone=dict( + _delete_=True, + type='SwinTransformer', + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + use_checkpoint=False), + neck=dict(in_channels=[96, 192, 384, 768])) + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +# augmentation strategy originates from DETR / Sparse RCNN +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict( + type='AutoAugment', + policies=[[ + dict( + type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + multiscale_mode='value', + keep_ratio=True) + ], + [ + dict( + type='Resize', + img_scale=[(400, 1333), (500, 1333), (600, 1333)], + multiscale_mode='value', + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), + (576, 1333), (608, 1333), (640, 1333), + (672, 1333), (704, 1333), (736, 1333), + (768, 1333), (800, 1333)], + multiscale_mode='value', + override=True, + keep_ratio=True) + ]]), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +data = dict(train=dict(pipeline=train_pipeline)) + +optimizer = dict( + _delete_=True, + type='AdamW', + lr=0.0001, + betas=(0.9, 0.999), + weight_decay=0.05, + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + })) +lr_config = dict(warmup_iters=1000, step=[27, 33]) +runner = dict(max_epochs=36) + +custom_imports = dict( + imports=['swin.swin_transformer'], allow_failed_imports=False) + +fp16 = dict(loss_scale=dict(init_scale=512, mode='dynamic')) diff --git a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_ms-crop-3x_coco.py b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_ms-crop-3x_coco.py new file mode 100644 index 0000000..209a23c --- /dev/null +++ b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_ms-crop-3x_coco.py @@ -0,0 +1,92 @@ +_base_ = [ + '../_base_/models/mask_rcnn_r50_fpn.py', + '../_base_/datasets/coco_instance.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime_det.py' +] + +model = dict( + type='MaskRCNN', + pretrained='./pretrain/swin/swin_tiny_patch4_window7_224.pth', + backbone=dict( + _delete_=True, + type='SwinTransformer', + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + use_checkpoint=False), + neck=dict(in_channels=[96, 192, 384, 768])) + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +# augmentation strategy originates from DETR / Sparse RCNN +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict( + type='AutoAugment', + policies=[[ + dict( + type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + multiscale_mode='value', + keep_ratio=True) + ], + [ + dict( + type='Resize', + img_scale=[(400, 1333), (500, 1333), (600, 1333)], + multiscale_mode='value', + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), + (576, 1333), (608, 1333), (640, 1333), + (672, 1333), (704, 1333), (736, 1333), + (768, 1333), (800, 1333)], + multiscale_mode='value', + override=True, + keep_ratio=True) + ]]), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +data = dict(train=dict(pipeline=train_pipeline)) + +optimizer = dict( + _delete_=True, + type='AdamW', + lr=0.0001, + betas=(0.9, 0.999), + weight_decay=0.05, + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + })) +lr_config = dict(warmup_iters=1000, step=[27, 33]) +runner = dict(max_epochs=36) + +custom_imports = dict( + imports=['swin.swin_transformer'], allow_failed_imports=False) diff --git a/swin_transformer/configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py b/swin_transformer/configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py new file mode 100644 index 0000000..2098116 --- /dev/null +++ b/swin_transformer/configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py @@ -0,0 +1,53 @@ +_base_ = [ + '../_base_/models/upernet_r50.py', '../_base_/datasets/ade20k.py', + '../_base_/default_runtime_seg.py', '../_base_/schedules/schedule_160k.py' +] +model = dict( + pretrained='./pretrain/swin/swin_tiny_patch4_window7_224.pth', + backbone=dict( + _delete_=True, + type='SwinTransformer', + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + use_checkpoint=False), + decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150), + auxiliary_head=dict(in_channels=384, num_classes=150)) + +optimizer = dict( + _delete_=True, + type='AdamW', + lr=0.00006, + betas=(0.9, 0.999), + weight_decay=0.01, + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + })) + +lr_config = dict( + _delete_=True, + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, + min_lr=0.0, + by_epoch=False) + +data = dict(samples_per_gpu=2, workers_per_gpu=2) + +custom_imports = dict( + imports=['swin.swin_transformer'], allow_failed_imports=False) diff --git a/swin_transformer/swin/swin_checkpoint.py b/swin_transformer/swin/swin_checkpoint.py new file mode 100644 index 0000000..09e587b --- /dev/null +++ b/swin_transformer/swin/swin_checkpoint.py @@ -0,0 +1,491 @@ +# Copyright (c) Open-MMLab. All rights reserved. +import io +import os +import os.path as osp +import pkgutil +import time +import warnings +from collections import OrderedDict +from importlib import import_module +from tempfile import TemporaryDirectory + +import mmcv +import torch +import torchvision +from mmcv.fileio import FileClient +from mmcv.fileio import load as load_file +from mmcv.parallel import is_module_wrapper +from mmcv.runner import get_dist_info +from mmcv.utils import mkdir_or_exist +from torch.nn import functional as F +from torch.optim import Optimizer +from torch.utils import model_zoo + +ENV_MMCV_HOME = 'MMCV_HOME' +ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' +DEFAULT_CACHE_DIR = '~/.cache' + + +def _get_mmcv_home(): + mmcv_home = os.path.expanduser( + os.getenv( + ENV_MMCV_HOME, + os.path.join( + os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) + + mkdir_or_exist(mmcv_home) + return mmcv_home + + +def load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + all_missing_keys = [] + err_msg = [] + + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + # use _load_from_state_dict to enable checkpoint version control + def load(module, prefix=''): + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, True, + all_missing_keys, unexpected_keys, + err_msg) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(module) + load = None # break load->load reference cycle + + # ignore "num_batches_tracked" of BN layers + missing_keys = [ + key for key in all_missing_keys if 'num_batches_tracked' not in key + ] + + if unexpected_keys: + err_msg.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + rank, _ = get_dist_info() + if len(err_msg) > 0 and rank == 0: + err_msg.insert( + 0, 'The model and loaded state dict do not match exactly\n') + err_msg = '\n'.join(err_msg) + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + + +def load_url_dist(url, model_dir=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + return checkpoint + + +def load_pavimodel_dist(model_path, map_location=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + try: + from pavi import modelcloud + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load( + downloaded_file, map_location=map_location) + return checkpoint + + +def load_fileclient_dist(filename, backend, map_location): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + allowed_backends = ['ceph'] + if backend not in allowed_backends: + raise ValueError(f'Load from Backend {backend} is not supported.') + if rank == 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + return checkpoint + + +def get_torchvision_models(): + model_urls = dict() + for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + return model_urls + + +def get_external_models(): + mmcv_home = _get_mmcv_home() + default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') + default_urls = load_file(default_json_path) + assert isinstance(default_urls, dict) + external_json_path = osp.join(mmcv_home, 'open_mmlab.json') + if osp.exists(external_json_path): + external_urls = load_file(external_json_path) + assert isinstance(external_urls, dict) + default_urls.update(external_urls) + + return default_urls + + +def get_mmcls_models(): + mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') + mmcls_urls = load_file(mmcls_json_path) + + return mmcls_urls + + +def get_deprecated_model_names(): + deprecate_json_path = osp.join(mmcv.__path__[0], + 'model_zoo/deprecated.json') + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + + return deprecate_urls + + +def _process_mmcls_checkpoint(checkpoint): + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('backbone.'): + new_state_dict[k[9:]] = v + new_checkpoint = dict(state_dict=new_state_dict) + + return new_checkpoint + + +def _load_checkpoint(filename, map_location=None): + """Load checkpoint from somewhere (modelzoo, file, url). + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. Default: None. + Returns: + dict | OrderedDict: The loaded checkpoint. It can be either an + OrderedDict storing model weights or a dict containing other + information, which depends on the checkpoint. + """ + if filename.startswith('modelzoo://'): + warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' + 'use "torchvision://" instead') + model_urls = get_torchvision_models() + model_name = filename[11:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('torchvision://'): + model_urls = get_torchvision_models() + model_name = filename[14:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('open-mmlab://'): + model_urls = get_external_models() + model_name = filename[13:] + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' + f'of open-mmlab://{deprecated_urls[model_name]}') + model_name = deprecated_urls[model_name] + model_url = model_urls[model_name] + # check if is url + if model_url.startswith(('http://', 'https://')): + checkpoint = load_url_dist(model_url) + else: + filename = osp.join(_get_mmcv_home(), model_url) + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + elif filename.startswith('mmcls://'): + model_urls = get_mmcls_models() + model_name = filename[8:] + checkpoint = load_url_dist(model_urls[model_name]) + checkpoint = _process_mmcls_checkpoint(checkpoint) + elif filename.startswith(('http://', 'https://')): + checkpoint = load_url_dist(filename) + elif filename.startswith('pavi://'): + model_path = filename[7:] + checkpoint = load_pavimodel_dist(model_path, map_location=map_location) + elif filename.startswith('s3://'): + checkpoint = load_fileclient_dist( + filename, backend='ceph', map_location=map_location) + else: + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +def load_checkpoint(model, + filename, + map_location='cpu', + strict=False, + logger=None): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # reshape absolute position embedding + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = model.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H * W: + logger.warning('Error in loading absolute_pos_embed, pass') + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view( + N2, H, W, C2).permute(0, 3, 1, 2) + + # interpolate position bias table if needed + relative_position_bias_table_keys = [ + k for k in state_dict.keys() if 'relative_position_bias_table' in k + ] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + table_current = model.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + logger.warning(f'Error in loading {table_key}, pass') + else: + if L1 != L2: + S1 = int(L1**0.5) + S2 = int(L2**0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).view(1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view( + nH2, L2).permute(1, 0) + + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + return state_dict_cpu + + +def _save_to_state_dict(module, destination, prefix, keep_vars): + """Saves module state to `destination` dictionary. + + This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. + Args: + module (nn.Module): The module to generate state_dict. + destination (dict): A dict where state will be stored. + prefix (str): The prefix for parameters and buffers used in this + module. + """ + for name, param in module._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in module._buffers.items(): + # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d + if buf is not None: + destination[prefix + name] = buf if keep_vars else buf.detach() + + +def get_state_dict(module, destination=None, prefix='', keep_vars=False): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + This method is modified from :meth:`torch.nn.Module.state_dict` to + recursively check parallel module in case that the model has a complicated + structure, e.g., nn.Module(nn.Module(DDP)). + Args: + module (nn.Module): The module to generate state_dict. + destination (OrderedDict): Returned dict for the state of the + module. + prefix (str): Prefix of the key. + keep_vars (bool): Whether to keep the variable property of the + parameters. Default: False. + Returns: + dict: A dictionary containing a whole state of the module. + """ + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + + # below is the same as torch.nn.Module.state_dict() + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict( + version=module._version) + _save_to_state_dict(module, destination, prefix, keep_vars) + for name, child in module._modules.items(): + if child is not None: + get_state_dict( + child, destination, prefix + name + '.', keep_vars=keep_vars) + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + +def save_checkpoint(model, filename, optimizer=None, meta=None): + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f'meta must be a dict or None, but got {type(meta)}') + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, 'CLASSES') and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(get_state_dict(model)) + } + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint['optimizer'] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint['optimizer'] = {} + for name, optim in optimizer.items(): + checkpoint['optimizer'][name] = optim.state_dict() + + if filename.startswith('pavi://'): + try: + from pavi import modelcloud + from pavi.exception import NodeNotFoundError + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, 'wb') as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + mmcv.mkdir_or_exist(osp.dirname(filename)) + # immediately flush buffer + with open(filename, 'wb') as f: + torch.save(checkpoint, f) + f.flush() diff --git a/swin_transformer/swin/swin_transformer.py b/swin_transformer/swin/swin_transformer.py new file mode 100644 index 0000000..65f252d --- /dev/null +++ b/swin_transformer/swin/swin_transformer.py @@ -0,0 +1,710 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from mmdet.utils import get_root_logger +from mmcv.cnn import MODELS +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from .swin_checkpoint import load_checkpoint + + +class Mlp(nn.Module): + """Multilayer perceptron.""" + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Window based multi-head self attention (W-MSA) module with relative + position bias. + + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """Forward function. + + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, 'input feature has wrong size' + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, + Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + attn_mask = attn_mask.to(dtype=x.dtype) + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, + (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +@MODELS.register_module() +class SwinTransformer(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1] + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], + patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if + (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, + 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() From ede2109dc964567dacbcd65e5ce961f34b95718d Mon Sep 17 00:00:00 2001 From: ZwwWayne Date: Wed, 26 May 2021 12:22:40 +0800 Subject: [PATCH 02/10] add usages --- swin_transformer/README.md | 19 +++++++++++++++++-- swin_transformer/slurm_train.sh | 25 +++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 swin_transformer/slurm_train.sh diff --git a/swin_transformer/README.md b/swin_transformer/README.md index cfbcb14..9a8d37b 100644 --- a/swin_transformer/README.md +++ b/swin_transformer/README.md @@ -12,7 +12,8 @@ It implements Swin Transformer for object detection and segmentation tasks to sh You can install them after installing mim through the following commands ```bash -mim install mmcv==1.3.4 +pip install openmim # install mim through pypi +mim install mmcv-full==1.3.4 mim install mmdet==2.12.0 mim install mmsegmentation=0.13.0 ``` @@ -21,4 +22,18 @@ mim install mmsegmentation=0.13.0 Because MMDetection and MMSegmentation inherits the model registry in MMCV since v2.12.0 and v0.13.0, we only need the implementation of swin transformer and add it into the model registry of MMCV. Then we can use it for object detection and segmentation by modifying configs. -The implementation of Swin Transformer and its pre-trained models are taken from the [official implementation]() +The implementation of Swin Transformer and its pre-trained models are taken from the [official implementation](https://github.com/microsoft/Swin-Transformer) + +## Usages + +To run it with mmdet, we can use the command as below + +```bash +sh ./slurm_train.sh mmdet $PARTITION $JOB_NAME configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py ./work_dir/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py +``` + +To run it with mmseg, we can use the command as below + +```bash +sh ./slurm_train.sh mmseg $PARTITION $JOB_NAME configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py ./work_dir/upernet_swin-t_512x512_160k_8x2_ade20k.py +``` diff --git a/swin_transformer/slurm_train.sh b/swin_transformer/slurm_train.sh new file mode 100644 index 0000000..27a7315 --- /dev/null +++ b/swin_transformer/slurm_train.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +set -x + +REPO=$1 +PARTITION=$2 +JOB_NAME=$3 +CONFIG=$4 +WORK_DIR=$5 +GPUS=${GPUS:-8} +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +SRUN_ARGS=${SRUN_ARGS:-"-x SH-IDC1-10-198-4-[92,94]"} +PY_ARGS=${@:6} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + mim run ${REPO} train ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} From fb93cabf1bc22334b4d02d7da1af9d6873ffa6a2 Mon Sep 17 00:00:00 2001 From: wwzhang Date: Tue, 8 Jun 2021 23:45:25 +0800 Subject: [PATCH 03/10] support training with mmdet --- .gitignore | 121 ++++++++++++++++++ swin_transformer/README.md | 23 ++-- .../mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py | 5 +- ...mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py | 3 +- ...n_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py | 3 +- ...k_rcnn_swim-t-p4-w7_fpn_ms-crop-3x_coco.py | 3 +- .../upernet_swin-t_512x512_160k_8x2_ade20k.py | 3 +- swin_transformer/slurm_train.sh | 25 ---- swin_transformer/swin/swin_transformer.py | 13 +- 9 files changed, 157 insertions(+), 42 deletions(-) create mode 100644 .gitignore delete mode 100644 swin_transformer/slurm_train.sh diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..77ca0d7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,121 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +data/ +data +.vscode +.idea +.DS_Store + +# custom +*.pkl +*.pkl.json +*.log.json +work_dirs/ + +# Pytorch +*.pth +*.py~ +*.sh~ diff --git a/swin_transformer/README.md b/swin_transformer/README.md index 9a8d37b..ba32a80 100644 --- a/swin_transformer/README.md +++ b/swin_transformer/README.md @@ -5,19 +5,24 @@ It implements Swin Transformer for object detection and segmentation tasks to sh ## Requirements -- MMCV-full v1.3.4 -- MMDetection v2.12.0 -- MMSegmentation v0.13.0 +- MIM 0.1.0 +- MMCV-full v1.3.5 +- MMDetection v2.13.0 +- MMSegmentation v0.14.0 +- timm You can install them after installing mim through the following commands ```bash pip install openmim # install mim through pypi -mim install mmcv-full==1.3.4 -mim install mmdet==2.12.0 -mim install mmsegmentation=0.13.0 +pip install timm # swin transformer relies timm +mim install mmcv-full==1.3.5 # install mmcv +MKL_THREADING_LAYER=GNU mim install mmdet==2.13.0 # install mmdet to run object detection +MKL_THREADING_LAYER=GNU mim install mmsegmentation=0.14.0 # install mmseg to run semantic segmentation ``` +**Note**: `MKL_THREADING_LAYER=GNU` is workaround according to the [issue](https://github.com/pytorch/pytorch/issues/37377). + ## Explaination Because MMDetection and MMSegmentation inherits the model registry in MMCV since v2.12.0 and v0.13.0, we only need the implementation of swin transformer and add it into the model registry of MMCV. Then we can use it for object detection and segmentation by modifying configs. @@ -26,14 +31,14 @@ The implementation of Swin Transformer and its pre-trained models are taken from ## Usages -To run it with mmdet, we can use the command as below +Assume now you are in the directory under `swin_transformer`, to run it with mmdet and slurm, we can use the command as below ```bash -sh ./slurm_train.sh mmdet $PARTITION $JOB_NAME configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py ./work_dir/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py +PYTHONPATH='.':$PYTHONPATH mim train mmdet configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py --work-dir ../work_dir/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py --launcher slurm --partition $PARTITION --gpus 8 --gpus-per-node 8 --srun-args ${SRUN_ARGS} ``` To run it with mmseg, we can use the command as below ```bash -sh ./slurm_train.sh mmseg $PARTITION $JOB_NAME configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py ./work_dir/upernet_swin-t_512x512_160k_8x2_ade20k.py +PYTHONPATH='.':$PYTHONPATH mim train mmseg configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py --work-dir ../work_dir/upernet_swin-t_512x512_160k_8x2_ade20k.py --launcher slurm --partition $PARTITION --gpus 8 --gpus-per-node 8 --srun-args ${SRUN_ARGS} ``` diff --git a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py index e193390..61d3f35 100644 --- a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py +++ b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py @@ -1,7 +1,7 @@ _base_ = [ '../_base_/models/mask_rcnn_r50_fpn.py', '../_base_/datasets/coco_instance.py', - '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime_det.py' ] model = dict( @@ -9,7 +9,8 @@ pretrained='./pretrain/swin/swin_tiny_patch4_window7_224.pth', backbone=dict( _delete_=True, - type='SwinTransformer', + # SwinTransformer is registered in the MMCV MODELS registry + type='mmcv.SwinTransformer', embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], diff --git a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py index 2d8bdb8..c738341 100644 --- a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py +++ b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py @@ -9,7 +9,8 @@ pretrained='./pretrain/swin/swin_tiny_patch4_window7_224.pth', backbone=dict( _delete_=True, - type='SwinTransformer', + # SwinTransformer is registered in the MMCV MODELS registry + type='mmcv.SwinTransformer', embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], diff --git a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py index 96d5f2d..7e40bb1 100644 --- a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py +++ b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py @@ -9,7 +9,8 @@ pretrained='./pretrain/swin/swin_tiny_patch4_window7_224.pth', backbone=dict( _delete_=True, - type='SwinTransformer', + # SwinTransformer is registered in the MMCV MODELS registry + type='mmcv.SwinTransformer', embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], diff --git a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_ms-crop-3x_coco.py b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_ms-crop-3x_coco.py index 209a23c..dd4c220 100644 --- a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_ms-crop-3x_coco.py +++ b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_ms-crop-3x_coco.py @@ -9,7 +9,8 @@ pretrained='./pretrain/swin/swin_tiny_patch4_window7_224.pth', backbone=dict( _delete_=True, - type='SwinTransformer', + # SwinTransformer is registered in the MMCV MODELS registry + type='mmcv.SwinTransformer', embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], diff --git a/swin_transformer/configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py b/swin_transformer/configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py index 2098116..1ee2018 100644 --- a/swin_transformer/configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py +++ b/swin_transformer/configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py @@ -6,7 +6,8 @@ pretrained='./pretrain/swin/swin_tiny_patch4_window7_224.pth', backbone=dict( _delete_=True, - type='SwinTransformer', + # SwinTransformer is registered in the MMCV MODELS registry + type='mmcv.SwinTransformer', embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], diff --git a/swin_transformer/slurm_train.sh b/swin_transformer/slurm_train.sh deleted file mode 100644 index 27a7315..0000000 --- a/swin_transformer/slurm_train.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env bash - -set -x - -REPO=$1 -PARTITION=$2 -JOB_NAME=$3 -CONFIG=$4 -WORK_DIR=$5 -GPUS=${GPUS:-8} -GPUS_PER_NODE=${GPUS_PER_NODE:-8} -CPUS_PER_TASK=${CPUS_PER_TASK:-5} -SRUN_ARGS=${SRUN_ARGS:-"-x SH-IDC1-10-198-4-[92,94]"} -PY_ARGS=${@:6} - -PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ -srun -p ${PARTITION} \ - --job-name=${JOB_NAME} \ - --gres=gpu:${GPUS_PER_NODE} \ - --ntasks=${GPUS} \ - --ntasks-per-node=${GPUS_PER_NODE} \ - --cpus-per-task=${CPUS_PER_TASK} \ - --kill-on-bad-exit=1 \ - ${SRUN_ARGS} \ - mim run ${REPO} train ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} diff --git a/swin_transformer/swin/swin_transformer.py b/swin_transformer/swin/swin_transformer.py index 65f252d..69e3da2 100644 --- a/swin_transformer/swin/swin_transformer.py +++ b/swin_transformer/swin/swin_transformer.py @@ -539,6 +539,7 @@ class SwinTransformer(nn.Module): frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + pretrained (str, optional): The path of the pretrained models to initilize the module. """ def __init__(self, @@ -560,7 +561,8 @@ def __init__(self, patch_norm=True, out_indices=(0, 1, 2, 3), frozen_stages=-1, - use_checkpoint=False): + use_checkpoint=False, + pretrained=None): super().__init__() self.pretrain_img_size = pretrain_img_size @@ -570,6 +572,7 @@ def __init__(self, self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages + self.pretrained = pretrained # split image into non-overlapping patches self.patch_embed = PatchEmbed( @@ -654,7 +657,13 @@ def init_weights(self, pretrained=None): pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ - + # MMDetection now use init_cfg to initialize modules while + # MMSegmentation will do that in the next release. + # The logic below makes the initialization behavior compatible with + # MMDetection and MMSegmentation. + if pretrained is None and self.pretrained is not None: + pretrained = self.pretrained + def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) From 8d0872e85205d8ce14f7e2d3136bd05b9277caa1 Mon Sep 17 00:00:00 2001 From: wwzhang Date: Fri, 11 Jun 2021 14:24:46 +0800 Subject: [PATCH 04/10] add results table --- swin_transformer/README.md | 16 ++ .../_base_/models/cascade_rcnn_r50_fpn.py | 179 ------------------ .../_base_/models/faster_rcnn_r50_fpn.py | 107 ----------- .../upernet_swin-t_512x512_160k_8x2_ade20k.py | 2 +- 4 files changed, 17 insertions(+), 287 deletions(-) delete mode 100644 swin_transformer/configs/_base_/models/cascade_rcnn_r50_fpn.py delete mode 100644 swin_transformer/configs/_base_/models/faster_rcnn_r50_fpn.py rename swin_transformer/configs/{upernet => swin_upernet}/upernet_swin-t_512x512_160k_8x2_ade20k.py (98%) diff --git a/swin_transformer/README.md b/swin_transformer/README.md index ba32a80..8377fc7 100644 --- a/swin_transformer/README.md +++ b/swin_transformer/README.md @@ -42,3 +42,19 @@ To run it with mmseg, we can use the command as below ```bash PYTHONPATH='.':$PYTHONPATH mim train mmseg configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py --work-dir ../work_dir/upernet_swin-t_512x512_160k_8x2_ade20k.py --launcher slurm --partition $PARTITION --gpus 8 --gpus-per-node 8 --srun-args ${SRUN_ARGS} ``` + + +## Results + +### ADE20K + +| Backbone | Method | Crop Size | Lr Schd | mIoU | Config | Download | +| :---: | :---: | :---: | :---: | :---: | :---: | :---: | +| Swin-T | UPerNet | 512x512 | 160K | | [config](swin_transformer/configs/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py) | [model]() | [log]() | + +### COCO + +| Backbone | Method | Lr Schd | Bbox mAP | Mask mAP| Config | Download | +| :---: | :---: | :---: | :---: | :---: | :---: | :---: | +| Swin-T | Mask R-CNN | 1x| | |[config](swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py) | [model]() | [log]() | +| Swin-T | Mask R-CNN | FP16 1x| | |[config](swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py) | [model]() | [log]() | diff --git a/swin_transformer/configs/_base_/models/cascade_rcnn_r50_fpn.py b/swin_transformer/configs/_base_/models/cascade_rcnn_r50_fpn.py deleted file mode 100644 index cde2a96..0000000 --- a/swin_transformer/configs/_base_/models/cascade_rcnn_r50_fpn.py +++ /dev/null @@ -1,179 +0,0 @@ -# model settings -model = dict( - type='CascadeRCNN', - pretrained='torchvision://resnet50', - backbone=dict( - type='ResNet', - depth=50, - num_stages=4, - out_indices=(0, 1, 2, 3), - frozen_stages=1, - norm_cfg=dict(type='BN', requires_grad=True), - norm_eval=True, - style='pytorch'), - neck=dict( - type='FPN', - in_channels=[256, 512, 1024, 2048], - out_channels=256, - num_outs=5), - rpn_head=dict( - type='RPNHead', - in_channels=256, - feat_channels=256, - anchor_generator=dict( - type='AnchorGenerator', - scales=[8], - ratios=[0.5, 1.0, 2.0], - strides=[4, 8, 16, 32, 64]), - bbox_coder=dict( - type='DeltaXYWHBBoxCoder', - target_means=[.0, .0, .0, .0], - target_stds=[1.0, 1.0, 1.0, 1.0]), - loss_cls=dict( - type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), - loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), - roi_head=dict( - type='CascadeRoIHead', - num_stages=3, - stage_loss_weights=[1, 0.5, 0.25], - bbox_roi_extractor=dict( - type='SingleRoIExtractor', - roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), - out_channels=256, - featmap_strides=[4, 8, 16, 32]), - bbox_head=[ - dict( - type='Shared2FCBBoxHead', - in_channels=256, - fc_out_channels=1024, - roi_feat_size=7, - num_classes=80, - bbox_coder=dict( - type='DeltaXYWHBBoxCoder', - target_means=[0., 0., 0., 0.], - target_stds=[0.1, 0.1, 0.2, 0.2]), - reg_class_agnostic=True, - loss_cls=dict( - type='CrossEntropyLoss', - use_sigmoid=False, - loss_weight=1.0), - loss_bbox=dict(type='SmoothL1Loss', beta=1.0, - loss_weight=1.0)), - dict( - type='Shared2FCBBoxHead', - in_channels=256, - fc_out_channels=1024, - roi_feat_size=7, - num_classes=80, - bbox_coder=dict( - type='DeltaXYWHBBoxCoder', - target_means=[0., 0., 0., 0.], - target_stds=[0.05, 0.05, 0.1, 0.1]), - reg_class_agnostic=True, - loss_cls=dict( - type='CrossEntropyLoss', - use_sigmoid=False, - loss_weight=1.0), - loss_bbox=dict(type='SmoothL1Loss', beta=1.0, - loss_weight=1.0)), - dict( - type='Shared2FCBBoxHead', - in_channels=256, - fc_out_channels=1024, - roi_feat_size=7, - num_classes=80, - bbox_coder=dict( - type='DeltaXYWHBBoxCoder', - target_means=[0., 0., 0., 0.], - target_stds=[0.033, 0.033, 0.067, 0.067]), - reg_class_agnostic=True, - loss_cls=dict( - type='CrossEntropyLoss', - use_sigmoid=False, - loss_weight=1.0), - loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) - ]), - # model training and testing settings - train_cfg=dict( - rpn=dict( - assigner=dict( - type='MaxIoUAssigner', - pos_iou_thr=0.7, - neg_iou_thr=0.3, - min_pos_iou=0.3, - match_low_quality=True, - ignore_iof_thr=-1), - sampler=dict( - type='RandomSampler', - num=256, - pos_fraction=0.5, - neg_pos_ub=-1, - add_gt_as_proposals=False), - allowed_border=0, - pos_weight=-1, - debug=False), - rpn_proposal=dict( - nms_pre=2000, - max_per_img=2000, - nms=dict(type='nms', iou_threshold=0.7), - min_bbox_size=0), - rcnn=[ - dict( - assigner=dict( - type='MaxIoUAssigner', - pos_iou_thr=0.5, - neg_iou_thr=0.5, - min_pos_iou=0.5, - match_low_quality=False, - ignore_iof_thr=-1), - sampler=dict( - type='RandomSampler', - num=512, - pos_fraction=0.25, - neg_pos_ub=-1, - add_gt_as_proposals=True), - pos_weight=-1, - debug=False), - dict( - assigner=dict( - type='MaxIoUAssigner', - pos_iou_thr=0.6, - neg_iou_thr=0.6, - min_pos_iou=0.6, - match_low_quality=False, - ignore_iof_thr=-1), - sampler=dict( - type='RandomSampler', - num=512, - pos_fraction=0.25, - neg_pos_ub=-1, - add_gt_as_proposals=True), - pos_weight=-1, - debug=False), - dict( - assigner=dict( - type='MaxIoUAssigner', - pos_iou_thr=0.7, - neg_iou_thr=0.7, - min_pos_iou=0.7, - match_low_quality=False, - ignore_iof_thr=-1), - sampler=dict( - type='RandomSampler', - num=512, - pos_fraction=0.25, - neg_pos_ub=-1, - add_gt_as_proposals=True), - pos_weight=-1, - debug=False) - ]), - test_cfg=dict( - rpn=dict( - nms_pre=1000, - max_per_img=1000, - nms=dict(type='nms', iou_threshold=0.7), - min_bbox_size=0), - rcnn=dict( - score_thr=0.05, - nms=dict(type='nms', iou_threshold=0.5), - max_per_img=100))) diff --git a/swin_transformer/configs/_base_/models/faster_rcnn_r50_fpn.py b/swin_transformer/configs/_base_/models/faster_rcnn_r50_fpn.py deleted file mode 100644 index 0f038d1..0000000 --- a/swin_transformer/configs/_base_/models/faster_rcnn_r50_fpn.py +++ /dev/null @@ -1,107 +0,0 @@ -model = dict( - type='FasterRCNN', - pretrained='torchvision://resnet50', - backbone=dict( - type='ResNet', - depth=50, - num_stages=4, - out_indices=(0, 1, 2, 3), - frozen_stages=1, - norm_cfg=dict(type='BN', requires_grad=True), - norm_eval=True, - style='pytorch'), - neck=dict( - type='FPN', - in_channels=[256, 512, 1024, 2048], - out_channels=256, - num_outs=5), - rpn_head=dict( - type='RPNHead', - in_channels=256, - feat_channels=256, - anchor_generator=dict( - type='AnchorGenerator', - scales=[8], - ratios=[0.5, 1.0, 2.0], - strides=[4, 8, 16, 32, 64]), - bbox_coder=dict( - type='DeltaXYWHBBoxCoder', - target_means=[.0, .0, .0, .0], - target_stds=[1.0, 1.0, 1.0, 1.0]), - loss_cls=dict( - type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), - loss_bbox=dict(type='L1Loss', loss_weight=1.0)), - roi_head=dict( - type='StandardRoIHead', - bbox_roi_extractor=dict( - type='SingleRoIExtractor', - roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), - out_channels=256, - featmap_strides=[4, 8, 16, 32]), - bbox_head=dict( - type='Shared2FCBBoxHead', - in_channels=256, - fc_out_channels=1024, - roi_feat_size=7, - num_classes=80, - bbox_coder=dict( - type='DeltaXYWHBBoxCoder', - target_means=[0., 0., 0., 0.], - target_stds=[0.1, 0.1, 0.2, 0.2]), - reg_class_agnostic=False, - loss_cls=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), - loss_bbox=dict(type='L1Loss', loss_weight=1.0))), - # model training and testing settings - train_cfg=dict( - rpn=dict( - assigner=dict( - type='MaxIoUAssigner', - pos_iou_thr=0.7, - neg_iou_thr=0.3, - min_pos_iou=0.3, - match_low_quality=True, - ignore_iof_thr=-1), - sampler=dict( - type='RandomSampler', - num=256, - pos_fraction=0.5, - neg_pos_ub=-1, - add_gt_as_proposals=False), - allowed_border=-1, - pos_weight=-1, - debug=False), - rpn_proposal=dict( - nms_pre=2000, - max_per_img=1000, - nms=dict(type='nms', iou_threshold=0.7), - min_bbox_size=0), - rcnn=dict( - assigner=dict( - type='MaxIoUAssigner', - pos_iou_thr=0.5, - neg_iou_thr=0.5, - min_pos_iou=0.5, - match_low_quality=False, - ignore_iof_thr=-1), - sampler=dict( - type='RandomSampler', - num=512, - pos_fraction=0.25, - neg_pos_ub=-1, - add_gt_as_proposals=True), - pos_weight=-1, - debug=False)), - test_cfg=dict( - rpn=dict( - nms_pre=1000, - max_per_img=1000, - nms=dict(type='nms', iou_threshold=0.7), - min_bbox_size=0), - rcnn=dict( - score_thr=0.05, - nms=dict(type='nms', iou_threshold=0.5), - max_per_img=100) - # soft-nms is also supported for rcnn testing - # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) - )) diff --git a/swin_transformer/configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py b/swin_transformer/configs/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py similarity index 98% rename from swin_transformer/configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py rename to swin_transformer/configs/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py index 1ee2018..552022f 100644 --- a/swin_transformer/configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py +++ b/swin_transformer/configs/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py @@ -17,7 +17,7 @@ qk_scale=None, drop_rate=0., attn_drop_rate=0., - drop_path_rate=0.2, + drop_path_rate=0.3, ape=False, patch_norm=True, out_indices=(0, 1, 2, 3), From b259d11b03a808908c05c350bb1a70d4f5b895d0 Mon Sep 17 00:00:00 2001 From: wwzhang Date: Mon, 14 Jun 2021 10:37:40 +0800 Subject: [PATCH 05/10] update readme --- README.md | 1 + README_zh-CN.md | 1 + swin_transformer/README.md | 6 +++--- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index dd3d4c4..d6c7ea7 100644 --- a/README.md +++ b/README.md @@ -5,3 +5,4 @@ English | [简体中文](README_zh-CN.md) Based on MIM and other OpenMMLab codebases, you can build new projects conveniently by just writing several python files. In this repository we provide some examples: 1. [mmcls_custom_backbone](/mmcls_custom_backbone): Use custom backbone in MMClassification. +2. [Swin Transformer](/swin_transformer): Minimal code implementation of Swin Transformer for object detection and semantic segmentation. diff --git a/README_zh-CN.md b/README_zh-CN.md index 30fe8af..fccf67a 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -5,3 +5,4 @@ 基于 [MIM](https://github.com/open-mmlab/mim) 和 OpenMMLAB 中的代码库,用户仅需编写数个 python 文件就可轻松构建新的项目。在这里我们提供了如下示例 1. [mmcls_custom_backbone](/mmcls_custom_backbone):在 MMClassification 中使用自定义主干网络 +2. [Swin Transformer](/swin_transformer): Swin Transformer 的最简实现,可以直接用于目标检测和语义分割任务 \ No newline at end of file diff --git a/swin_transformer/README.md b/swin_transformer/README.md index 8377fc7..1c63f4f 100644 --- a/swin_transformer/README.md +++ b/swin_transformer/README.md @@ -50,11 +50,11 @@ PYTHONPATH='.':$PYTHONPATH mim train mmseg configs/upernet/upernet_swin-t_512x51 | Backbone | Method | Crop Size | Lr Schd | mIoU | Config | Download | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | -| Swin-T | UPerNet | 512x512 | 160K | | [config](swin_transformer/configs/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py) | [model]() | [log]() | +| Swin-T | UPerNet | 512x512 | 160K | 44.3 | [config](swin_transformer/configs/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py) | [model]() | [log]() | ### COCO | Backbone | Method | Lr Schd | Bbox mAP | Mask mAP| Config | Download | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | -| Swin-T | Mask R-CNN | 1x| | |[config](swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py) | [model]() | [log]() | -| Swin-T | Mask R-CNN | FP16 1x| | |[config](swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py) | [model]() | [log]() | +| Swin-T | Mask R-CNN | 1x| 42.6| 39.5 |[config](swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py) | [model]() | [log]() | +| Swin-T | Mask R-CNN | FP16 1x| 42.5|39.3 |[config](swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py) | [model]() | [log]() | From 02d84006d4e023bb2a300b88f3f59851259cdf43 Mon Sep 17 00:00:00 2001 From: ZwwWayne Date: Mon, 14 Jun 2021 10:55:02 +0800 Subject: [PATCH 06/10] update model links --- swin_transformer/README.md | 4 ++-- .../swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py | 4 ++-- .../mask_rcnn_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/swin_transformer/README.md b/swin_transformer/README.md index 1c63f4f..a241526 100644 --- a/swin_transformer/README.md +++ b/swin_transformer/README.md @@ -56,5 +56,5 @@ PYTHONPATH='.':$PYTHONPATH mim train mmseg configs/upernet/upernet_swin-t_512x51 | Backbone | Method | Lr Schd | Bbox mAP | Mask mAP| Config | Download | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | -| Swin-T | Mask R-CNN | 1x| 42.6| 39.5 |[config](swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py) | [model]() | [log]() | -| Swin-T | Mask R-CNN | FP16 1x| 42.5|39.3 |[config](swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py) | [model]() | [log]() | +| Swin-T | Mask R-CNN | 1x| 42.6| 39.5 |[config](swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_1x_coco_20210612_135948-bf3d7aa4.pth) | [log](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_1x_coco_20210612_135948.log.json) | +| Swin-T | Mask R-CNN | FP16 1x| 42.5|39.3 |[config](swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py) | [model](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco_20210612_135948-6434d76f.pth) | [log](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco_20210612_135948.log.json) | diff --git a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py index c738341..df81625 100644 --- a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py +++ b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py @@ -44,5 +44,5 @@ custom_imports = dict( imports=['swin.swin_transformer'], allow_failed_imports=False) - -fp16 = dict(loss_scale=dict(init_scale=512, mode='dynamic')) +# you need to set mode='dynamic' if you are using pytorch<=1.5.0 +fp16 = dict(loss_scale=dict(init_scale=512)) diff --git a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py index 7e40bb1..3929501 100644 --- a/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py +++ b/swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py @@ -91,5 +91,5 @@ custom_imports = dict( imports=['swin.swin_transformer'], allow_failed_imports=False) - -fp16 = dict(loss_scale=dict(init_scale=512, mode='dynamic')) +# you need to set mode='dynamic' if you are using pytorch<=1.5.0 +fp16 = dict(loss_scale=dict(init_scale=512)) From fe39b64f317373fc94f21630dcfbf023fa1e81dc Mon Sep 17 00:00:00 2001 From: ZwwWayne Date: Mon, 14 Jun 2021 10:56:24 +0800 Subject: [PATCH 07/10] update links --- swin_transformer/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/swin_transformer/README.md b/swin_transformer/README.md index a241526..b802de9 100644 --- a/swin_transformer/README.md +++ b/swin_transformer/README.md @@ -50,11 +50,11 @@ PYTHONPATH='.':$PYTHONPATH mim train mmseg configs/upernet/upernet_swin-t_512x51 | Backbone | Method | Crop Size | Lr Schd | mIoU | Config | Download | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | -| Swin-T | UPerNet | 512x512 | 160K | 44.3 | [config](swin_transformer/configs/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py) | [model]() | [log]() | +| Swin-T | UPerNet | 512x512 | 160K | 44.3 | [config](/configs/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py) | [model]() | [log]() | ### COCO | Backbone | Method | Lr Schd | Bbox mAP | Mask mAP| Config | Download | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | -| Swin-T | Mask R-CNN | 1x| 42.6| 39.5 |[config](swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_1x_coco_20210612_135948-bf3d7aa4.pth) | [log](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_1x_coco_20210612_135948.log.json) | -| Swin-T | Mask R-CNN | FP16 1x| 42.5|39.3 |[config](swin_transformer/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py) | [model](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco_20210612_135948-6434d76f.pth) | [log](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco_20210612_135948.log.json) | +| Swin-T | Mask R-CNN | 1x| 42.6| 39.5 |[config](/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_1x_coco_20210612_135948-bf3d7aa4.pth) | [log](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_1x_coco_20210612_135948.log.json) | +| Swin-T | Mask R-CNN | FP16 1x| 42.5|39.3 |[config](/configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py) | [model](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco_20210612_135948-6434d76f.pth) | [log](https://download.openmmlab.com/mim-example/swin_transformer/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco_20210612_135948.log.json) | From 16faed47955c39dc976d0355ded14556a3fb0415 Mon Sep 17 00:00:00 2001 From: ZwwWayne Date: Mon, 14 Jun 2021 10:58:35 +0800 Subject: [PATCH 08/10] update mim version --- swin_transformer/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/swin_transformer/README.md b/swin_transformer/README.md index b802de9..6b4dbc1 100644 --- a/swin_transformer/README.md +++ b/swin_transformer/README.md @@ -5,7 +5,7 @@ It implements Swin Transformer for object detection and segmentation tasks to sh ## Requirements -- MIM 0.1.0 +- MIM>=0.1.1 - MMCV-full v1.3.5 - MMDetection v2.13.0 - MMSegmentation v0.14.0 @@ -14,7 +14,7 @@ It implements Swin Transformer for object detection and segmentation tasks to sh You can install them after installing mim through the following commands ```bash -pip install openmim # install mim through pypi +pip install openmim>=0.1.1 # install mim through pypi pip install timm # swin transformer relies timm mim install mmcv-full==1.3.5 # install mmcv MKL_THREADING_LAYER=GNU mim install mmdet==2.13.0 # install mmdet to run object detection From 977e03c79f3afe357c906968741e8aea5d01e272 Mon Sep 17 00:00:00 2001 From: ZwwWayne Date: Mon, 14 Jun 2021 17:07:09 +0800 Subject: [PATCH 09/10] add more detailed description --- swin_transformer/README.md | 78 +++++++++++++++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 6 deletions(-) diff --git a/swin_transformer/README.md b/swin_transformer/README.md index 6b4dbc1..30adc2e 100644 --- a/swin_transformer/README.md +++ b/swin_transformer/README.md @@ -21,29 +21,95 @@ MKL_THREADING_LAYER=GNU mim install mmdet==2.13.0 # install mmdet to run object MKL_THREADING_LAYER=GNU mim install mmsegmentation=0.14.0 # install mmseg to run semantic segmentation ``` -**Note**: `MKL_THREADING_LAYER=GNU` is workaround according to the [issue](https://github.com/pytorch/pytorch/issues/37377). +**Note**: `MKL_THREADING_LAYER=GNU` is a workaround according to the [issue](https://github.com/pytorch/pytorch/issues/37377). ## Explaination -Because MMDetection and MMSegmentation inherits the model registry in MMCV since v2.12.0 and v0.13.0, we only need the implementation of swin transformer and add it into the model registry of MMCV. Then we can use it for object detection and segmentation by modifying configs. +Because MMDetection and MMSegmentation inherits the model registry in MMCV since v2.12.0 and v0.13.0, respectively, we only need one implementation of swin transformer and add it into the model registry of MMCV. Then we can use it for object detection and segmentation by modifying configs. -The implementation of Swin Transformer and its pre-trained models are taken from the [official implementation](https://github.com/microsoft/Swin-Transformer) + +### Step 1: implement Swin Transformer + +The implementation of Swin Transformer and its pre-trained models are taken from the [official implementation](https://github.com/microsoft/Swin-Transformer). +The key file structure is as below: + +``` +swin_transformer + |---- configs + |---- swin_mask_rcnn # config files to run with MMDetection + |---- mask_rcnn_swim-t-p4-w7_fpn_1x_coco.py + |---- mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py + |---- swin_upernet # config files to run with MMSegmentation + |---- upernet_swin-t_512x512_160k_8x2_ade20k.py + |---- swin + |---- swin_checkpoint.py # for checkout loading + |---- swin_transformer.py # implementation of swin transformer +``` + +### Step 2: register Swin Transformer into model registry + +The key step that allow MMDet and MMSeg to use a unique implementation of Swin Transformer is to register the backbone into the registry in MMCV. + +```python +from mmcv.cnn import MODELS + + +@MODELS.register_module() +class SwinTransformer(nn.Module): + # code implementation + def __init__(self, *args, **kwargs): + super().__init__() +``` + +It essentially builds a mapping as below + +```python +'SwinTransformer' -> +``` + +Because MMDetection and MMSegmentation inherits the model registry in MMCV since v2.12.0 and v0.13.0, their `MODELS` registries are under descendants of the `MODELS` registry in MMCV. Therefore, such a mapping in MMDet/MMSeg becomes + +```python +'mmcv.SwinTransformer' -> +``` + +To enable the `MODEL.build()` in MMDet/MMSeg to correctly find the implementation of `SwinTransformer`, we need to specify the scope of the module by `mmcv.SwinTransformer` as you will see in the configs. + +### Step 3: use Swin Transformer through config + +To use Swin Transformer, we can simply use the config and the build function + +```python +module_cfg = dict(type='mmcv.SwinTransformer') +module = build_backbone(module_cfg) +``` + +To run it with MMDetection or MMSegmentation, we need to define the model backbone as below + +```python +model = dict( + type='MaskRCNN', + pretrained='./pretrain/swin/swin_tiny_patch4_window7_224.pth', + backbone=dict(type='mmcv.SwinTransformer')) + +custom_imports = dict( + imports=['swin.swin_transformer'], allow_failed_imports=False) +``` ## Usages Assume now you are in the directory under `swin_transformer`, to run it with mmdet and slurm, we can use the command as below ```bash -PYTHONPATH='.':$PYTHONPATH mim train mmdet configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py --work-dir ../work_dir/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py --launcher slurm --partition $PARTITION --gpus 8 --gpus-per-node 8 --srun-args ${SRUN_ARGS} +PYTHONPATH='.':$PYTHONPATH mim train mmdet configs/swin_mask_rcnn/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py \--work-dir ../work_dir/mask_rcnn_swim-t-p4-w7_fpn_fp16_1x_coco.py --launcher slurm --partition $PARTITION --gpus 8 --gpus-per-node 8 --srun-args $SRUN_ARGS ``` To run it with mmseg, we can use the command as below ```bash -PYTHONPATH='.':$PYTHONPATH mim train mmseg configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py --work-dir ../work_dir/upernet_swin-t_512x512_160k_8x2_ade20k.py --launcher slurm --partition $PARTITION --gpus 8 --gpus-per-node 8 --srun-args ${SRUN_ARGS} +PYTHONPATH='.':$PYTHONPATH mim train mmseg configs/upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py --work-dir ../work_dir/upernet_swin-t_512x512_160k_8x2_ade20k.py --launcher slurm --partition $PARTITION --gpus 8 --gpus-per-node 8 --srun-args $SRUN_ARGS ``` - ## Results ### ADE20K From 23e0fa8191a77c383f688c99827a94cd596d0f19 Mon Sep 17 00:00:00 2001 From: wwzhang Date: Mon, 14 Jun 2021 23:05:24 +0800 Subject: [PATCH 10/10] update upernet log and model link --- swin_transformer/README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/swin_transformer/README.md b/swin_transformer/README.md index 30adc2e..2fe5927 100644 --- a/swin_transformer/README.md +++ b/swin_transformer/README.md @@ -116,8 +116,7 @@ PYTHONPATH='.':$PYTHONPATH mim train mmseg configs/upernet/upernet_swin-t_512x51 | Backbone | Method | Crop Size | Lr Schd | mIoU | Config | Download | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | -| Swin-T | UPerNet | 512x512 | 160K | 44.3 | [config](/configs/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py) | [model]() | [log]() | - +| Swin-T | UPerNet | 512x512 | 160K | 44.3 | [config](/configs/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k.py) | [model](https://download.openmmlab.com/mim-example/swin_transformer/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k_20210613_201937-4f09fb29.pth) | [log](https://download.openmmlab.com/mim-example/swin_transformer/swin_upernet/upernet_swin-t_512x512_160k_8x2_ade20k_20210613_201937.log.json) | ### COCO | Backbone | Method | Lr Schd | Bbox mAP | Mask mAP| Config | Download |