Skip to content

Commit 319dbda

Browse files
committed
fix conflict
2 parents 68a33ef + b8c6aaa commit 319dbda

File tree

4 files changed

+183
-14
lines changed

4 files changed

+183
-14
lines changed

configs/_base_/models/segnext.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# model settings
2+
norm_cfg = dict(type='SyncBN', requires_grad=True)
3+
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
4+
model = dict(
5+
type='EncoderDecoder',
6+
pretrained=None,
7+
backbone=dict(
8+
type='MSCAN',
9+
embed_dims=[32, 64, 160, 256],
10+
mlp_ratios=[8, 8, 4, 4],
11+
drop_rate=0.0,
12+
drop_path_rate=0.1,
13+
depths=[3, 3, 5, 2],
14+
attention_kernel_sizes=[[5], [1, 7], [1, 11], [1, 21]],
15+
attention_kernel_paddings=[2, (0, 3), (0, 5), (0, 10)],
16+
norm_cfg=dict(type='BN', requires_grad=True)),
17+
decode_head=dict(
18+
type='LightHamHead',
19+
in_channels=[64, 160, 256],
20+
in_index=[1, 2, 3],
21+
channels=256,
22+
ham_channels=256,
23+
dropout_ratio=0.1,
24+
num_classes=19,
25+
norm_cfg=ham_norm_cfg,
26+
align_corners=False,
27+
loss_decode=dict(
28+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
29+
ham_kwargs=dict(
30+
spatial=True,
31+
MD_S=1,
32+
MD_D=512,
33+
MD_R=64,
34+
train_steps=6,
35+
eval_steps=7,
36+
inv_t=100,
37+
eta=0.9,
38+
rand_init=True)),
39+
# model training and testing settings
40+
train_cfg=dict(),
41+
test_cfg=dict(mode='whole'))
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
_base_ = [
2+
'../_base_/models/segnext.py',
3+
'../_base_/default_runtime.py',
4+
]
5+
find_unused_parameters = True
6+
# model settings
7+
norm_cfg = dict(type='BN', requires_grad=True)
8+
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
9+
model = dict(
10+
type='EncoderDecoder',
11+
backbone=dict(
12+
init_cfg=dict(type='Pretrained', checkpoint='/notebooks/mscan_t.pth')),
13+
decode_head=dict(
14+
type='LightHamHead',
15+
in_channels=[64, 160, 256],
16+
in_index=[1, 2, 3],
17+
channels=256,
18+
ham_channels=256,
19+
ham_kwargs=dict(MD_R=16),
20+
dropout_ratio=0.1,
21+
num_classes=150,
22+
norm_cfg=ham_norm_cfg,
23+
align_corners=False,
24+
loss_decode=dict(
25+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
26+
# model training and testing settings
27+
train_cfg=dict(),
28+
test_cfg=dict(mode='whole'))
29+
30+
evaluation = dict(interval=8000, metric='mIoU')
31+
checkpoint_config = dict(by_epoch=False, interval=8000)
32+
# optimizer
33+
# 0.00006 is the lr for bs 16, should use 0.00006/8 as lr (need to test)
34+
optimizer = dict(
35+
type='AdamW',
36+
lr=0.00006,
37+
betas=(0.9, 0.999),
38+
weight_decay=0.01,
39+
paramwise_cfg=dict(
40+
custom_keys={
41+
'pos_block': dict(decay_mult=0.),
42+
'norm': dict(decay_mult=0.),
43+
'head': dict(lr_mult=10.)
44+
}))
45+
46+
lr_config = dict(
47+
policy='poly',
48+
warmup='linear',
49+
warmup_iters=1500,
50+
warmup_ratio=1e-6,
51+
power=1.0,
52+
min_lr=0.0,
53+
by_epoch=False)
54+
55+
dataset_type = 'ADE20KDataset'
56+
data_root = '/notebooks/ADEChallengeData2016'
57+
img_norm_cfg = dict(
58+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
59+
crop_size = (512, 512)
60+
train_pipeline = [
61+
dict(type='LoadImageFromFile'),
62+
dict(type='LoadAnnotations', reduce_zero_label=True),
63+
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
64+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
65+
dict(type='RandomFlip', prob=0.5),
66+
dict(type='PhotoMetricDistortion'),
67+
dict(type='Normalize', **img_norm_cfg),
68+
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
69+
dict(type='DefaultFormatBundle'),
70+
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
71+
]
72+
test_pipeline = [
73+
dict(type='LoadImageFromFile'),
74+
dict(
75+
type='MultiScaleFlipAug',
76+
img_scale=(2048, 512),
77+
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
78+
flip=False,
79+
transforms=[
80+
dict(type='Resize', keep_ratio=True),
81+
dict(type='ResizeToMultiple', size_divisor=32),
82+
dict(type='RandomFlip'),
83+
dict(type='Normalize', **img_norm_cfg),
84+
dict(type='ImageToTensor', keys=['img']),
85+
dict(type='Collect', keys=['img']),
86+
])
87+
]
88+
data = dict(
89+
samples_per_gpu=16,
90+
workers_per_gpu=4,
91+
train=dict(
92+
type='RepeatDataset',
93+
times=50,
94+
dataset=dict(
95+
type=dataset_type,
96+
data_root=data_root,
97+
img_dir='images/training',
98+
ann_dir='annotations/training',
99+
pipeline=train_pipeline)),
100+
val=dict(
101+
type=dataset_type,
102+
data_root=data_root,
103+
img_dir='images/validation',
104+
ann_dir='annotations/validation',
105+
pipeline=test_pipeline),
106+
test=dict(
107+
type=dataset_type,
108+
data_root=data_root,
109+
img_dir='images/validation',
110+
ann_dir='annotations/validation',
111+
pipeline=test_pipeline))
112+
113+
optimizer_config = dict()
114+
# runtime settings
115+
runner = dict(type='IterBasedRunner', max_iters=160000)
116+
checkpoint_config = dict(by_epoch=False, interval=4000)
117+
evaluation = dict(interval=4000, metric='mIoU')

mmseg/models/backbones/mscan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def forward(self, x):
256256
- H (int): Height of x.
257257
- W (int): Width of x.
258258
"""
259+
259260
x = self.proj(x)
260261
_, _, H, W = x.shape
261262
x = self.norm(x)

mmseg/models/decode_heads/ham_head.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
# Originally from https://github.com/visual-attention-network/segnext
23
import torch
34
import torch.nn as nn
45
import torch.nn.functional as F
@@ -11,22 +12,31 @@
1112

1213
class _MatrixDecomposition2DBase(nn.Module):
1314

14-
def __init__(self, args=dict()):
15+
def __init__(self,
16+
spatial=True,
17+
MD_S=1,
18+
MD_D=512,
19+
MD_R=64,
20+
train_steps=6,
21+
eval_steps=7,
22+
inv_t=100,
23+
eta=0.9,
24+
rand_init=True):
1525
super().__init__()
1626

17-
self.spatial = args.setdefault('SPATIAL', True)
27+
self.spatial = spatial
1828

19-
self.S = args.setdefault('MD_S', 1)
20-
self.D = args.setdefault('MD_D', 512)
21-
self.R = args.setdefault('MD_R', 64)
29+
self.S = MD_S
30+
self.D = MD_D
31+
self.R = MD_R
2232

23-
self.train_steps = args.setdefault('TRAIN_STEPS', 6)
24-
self.eval_steps = args.setdefault('EVAL_STEPS', 7)
33+
self.train_steps = train_steps
34+
self.eval_steps = eval_steps
2535

26-
self.inv_t = args.setdefault('INV_T', 100)
27-
self.eta = args.setdefault('ETA', 0.9)
36+
self.inv_t = inv_t
37+
self.eta = eta
2838

29-
self.rand_init = args.setdefault('RAND_INIT', True)
39+
self.rand_init = rand_init
3040

3141
print('spatial', self.spatial)
3242
print('S', self.S)
@@ -71,14 +81,14 @@ def forward(self, x, return_bases=False):
7181
D = H * W
7282
N = C // self.S
7383
x = x.view(B * self.S, N, D).transpose(1, 2)
74-
84+
cuda = x.device == torch.device('cuda')
7585
if not self.rand_init and not hasattr(self, 'bases'):
76-
bases = self._build_bases(1, self.S, D, self.R, cuda=True)
86+
bases = self._build_bases(1, self.S, D, self.R, cuda=cuda)
7787
self.register_buffer('bases', bases)
7888

7989
# (S, D, R) -> (B * S, D, R)
8090
if self.rand_init:
81-
bases = self._build_bases(B, self.S, D, self.R, cuda=True)
91+
bases = self._build_bases(B, self.S, D, self.R, cuda=cuda)
8292
else:
8393
bases = self.bases.repeat(B, 1, 1)
8494

@@ -105,7 +115,7 @@ def forward(self, x, return_bases=False):
105115
class NMF2D(_MatrixDecomposition2DBase):
106116

107117
def __init__(self, args=dict()):
108-
super().__init__(args)
118+
super().__init__(**args)
109119

110120
self.inv_t = 1
111121

0 commit comments

Comments
 (0)