-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support Side Adapter Network #3232
Merged
xiexinch
merged 67 commits into
open-mmlab:dev-1.x
from
angiecao:angiecao/add_SAN_infer
Sep 20, 2023
Merged
Changes from 58 commits
Commits
Show all changes
67 commits
Select commit
Hold shift + click to select a range
48344d2
add vit configs
angiecao 236c1fe
Add AdapterSideNetwork decodehead
angiecao 41b8db9
Add text_encoder in CLIP
angiecao 0213604
Add configs for SAN
angiecao 0ec5238
Add model covert for SAN
angiecao 5b3b9e6
fix circular import
angiecao 7de7c0c
fix bg_embed
angiecao 3fac1ec
add test units for SAN
angiecao 16662d8
Add regex, ftfy requirements
angiecao 340688f
change 'multimodel' to 'multimodal'
angiecao 1fffd59
Replace file_client with fileio
angiecao c992216
add vit configs
angiecao 75811e3
Add AdapterSideNetwork decodehead
angiecao 91dcfa1
Add text_encoder in CLIP
angiecao c7f3167
Add configs for SAN
angiecao 6c2f01e
Add model covert for SAN
angiecao c0cad61
fix circular import
angiecao 3e7bcca
fix bg_embed
angiecao 7fb34b6
add test units for SAN
angiecao 125bf17
Add regex, ftfy requirements
angiecao 7b82303
change 'multimodel' to 'multimodal'
angiecao 8f01fb4
freeze parameters
angiecao 2ed4f6c
weight init
angiecao 34a230c
add convert CLIP model into mmseg style
angiecao 7229c00
train_pipeline & optimizer & scheduler
angiecao 53573a6
add deep supervised
angiecao 3e3bbea
add multimodal dependencies
angiecao 6449203
Merge branch 'angiecao/add_SAN_infer' of github.com:angiecao/mmsegmen…
angiecao a327d70
fix test_san_head
angiecao 887ae9f
[CodeCamp2023-154] Add semantic label to the segmentation visualizati…
CastleDream 6afe346
add train loss for SAN
angiecao da69d7e
fix weights_init
angiecao 790add5
[Fix] Fix module PascalContextDataset (#3235)
angiecao 8997794
[Fix] Added ignore_index and one hot encoding for dice loss (#3237)
yeedrag 759bd7e
[CodeCamp2023-367] Add pp_mobileseg model (#3239)
Yang-Changhui 0c5db50
[Project] Support CAT-Seg from CVPR2023 (#3098)
SheffieldCao aabc3ee
[Fix] update pp_mobileseg ckpt links (#3254)
xiexinch f8ecb4b
[Doc]fix inference_segmentor to inference_model (#3261)
ooooo-create 7bec537
fix dataset configs
angiecao 03071cb
add assert & remove unused code
angiecao 5b1ae9c
fix conflits with dev-1.x
angiecao 327e757
fix conflits with dev-1.x
angiecao 487827a
merge infer to train
angiecao 48ccb01
fix loss
angiecao dd0b717
fix conflicts with add_SAN_train
angiecao 54f8912
delete unused files
angiecao a459ddf
fix no valid category of ground truth
angiecao 9a9a505
change to naive dice loss & set weight decay to 1e-4
angiecao fdf6617
add amp&clip grad
angiecao b0e9d22
fix loss_mask_ce is nan when use amp
angiecao 89600c9
fix loss_cls_ce
angiecao 16c65fc
fix error when use ViT/L-14 in clip2mmseg
angiecao a09c41a
Revert "fix loss_cls_ce"
angiecao 4954fc8
fix avg_factor when class_weight is not None
angiecao ea2abc9
change num_total_masks to float32
angiecao 9362e12
add assertion for InstanceData in match_cost
angiecao d69dc51
add test units for train SAN
angiecao f3e69a6
change batchsize for train
angiecao 28b741f
remove unused & add reference link & change LayerNorm to LayerNorm2d
angiecao 8a18f37
Merge branch 'dev-1.x' into angiecao/add_SAN_infer
angiecao e45845d
fix pre-commit error
angiecao fc5424e
combine two cross attention methods into one
angiecao da670c7
fix docstring
angiecao d30d37f
Merge branch 'angiecao/add_SAN_infer' of github.com:angiecao/mmsegmen…
angiecao 8e50943
modify pretrained_Part load method & replace reduce_mean with all_reduce
angiecao 17a4279
add readme.md
angiecao 75a6338
add readme.md
angiecao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# model settings | ||
norm_cfg = dict(type='SyncBN', requires_grad=True) | ||
|
||
data_preprocessor = dict( | ||
type='SegDataPreProcessor', | ||
mean=[122.7709, 116.7460, 104.0937], | ||
std=[68.5005, 66.6322, 70.3232], | ||
bgr_to_rgb=True, | ||
pad_val=0, | ||
seg_pad_val=255, | ||
size_divisor=640, | ||
test_cfg=dict(size_divisor=32)) | ||
|
||
num_classes = 171 | ||
model = dict( | ||
type='MultimodalEncoderDecoder', | ||
data_preprocessor=data_preprocessor, | ||
pretrained='pretrain/clip_vit_base_patch16_224.pth', | ||
asymetric_input=True, | ||
encoder_resolution=0.5, | ||
image_encoder=dict( | ||
type='VisionTransformer', | ||
img_size=(224, 224), | ||
patch_size=16, | ||
patch_pad=0, | ||
in_channels=3, | ||
embed_dims=768, | ||
num_layers=9, | ||
num_heads=12, | ||
mlp_ratio=4, | ||
out_origin=True, | ||
out_indices=(2, 5, 8), | ||
qkv_bias=True, | ||
drop_rate=0.0, | ||
attn_drop_rate=0.0, | ||
drop_path_rate=0.0, | ||
with_cls_token=True, | ||
output_cls_token=True, | ||
patch_bias=False, | ||
pre_norm=True, | ||
norm_cfg=dict(type='LN', eps=1e-5), | ||
act_cfg=dict(type='QuickGELU'), | ||
norm_eval=False, | ||
interpolate_mode='bicubic', | ||
frozen_exclude=['pos_embed']), | ||
text_encoder=dict( | ||
type='CLIPTextEncoder', | ||
dataset_name=None, | ||
templates='vild', | ||
embed_dims=512, | ||
num_layers=12, | ||
num_heads=8, | ||
mlp_ratio=4, | ||
output_dims=512, | ||
cache_feature=True, | ||
cat_bg=True, | ||
norm_cfg=dict(type='LN', eps=1e-5) | ||
), | ||
decode_head=dict( | ||
type='SideAdapterCLIPHead', | ||
num_classes=num_classes, | ||
deep_supervision_idxs=[7], | ||
san_cfg=dict( | ||
in_channels=3, | ||
clip_channels=768, | ||
embed_dims=240, | ||
patch_size=16, | ||
patch_bias=True, | ||
num_queries=100, | ||
cfg_encoder=dict( | ||
num_encode_layer=8, | ||
num_heads=6, | ||
mlp_ratio=4 | ||
), | ||
fusion_index=[0, 1, 2, 3], | ||
cfg_decoder=dict( | ||
num_heads=12, | ||
num_layers=1, | ||
embed_channels=256, | ||
mlp_channels=256, | ||
num_mlp=3, | ||
rescale=True), | ||
norm_cfg=dict(type='LN', eps=1e-6), | ||
), | ||
maskgen_cfg=dict( | ||
sos_token_format='cls_token', | ||
sos_token_num=100, | ||
cross_attn=False, | ||
num_layers=3, | ||
embed_dims=768, | ||
num_heads=12, | ||
mlp_ratio=4, | ||
qkv_bias=True, | ||
out_dims=512, | ||
final_norm=True, | ||
act_cfg=dict(type='QuickGELU'), | ||
norm_cfg=dict(type='LN', eps=1e-5), | ||
frozen_exclude=[] | ||
), | ||
align_corners=False, | ||
train_cfg=dict( | ||
num_points=12544, | ||
oversample_ratio=3.0, | ||
importance_sample_ratio=0.75, | ||
assigner=dict( | ||
type='HungarianAssigner', | ||
match_costs=[ | ||
dict(type='ClassificationCost', weight=2.0), | ||
dict( | ||
type='CrossEntropyLossCost', | ||
weight=5.0, | ||
use_sigmoid=True), | ||
dict( | ||
type='DiceCost', | ||
weight=5.0, | ||
pred_act=True, | ||
eps=1.0) | ||
])), | ||
loss_decode=[dict(type='CrossEntropyLoss', | ||
loss_name='loss_cls_ce', | ||
loss_weight=2.0, | ||
class_weight=[1.0] * num_classes + [0.1]), | ||
dict(type='CrossEntropyLoss', | ||
use_sigmoid=True, | ||
loss_name='loss_mask_ce', | ||
loss_weight=5.0), | ||
dict(type='DiceLoss', | ||
ignore_index=None, | ||
naive_dice=True, | ||
eps=1, | ||
loss_name='loss_mask_dice', | ||
loss_weight=5.0) | ||
]), | ||
|
||
# model training and testing settings | ||
train_cfg=dict(), | ||
test_cfg=dict(mode='whole')) # yapf: disable |
200 changes: 200 additions & 0 deletions
200
configs/mask2former/mask2former_r50_8xb2-160k_vaihingen-512x512.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
_base_ = ['../_base_/default_runtime.py', '../_base_/datasets/vaihingen.py'] | ||
|
||
custom_imports = dict(imports='mmdet.models', allow_failed_imports=False) | ||
|
||
crop_size = (512, 512) | ||
data_preprocessor = dict( | ||
type='SegDataPreProcessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True, | ||
pad_val=0, | ||
seg_pad_val=255, | ||
size=crop_size, | ||
test_cfg=dict(size_divisor=32)) | ||
num_classes = 6 | ||
model = dict( | ||
type='EncoderDecoder', | ||
data_preprocessor=data_preprocessor, | ||
backbone=dict( | ||
type='ResNet', | ||
depth=50, | ||
deep_stem=False, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
frozen_stages=-1, | ||
norm_cfg=dict(type='SyncBN', requires_grad=False), | ||
style='pytorch', | ||
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), | ||
decode_head=dict( | ||
type='Mask2FormerHead', | ||
in_channels=[256, 512, 1024, 2048], | ||
strides=[4, 8, 16, 32], | ||
feat_channels=256, | ||
out_channels=256, | ||
num_classes=num_classes, | ||
num_queries=100, | ||
num_transformer_feat_level=3, | ||
align_corners=False, | ||
pixel_decoder=dict( | ||
type='mmdet.MSDeformAttnPixelDecoder', | ||
num_outs=3, | ||
norm_cfg=dict(type='GN', num_groups=32), | ||
act_cfg=dict(type='ReLU'), | ||
encoder=dict( # DeformableDetrTransformerEncoder | ||
num_layers=6, | ||
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer | ||
self_attn_cfg=dict( # MultiScaleDeformableAttention | ||
embed_dims=256, | ||
num_heads=8, | ||
num_levels=3, | ||
num_points=4, | ||
im2col_step=64, | ||
dropout=0.0, | ||
batch_first=True, | ||
norm_cfg=None, | ||
init_cfg=None), | ||
ffn_cfg=dict( | ||
embed_dims=256, | ||
feedforward_channels=1024, | ||
num_fcs=2, | ||
ffn_drop=0.0, | ||
act_cfg=dict(type='ReLU', inplace=True))), | ||
init_cfg=None), | ||
positional_encoding=dict( # SinePositionalEncoding | ||
num_feats=128, normalize=True), | ||
init_cfg=None), | ||
enforce_decoder_input_project=False, | ||
positional_encoding=dict( # SinePositionalEncoding | ||
num_feats=128, normalize=True), | ||
transformer_decoder=dict( # Mask2FormerTransformerDecoder | ||
return_intermediate=True, | ||
num_layers=9, | ||
layer_cfg=dict( # Mask2FormerTransformerDecoderLayer | ||
self_attn_cfg=dict( # MultiheadAttention | ||
embed_dims=256, | ||
num_heads=8, | ||
attn_drop=0.0, | ||
proj_drop=0.0, | ||
dropout_layer=None, | ||
batch_first=True), | ||
cross_attn_cfg=dict( # MultiheadAttention | ||
embed_dims=256, | ||
num_heads=8, | ||
attn_drop=0.0, | ||
proj_drop=0.0, | ||
dropout_layer=None, | ||
batch_first=True), | ||
ffn_cfg=dict( | ||
embed_dims=256, | ||
feedforward_channels=2048, | ||
num_fcs=2, | ||
act_cfg=dict(type='ReLU', inplace=True), | ||
ffn_drop=0.0, | ||
dropout_layer=None, | ||
add_identity=True)), | ||
init_cfg=None), | ||
loss_cls=dict( | ||
type='mmdet.CrossEntropyLoss', | ||
use_sigmoid=False, | ||
loss_weight=2.0, | ||
reduction='mean', | ||
class_weight=[1.0] * num_classes + [0.1]), | ||
loss_mask=dict( | ||
type='mmdet.CrossEntropyLoss', | ||
use_sigmoid=True, | ||
reduction='mean', | ||
loss_weight=5.0), | ||
loss_dice=dict( | ||
type='mmdet.DiceLoss', | ||
use_sigmoid=True, | ||
activate=True, | ||
reduction='mean', | ||
naive_dice=True, | ||
eps=1.0, | ||
loss_weight=5.0), | ||
train_cfg=dict( | ||
num_points=12544, | ||
oversample_ratio=3.0, | ||
importance_sample_ratio=0.75, | ||
assigner=dict( | ||
type='mmdet.HungarianAssigner', | ||
match_costs=[ | ||
dict(type='mmdet.ClassificationCost', weight=2.0), | ||
dict( | ||
type='mmdet.CrossEntropyLossCost', | ||
weight=5.0, | ||
use_sigmoid=True), | ||
dict( | ||
type='mmdet.DiceCost', | ||
weight=5.0, | ||
pred_act=True, | ||
eps=1.0) | ||
]), | ||
sampler=dict(type='mmdet.MaskPseudoSampler'))), | ||
train_cfg=dict(), | ||
test_cfg=dict(mode='whole')) | ||
|
||
# dataset config | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadAnnotations', reduce_zero_label=True), | ||
dict( | ||
type='RandomChoiceResize', | ||
scales=[int(512 * x * 0.1) for x in range(5, 21)], | ||
resize_type='ResizeShortestEdge', | ||
max_size=2048), | ||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PhotoMetricDistortion'), | ||
dict(type='PackSegInputs') | ||
] | ||
train_dataloader = dict(batch_size=2, dataset=dict(pipeline=train_pipeline)) | ||
|
||
# optimizer | ||
embed_multi = dict(lr_mult=1.0, decay_mult=0.0) | ||
optimizer = dict( | ||
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999)) | ||
optim_wrapper = dict( | ||
type='OptimWrapper', | ||
optimizer=optimizer, | ||
clip_grad=dict(max_norm=0.01, norm_type=2), | ||
paramwise_cfg=dict( | ||
custom_keys={ | ||
'backbone': dict(lr_mult=0.1, decay_mult=1.0), | ||
'query_embed': embed_multi, | ||
'query_feat': embed_multi, | ||
'level_embed': embed_multi, | ||
}, | ||
norm_decay_mult=0.0)) | ||
# learning policy | ||
param_scheduler = [ | ||
dict( | ||
type='PolyLR', | ||
eta_min=0, | ||
power=0.9, | ||
begin=0, | ||
end=160000, | ||
by_epoch=False) | ||
] | ||
|
||
# training schedule for 160k | ||
train_cfg = dict( | ||
type='IterBasedTrainLoop', max_iters=160000, val_interval=5000) | ||
val_cfg = dict(type='ValLoop') | ||
test_cfg = dict(type='TestLoop') | ||
default_hooks = dict( | ||
timer=dict(type='IterTimerHook'), | ||
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), | ||
param_scheduler=dict(type='ParamSchedulerHook'), | ||
checkpoint=dict( | ||
type='CheckpointHook', by_epoch=False, interval=5000, | ||
save_best='mIoU'), | ||
sampler_seed=dict(type='DistSamplerSeedHook'), | ||
visualization=dict(type='SegVisualizationHook')) | ||
|
||
# Default setting for scaling LR automatically | ||
# - `enable` means enable scaling LR automatically | ||
# or not by default. | ||
# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). | ||
auto_scale_lr = dict(enable=False, base_batch_size=16) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might remove this file.