forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support new config (open-mmlab#10566)
Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
- Loading branch information
Showing
9 changed files
with
310 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmcv.transforms import LoadImageFromFile | ||
from mmengine.dataset.sampler import DefaultSampler | ||
|
||
from mmdet.datasets import AspectRatioBatchSampler, CocoDataset | ||
from mmdet.datasets.transforms import (LoadAnnotations, PackDetInputs, | ||
RandomFlip, Resize) | ||
from mmdet.evaluation import CocoMetric | ||
|
||
# dataset settings | ||
dataset_type = CocoDataset | ||
data_root = 'data/coco/' | ||
|
||
# Example to use different file client | ||
# Method 1: simply set the data root and let the file I/O module | ||
# automatically infer from prefix (not support LMDB and Memcache yet) | ||
|
||
# data_root = 's3://openmmlab/datasets/detection/coco/' | ||
|
||
# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 | ||
# backend_args = dict( | ||
# backend='petrel', | ||
# path_mapping=dict({ | ||
# './data/': 's3://openmmlab/datasets/detection/', | ||
# 'data/': 's3://openmmlab/datasets/detection/' | ||
# })) | ||
backend_args = None | ||
|
||
train_pipeline = [ | ||
dict(type=LoadImageFromFile, backend_args=backend_args), | ||
dict(type=LoadAnnotations, with_bbox=True), | ||
dict(type=Resize, scale=(1333, 800), keep_ratio=True), | ||
dict(type=RandomFlip, prob=0.5), | ||
dict(type=PackDetInputs) | ||
] | ||
test_pipeline = [ | ||
dict(type=LoadImageFromFile, backend_args=backend_args), | ||
dict(type=Resize, scale=(1333, 800), keep_ratio=True), | ||
# If you don't have a gt annotation, delete the pipeline | ||
dict(type=LoadAnnotations, with_bbox=True), | ||
dict( | ||
type=PackDetInputs, | ||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', | ||
'scale_factor')) | ||
] | ||
train_dataloader = dict( | ||
batch_size=2, | ||
num_workers=2, | ||
persistent_workers=True, | ||
sampler=dict(type=DefaultSampler, shuffle=True), | ||
batch_sampler=dict(type=AspectRatioBatchSampler), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='annotations/instances_train2017.json', | ||
data_prefix=dict(img='train2017/'), | ||
filter_cfg=dict(filter_empty_gt=True, min_size=32), | ||
pipeline=train_pipeline, | ||
backend_args=backend_args)) | ||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=2, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='annotations/instances_val2017.json', | ||
data_prefix=dict(img='val2017/'), | ||
test_mode=True, | ||
pipeline=test_pipeline, | ||
backend_args=backend_args)) | ||
test_dataloader = val_dataloader | ||
|
||
val_evaluator = dict( | ||
type=CocoMetric, | ||
ann_file=data_root + 'annotations/instances_val2017.json', | ||
metric='bbox', | ||
format_only=False, | ||
backend_args=backend_args) | ||
test_evaluator = val_evaluator | ||
|
||
# inference on test dataset and | ||
# format the output results for submission. | ||
# test_dataloader = dict( | ||
# batch_size=1, | ||
# num_workers=2, | ||
# persistent_workers=True, | ||
# drop_last=False, | ||
# sampler=dict(type=DefaultSampler, shuffle=False), | ||
# dataset=dict( | ||
# type=dataset_type, | ||
# data_root=data_root, | ||
# ann_file=data_root + 'annotations/image_info_test-dev2017.json', | ||
# data_prefix=dict(img='test2017/'), | ||
# test_mode=True, | ||
# pipeline=test_pipeline)) | ||
# test_evaluator = dict( | ||
# type=CocoMetric, | ||
# metric='bbox', | ||
# format_only=True, | ||
# ann_file=data_root + 'annotations/image_info_test-dev2017.json', | ||
# outfile_prefix='./work_dirs/coco_detection/test') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, | ||
LoggerHook, ParamSchedulerHook) | ||
from mmengine.runner import LogProcessor | ||
from mmengine.visualization import LocalVisBackend | ||
|
||
from mmdet.engine.hooks import DetVisualizationHook | ||
from mmdet.visualization import DetLocalVisualizer | ||
|
||
default_scope = None | ||
|
||
default_hooks = dict( | ||
timer=dict(type=IterTimerHook), | ||
logger=dict(type=LoggerHook, interval=50), | ||
param_scheduler=dict(type=ParamSchedulerHook), | ||
checkpoint=dict(type=CheckpointHook, interval=1), | ||
sampler_seed=dict(type=DistSamplerSeedHook), | ||
visualization=dict(type=DetVisualizationHook)) | ||
|
||
env_cfg = dict( | ||
cudnn_benchmark=False, | ||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | ||
dist_cfg=dict(backend='nccl'), | ||
) | ||
|
||
vis_backends = [dict(type=LocalVisBackend)] | ||
visualizer = dict( | ||
type=DetLocalVisualizer, vis_backends=vis_backends, name='visualizer') | ||
log_processor = dict(type=LogProcessor, window_size=50, by_epoch=True) | ||
|
||
log_level = 'INFO' | ||
load_from = None | ||
resume = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmcv.ops import nms | ||
from torch.nn import BatchNorm2d | ||
|
||
from mmdet.models import (FPN, DetDataPreprocessor, FocalLoss, L1Loss, ResNet, | ||
RetinaHead, RetinaNet) | ||
from mmdet.models.task_modules import (AnchorGenerator, DeltaXYWHBBoxCoder, | ||
MaxIoUAssigner, PseudoSampler) | ||
|
||
# model settings | ||
model = dict( | ||
type=RetinaNet, | ||
data_preprocessor=dict( | ||
type=DetDataPreprocessor, | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True, | ||
pad_size_divisor=32), | ||
backbone=dict( | ||
type=ResNet, | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
frozen_stages=1, | ||
norm_cfg=dict(type=BatchNorm2d, requires_grad=True), | ||
norm_eval=True, | ||
style='pytorch', | ||
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), | ||
neck=dict( | ||
type=FPN, | ||
in_channels=[256, 512, 1024, 2048], | ||
out_channels=256, | ||
start_level=1, | ||
add_extra_convs='on_input', | ||
num_outs=5), | ||
bbox_head=dict( | ||
type=RetinaHead, | ||
num_classes=80, | ||
in_channels=256, | ||
stacked_convs=4, | ||
feat_channels=256, | ||
anchor_generator=dict( | ||
type=AnchorGenerator, | ||
octave_base_scale=4, | ||
scales_per_octave=3, | ||
ratios=[0.5, 1.0, 2.0], | ||
strides=[8, 16, 32, 64, 128]), | ||
bbox_coder=dict( | ||
type=DeltaXYWHBBoxCoder, | ||
target_means=[.0, .0, .0, .0], | ||
target_stds=[1.0, 1.0, 1.0, 1.0]), | ||
loss_cls=dict( | ||
type=FocalLoss, | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type=L1Loss, loss_weight=1.0)), | ||
# model training and testing settings | ||
train_cfg=dict( | ||
assigner=dict( | ||
type=MaxIoUAssigner, | ||
pos_iou_thr=0.5, | ||
neg_iou_thr=0.4, | ||
min_pos_iou=0, | ||
ignore_iof_thr=-1), | ||
sampler=dict( | ||
type=PseudoSampler), # Focal loss should use PseudoSampler | ||
allowed_border=-1, | ||
pos_weight=-1, | ||
debug=False), | ||
test_cfg=dict( | ||
nms_pre=1000, | ||
min_bbox_size=0, | ||
score_thr=0.05, | ||
nms=dict(type=nms, iou_threshold=0.5), | ||
max_per_img=100)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper | ||
from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR | ||
from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop | ||
from torch.optim.sgd import SGD | ||
|
||
# training schedule for 1x | ||
train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=12, val_interval=1) | ||
val_cfg = dict(type=ValLoop) | ||
test_cfg = dict(type=TestLoop) | ||
|
||
# learning rate | ||
param_scheduler = [ | ||
dict(type=LinearLR, start_factor=0.001, by_epoch=False, begin=0, end=500), | ||
dict( | ||
type=MultiStepLR, | ||
begin=0, | ||
end=12, | ||
by_epoch=True, | ||
milestones=[8, 11], | ||
gamma=0.1) | ||
] | ||
|
||
# optimizer | ||
optim_wrapper = dict( | ||
type=OptimWrapper, | ||
optimizer=dict(type=SGD, lr=0.02, momentum=0.9, weight_decay=0.0001)) | ||
|
||
# Default setting for scaling LR automatically | ||
# - `enable` means enable scaling LR automatically | ||
# or not by default. | ||
# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). | ||
auto_scale_lr = dict(enable=False, base_batch_size=16) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa | ||
# mmcv >= 2.0.1 | ||
# mmengine >= 0.8.0 | ||
|
||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .._base_.models.retinanet_r50_fpn import * | ||
from .._base_.datasets.coco_detection import * | ||
from .._base_.schedules.schedule_1x import * | ||
from .._base_.default_runtime import * | ||
from .retinanet_tta import * | ||
|
||
from torch.optim.sgd import SGD | ||
|
||
# optimizer | ||
optim_wrapper.update( | ||
dict(optimizer=dict(type=SGD, lr=0.01, momentum=0.9, weight_decay=0.0001))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmcv.transforms.loading import LoadImageFromFile | ||
from mmcv.transforms.processing import TestTimeAug | ||
|
||
from mmdet.datasets.transforms.formatting import PackDetInputs | ||
from mmdet.datasets.transforms.loading import LoadAnnotations | ||
from mmdet.datasets.transforms.transforms import RandomFlip, Resize | ||
from mmdet.models.test_time_augs.det_tta import DetTTAModel | ||
|
||
tta_model = dict( | ||
type=DetTTAModel, | ||
tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.5), max_per_img=100)) | ||
|
||
img_scales = [(1333, 800), (666, 400), (2000, 1200)] | ||
tta_pipeline = [ | ||
dict(type=LoadImageFromFile, backend_args=None), | ||
dict( | ||
type=TestTimeAug, | ||
transforms=[ | ||
[dict(type=Resize, scale=s, keep_ratio=True) for s in img_scales], | ||
[dict(type=RandomFlip, prob=1.), | ||
dict(type=RandomFlip, prob=0.)], | ||
[dict(type=LoadAnnotations, with_bbox=True)], | ||
[ | ||
dict( | ||
type=PackDetInputs, | ||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', | ||
'scale_factor', 'flip', 'flip_direction')) | ||
] | ||
]) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters