Skip to content

Commit 8323699

Browse files
authored
Merge 12def7c into 52b4fa5
2 parents 52b4fa5 + 12def7c commit 8323699

File tree

11 files changed

+470
-1
lines changed

11 files changed

+470
-1
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# dataset settings
2+
dataset_type = 'CityscapesDataset'
3+
data_root = 'data/cityscapes/'
4+
img_norm_cfg = dict(
5+
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
6+
crop_size = (512, 1024)
7+
train_pipeline = [
8+
dict(type='LoadImageFromFile'),
9+
dict(type='LoadAnnotations'),
10+
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
11+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12+
dict(type='RandomFlip', prob=0.5),
13+
dict(type='PhotoMetricDistortion'),
14+
dict(type='Normalize', **img_norm_cfg),
15+
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
16+
dict(type='DefaultFormatBundle'),
17+
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
18+
]
19+
test_pipeline = [
20+
dict(type='LoadImageFromFile'),
21+
dict(
22+
type='MultiScaleFlipAug',
23+
img_scale=(2048, 1024),
24+
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
25+
flip=False,
26+
transforms=[
27+
dict(type='Resize', keep_ratio=True),
28+
dict(type='RandomFlip'),
29+
dict(type='Normalize', **img_norm_cfg),
30+
dict(type='ImageToTensor', keys=['img']),
31+
dict(type='Collect', keys=['img']),
32+
])
33+
]
34+
data = dict(
35+
samples_per_gpu=2,
36+
workers_per_gpu=2,
37+
train=dict(
38+
type=dataset_type,
39+
data_root=data_root,
40+
img_dir='leftImg8bit/train',
41+
ann_dir='gtFine/train',
42+
pipeline=train_pipeline),
43+
val=dict(
44+
type=dataset_type,
45+
data_root=data_root,
46+
img_dir='leftImg8bit/val',
47+
ann_dir='gtFine/val',
48+
pipeline=test_pipeline),
49+
test=dict(
50+
type=dataset_type,
51+
data_root=data_root,
52+
img_dir='leftImg8bit/val',
53+
ann_dir='gtFine/val',
54+
pipeline=test_pipeline))

configs/_base_/models/sfnet_r50-d8.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# model settings
2+
norm_cfg = dict(type='SyncBN', requires_grad=True)
3+
model = dict(
4+
type='EncoderDecoder',
5+
pretrained='open-mmlab://resnet50_v1c',
6+
backbone=dict(
7+
type='ResNetV1c',
8+
depth=50,
9+
num_stages=4,
10+
out_indices=(0, 1, 2, 3),
11+
dilations=(1, 1, 2, 4),
12+
strides=(1, 2, 2, 2),
13+
norm_cfg=norm_cfg,
14+
norm_eval=False,
15+
style='pytorch',
16+
contract_dilation=False),
17+
decode_head=dict(
18+
type='SFNetHead',
19+
in_channels=2048,
20+
in_index=3,
21+
channels=256,
22+
pool_scales=(1, 2, 3, 6),
23+
fpn_inplanes=[256, 512, 1024, 2048],
24+
fpn_dim=256,
25+
dropout_ratio=0,
26+
num_classes=19,
27+
norm_cfg=norm_cfg,
28+
align_corners=False,
29+
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
30+
loss_decode=dict(
31+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
32+
33+
# model training and testing settings
34+
train_cfg=dict(),
35+
test_cfg=dict(mode='whole'))
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# model settings
2+
norm_cfg = dict(type='SyncBN', requires_grad=True)
3+
model = dict(
4+
type='EncoderDecoder',
5+
pretrained=None,
6+
backbone=dict(
7+
type='ResNetV1d',
8+
depth=50,
9+
num_stages=4,
10+
out_indices=(0, 1, 2, 3),
11+
dilations=(1, 1, 2, 4),
12+
strides=(1, 2, 1, 1),
13+
norm_cfg=norm_cfg,
14+
norm_eval=False,
15+
style='pytorch',
16+
contract_dilation=False),
17+
decode_head=dict(
18+
type='SFNetHead',
19+
in_channels=2048,
20+
in_index=3,
21+
channels=256,
22+
pool_scales=(1, 2, 3, 6),
23+
dropout_ratio=0,
24+
num_classes=19,
25+
norm_cfg=norm_cfg,
26+
align_corners=False,
27+
loss_decode=dict(
28+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
29+
30+
# model training and testing settings
31+
train_cfg=dict(),
32+
test_cfg=dict(mode='whole'))
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
_base_ = './sfnet_r50-d32_512x1024_80k_cityscapes.py'
2+
model = dict(
3+
pretrained='open-mmlab://resnet18_v1c',
4+
backbone=dict(depth=18),
5+
decode_head=dict(
6+
in_channels=512,
7+
channels=128,
8+
fpn_inplanes=[64, 128, 256, 512],
9+
fpn_dim=128,
10+
),
11+
)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
_base_ = [
2+
'../_base_/models/sfnet_r50-d8.py', '../_base_/datasets/cityscapes.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
4+
]

configs/sfnet/sfnet_temp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
_base_ = [
2+
'../_base_/models/sfnet_r50-d8_pd.py',
3+
'../_base_/datasets/cityscapes_pd.py', '../_base_/default_runtime.py',
4+
'../_base_/schedules/schedule_80k.py'
5+
]

configs/sfnet/sfnet_temp18.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
_base_ = './sfnet_temp.py'
2+
model = dict(
3+
pretrained=None,
4+
backbone=dict(type='ResNetV1c', depth=18, strides=(1, 2, 2, 2)),
5+
decode_head=dict(
6+
in_channels=512,
7+
channels=128,
8+
fpn_inplanes=[64, 128, 256, 512],
9+
fpn_dim=128,
10+
),
11+
)

mmseg/models/decode_heads/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
from .sep_fcn_head import DepthwiseSeparableFCNHead
2121
from .setr_mla_head import SETRMLAHead
2222
from .setr_up_head import SETRUPHead
23+
from .sfnet_head import SFNetHead
2324
from .uper_head import UPerHead
2425

2526
__all__ = [
2627
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
2728
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
2829
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
29-
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 'SETRMLAHead'
30+
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
31+
'SETRMLAHead', 'SFNetHead'
3032
]
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import torch
2+
import torch.nn as nn
3+
from mmcv.cnn import ConvModule
4+
from mmcv.runner.base_module import BaseModule
5+
6+
from mmseg.ops import resize
7+
from ..builder import HEADS
8+
from .decode_head import BaseDecodeHead
9+
from .psp_head import PPM
10+
11+
12+
@HEADS.register_module()
13+
class SFNetHead(BaseDecodeHead):
14+
"""Semantic Flow for Fast and Accurate SceneParsing.
15+
16+
This head is the implementation of
17+
`SFSegNet <https://arxiv.org/pdf/2002.10120>`_.
18+
19+
Args:
20+
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
21+
Module. Default: (1, 2, 3, 6).
22+
fpn_inplanes (list):
23+
The list of feature channels number from backbone.
24+
fpn_dim (int, optional):
25+
The input channels of FAM module.
26+
Default: 256 for ResNet50, 128 for ResNet18.
27+
"""
28+
29+
def __init__(self,
30+
pool_scales=(1, 2, 3, 6),
31+
fpn_inplanes=[256, 512, 1024, 2048],
32+
fpn_dim=256,
33+
**kwargs):
34+
super(SFNetHead, self).__init__(**kwargs)
35+
assert isinstance(pool_scales, (list, tuple))
36+
self.pool_scales = pool_scales
37+
self.fpn_inplanes = fpn_inplanes
38+
self.fpn_dim = fpn_dim
39+
self.psp_modules = PPM(
40+
self.pool_scales,
41+
self.in_channels,
42+
self.in_channels // 4,
43+
bias=True,
44+
conv_cfg=self.conv_cfg,
45+
norm_cfg=self.norm_cfg,
46+
act_cfg=self.act_cfg,
47+
align_corners=True)
48+
self.bottleneck = ConvModule(
49+
self.in_channels * 2,
50+
self.channels,
51+
3,
52+
padding=1,
53+
bias=True,
54+
conv_cfg=self.conv_cfg,
55+
norm_cfg=self.norm_cfg,
56+
act_cfg=self.act_cfg)
57+
58+
self.fpn_in = []
59+
for fpn_inplane in self.fpn_inplanes[:-1]:
60+
self.fpn_in.append(
61+
ConvModule(
62+
fpn_inplane,
63+
self.fpn_dim,
64+
kernel_size=1,
65+
bias=True,
66+
conv_cfg=self.conv_cfg,
67+
norm_cfg=self.norm_cfg,
68+
act_cfg=self.act_cfg,
69+
inplace=False))
70+
self.fpn_in = nn.ModuleList(self.fpn_in)
71+
self.fpn_out = []
72+
self.fpn_out_align = []
73+
self.dsn = []
74+
for i in range(len(self.fpn_inplanes) - 1):
75+
self.fpn_out.append(
76+
ConvModule(
77+
self.fpn_dim,
78+
self.fpn_dim,
79+
kernel_size=3,
80+
stride=1,
81+
padding=1,
82+
bias=False,
83+
conv_cfg=self.conv_cfg,
84+
norm_cfg=self.norm_cfg,
85+
act_cfg=self.act_cfg,
86+
inplace=True))
87+
self.fpn_out_align.append(
88+
AlignedModule(
89+
inplane=self.fpn_dim, outplane=self.fpn_dim // 2))
90+
91+
self.fpn_out = nn.ModuleList(self.fpn_out)
92+
self.fpn_out_align = nn.ModuleList(self.fpn_out_align)
93+
self.conv_last = ConvModule(
94+
len(self.fpn_inplanes) * self.fpn_dim,
95+
self.fpn_dim,
96+
kernel_size=3,
97+
stride=1,
98+
padding=1,
99+
bias=False,
100+
conv_cfg=self.conv_cfg,
101+
norm_cfg=self.norm_cfg,
102+
act_cfg=self.act_cfg,
103+
inplace=True)
104+
105+
def forward(self, inputs):
106+
x = self._transform_inputs(inputs)
107+
psp_outs = [x]
108+
psp_outs.extend(self.psp_modules(x)[::-1])
109+
psp_outs = torch.cat(psp_outs, dim=1)
110+
psp_out = self.bottleneck(psp_outs)
111+
112+
f = psp_out
113+
fpn_feature_list = [psp_out]
114+
115+
for i in reversed(range(len(inputs) - 1)):
116+
conv_x = inputs[i]
117+
conv_x = self.fpn_in[i](conv_x)
118+
f = self.fpn_out_align[i]([conv_x, f])
119+
f = conv_x + f
120+
fpn_feature_list.append(self.fpn_out[i](f))
121+
122+
fpn_feature_list.reverse() # [P2 - P5]
123+
output_size = fpn_feature_list[0].size()[2:]
124+
fusion_list = [fpn_feature_list[0]]
125+
126+
for i in range(1, len(fpn_feature_list)):
127+
fusion_list.append(
128+
nn.functional.interpolate(
129+
fpn_feature_list[i],
130+
output_size,
131+
mode='bilinear',
132+
align_corners=True))
133+
134+
fusion_out = torch.cat(fusion_list, 1)
135+
x = self.conv_last(fusion_out)
136+
output = self.cls_seg(x)
137+
138+
return output
139+
140+
141+
class AlignedModule(BaseModule):
142+
"""The implementation of Flow Alignment Module (FAM).
143+
144+
Args:
145+
inplane (int): The number of FAM input channles.
146+
outplane (int): The number of FAM output channles.
147+
"""
148+
149+
def __init__(self, inplane, outplane, kernel_size=3):
150+
super(AlignedModule, self).__init__()
151+
self.down_h = nn.Conv2d(inplane, outplane, 1, bias=False)
152+
self.down_l = nn.Conv2d(inplane, outplane, 1, bias=False)
153+
self.flow_make = nn.Conv2d(
154+
outplane * 2, 2, kernel_size=kernel_size, padding=1, bias=False)
155+
156+
def forward(self, x):
157+
low_feature, h_feature = x
158+
h_feature_orign = h_feature
159+
h, w = low_feature.size()[2:]
160+
size = (h, w)
161+
low_feature = self.down_l(low_feature)
162+
h_feature = self.down_h(h_feature)
163+
h_feature = resize(
164+
h_feature, size=size, mode='bilinear', align_corners=True)
165+
flow = self.flow_make(torch.cat([h_feature, low_feature], 1))
166+
h_feature = self.flow_warp(h_feature_orign, flow, size=size)
167+
168+
return h_feature
169+
170+
def flow_warp(self, input, flow, size):
171+
"""Implementation of Warp Procedure in Fig 3(b) of original paper,
172+
which is between Flow Field and High Resolution Feature Map.
173+
174+
Args:
175+
input (Tensor): High Resolution Feature Map.
176+
flow (Tensor): Semantic Flow Field that will give
177+
dynamic indication about how to align these
178+
two feature maps effectively.
179+
size (Tuple): Shape of height and width of output.
180+
181+
Returns:
182+
output (Tensor): High Resolution Feature Map after
183+
warped offset and bilinear interpolation.
184+
185+
For example, in cityscapes 1024x2048 dataset with ResNet18 config,
186+
feature map from backbone is:
187+
[[1, 64, 256, 512],
188+
[1, 128, 128, 256],
189+
[1, 256, 64, 128],
190+
[1, 512, 32, 64]]
191+
192+
Thus, its inverse shape of [input, flow, size] is:
193+
[[1, 128, 32, 64], [1, 2, 64, 128], (64, 128)],
194+
[[1, 128, 64, 128], [1, 2, 128, 256], (128, 256)], and
195+
[[1, 128, 128, 256], [1, 2, 256, 512], (256, 512)], respectively.
196+
197+
The final output is:
198+
[[1, 128, 64, 128],
199+
[1, 128, 128, 256],
200+
[1, 128, 256, 512]], respectively.
201+
"""
202+
203+
out_h, out_w = size
204+
n, c, h, w = input.size()
205+
206+
# Warped offset in grid, from -1 to 1.
207+
norm = torch.tensor([[[[out_w,
208+
out_h]]]]).type_as(input).to(input.device)
209+
h = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w)
210+
w = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1)
211+
grid = torch.cat((w.unsqueeze(2), h.unsqueeze(2)), 2)
212+
grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device)
213+
214+
# Warped grid which is corrected the flow offset.
215+
grid = grid + flow.permute(0, 2, 3, 1) / norm
216+
217+
# Sampling mechanism interpolates the values of the 4-neighbors
218+
# (top-left, top-right, bottom-left, and bottom-right) of input.
219+
output = nn.functional.grid_sample(input, grid, align_corners=True)
220+
return output

pth2kvlist.py

Whitespace-only changes.

0 commit comments

Comments
 (0)