-
Notifications
You must be signed in to change notification settings - Fork 2.7k
[Feature] add SFSegNet head #733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
MengzhangLI
wants to merge
20
commits into
open-mmlab:master
Choose a base branch
from
MengzhangLI:SFNet
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
e92411e
readme_link_fix
MengzhangLI b825a5a
readme_link_fix
MengzhangLI 3538fb0
readme_link_fix
MengzhangLI 626d3c4
Fix UNet FCN Download link [#415]
MengzhangLI 66dabfe
Fix UNet FCN Download link [#415]
MengzhangLI 28911ab
Fix DMNet Download link [#548]
MengzhangLI 747e342
Fix DMNet Download link [#548]
MengzhangLI 899efe3
update_to_upstream_master_branch
MengzhangLI 75bc450
Merge branch 'open-mmlab:master' into master
MengzhangLI fea70b5
Merge branch 'open-mmlab:master' into master
MengzhangLI d151a59
Merge branch 'open-mmlab:master' into master
MengzhangLI 3f9411d
add SFNet head
MengzhangLI 6f9346f
sfnet
MengzhangLI 3ab94d6
sfnet
MengzhangLI ca56cee
sfnet
MengzhangLI 9a41916
sfnet with docstring
MengzhangLI 390ff09
sfnet with docstring
MengzhangLI a0c666b
add returns in sfnet head
MengzhangLI de1757b
add returns in sfnet head
MengzhangLI 12def7c
re-implement paddleseg
MengzhangLI File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# dataset settings | ||
dataset_type = 'CityscapesDataset' | ||
data_root = 'data/cityscapes/' | ||
img_norm_cfg = dict( | ||
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) | ||
crop_size = (512, 1024) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadAnnotations'), | ||
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), | ||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PhotoMetricDistortion'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), | ||
dict(type='DefaultFormatBundle'), | ||
dict(type='Collect', keys=['img', 'gt_semantic_seg']), | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='MultiScaleFlipAug', | ||
img_scale=(2048, 1024), | ||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], | ||
flip=False, | ||
transforms=[ | ||
dict(type='Resize', keep_ratio=True), | ||
dict(type='RandomFlip'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='Collect', keys=['img']), | ||
]) | ||
] | ||
data = dict( | ||
samples_per_gpu=2, | ||
workers_per_gpu=2, | ||
train=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
img_dir='leftImg8bit/train', | ||
ann_dir='gtFine/train', | ||
pipeline=train_pipeline), | ||
val=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
img_dir='leftImg8bit/val', | ||
ann_dir='gtFine/val', | ||
pipeline=test_pipeline), | ||
test=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
img_dir='leftImg8bit/val', | ||
ann_dir='gtFine/val', | ||
pipeline=test_pipeline)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# model settings | ||
norm_cfg = dict(type='SyncBN', requires_grad=True) | ||
model = dict( | ||
type='EncoderDecoder', | ||
pretrained='open-mmlab://resnet50_v1c', | ||
backbone=dict( | ||
type='ResNetV1c', | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
dilations=(1, 1, 2, 4), | ||
strides=(1, 2, 2, 2), | ||
norm_cfg=norm_cfg, | ||
norm_eval=False, | ||
style='pytorch', | ||
contract_dilation=False), | ||
decode_head=dict( | ||
type='SFNetHead', | ||
in_channels=2048, | ||
in_index=3, | ||
channels=256, | ||
pool_scales=(1, 2, 3, 6), | ||
fpn_inplanes=[256, 512, 1024, 2048], | ||
fpn_dim=256, | ||
dropout_ratio=0, | ||
MengzhangLI marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_classes=19, | ||
norm_cfg=norm_cfg, | ||
align_corners=False, | ||
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), | ||
MengzhangLI marked this conversation as resolved.
Show resolved
Hide resolved
|
||
loss_decode=dict( | ||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), | ||
|
||
# model training and testing settings | ||
train_cfg=dict(), | ||
test_cfg=dict(mode='whole')) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# model settings | ||
norm_cfg = dict(type='SyncBN', requires_grad=True) | ||
model = dict( | ||
type='EncoderDecoder', | ||
pretrained=None, | ||
backbone=dict( | ||
type='ResNetV1d', | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
dilations=(1, 1, 2, 4), | ||
strides=(1, 2, 1, 1), | ||
norm_cfg=norm_cfg, | ||
norm_eval=False, | ||
style='pytorch', | ||
contract_dilation=False), | ||
decode_head=dict( | ||
type='SFNetHead', | ||
in_channels=2048, | ||
in_index=3, | ||
channels=256, | ||
pool_scales=(1, 2, 3, 6), | ||
dropout_ratio=0, | ||
num_classes=19, | ||
norm_cfg=norm_cfg, | ||
align_corners=False, | ||
loss_decode=dict( | ||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), | ||
|
||
# model training and testing settings | ||
train_cfg=dict(), | ||
test_cfg=dict(mode='whole')) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
_base_ = './sfnet_r50-d32_512x1024_80k_cityscapes.py' | ||
model = dict( | ||
pretrained='open-mmlab://resnet18_v1c', | ||
backbone=dict(depth=18), | ||
decode_head=dict( | ||
in_channels=512, | ||
channels=128, | ||
fpn_inplanes=[64, 128, 256, 512], | ||
fpn_dim=128, | ||
), | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
_base_ = [ | ||
'../_base_/models/sfnet_r50-d8.py', '../_base_/datasets/cityscapes.py', | ||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_base_ = [ | ||
'../_base_/models/sfnet_r50-d8_pd.py', | ||
'../_base_/datasets/cityscapes_pd.py', '../_base_/default_runtime.py', | ||
'../_base_/schedules/schedule_80k.py' | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
_base_ = './sfnet_temp.py' | ||
model = dict( | ||
pretrained=None, | ||
backbone=dict(type='ResNetV1c', depth=18, strides=(1, 2, 2, 2)), | ||
decode_head=dict( | ||
in_channels=512, | ||
channels=128, | ||
fpn_inplanes=[64, 128, 256, 512], | ||
fpn_dim=128, | ||
), | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
import torch | ||
import torch.nn as nn | ||
from mmcv.cnn import ConvModule | ||
from mmcv.runner.base_module import BaseModule | ||
|
||
from mmseg.ops import resize | ||
from ..builder import HEADS | ||
from .decode_head import BaseDecodeHead | ||
from .psp_head import PPM | ||
|
||
|
||
@HEADS.register_module() | ||
class SFNetHead(BaseDecodeHead): | ||
"""Semantic Flow for Fast and Accurate SceneParsing. | ||
|
||
This head is the implementation of | ||
`SFSegNet <https://arxiv.org/pdf/2002.10120>`_. | ||
|
||
Args: | ||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid | ||
Module. Default: (1, 2, 3, 6). | ||
fpn_inplanes (list): | ||
The list of feature channels number from backbone. | ||
fpn_dim (int, optional): | ||
The input channels of FAM module. | ||
Default: 256 for ResNet50, 128 for ResNet18. | ||
""" | ||
|
||
def __init__(self, | ||
pool_scales=(1, 2, 3, 6), | ||
fpn_inplanes=[256, 512, 1024, 2048], | ||
fpn_dim=256, | ||
**kwargs): | ||
super(SFNetHead, self).__init__(**kwargs) | ||
assert isinstance(pool_scales, (list, tuple)) | ||
self.pool_scales = pool_scales | ||
self.fpn_inplanes = fpn_inplanes | ||
self.fpn_dim = fpn_dim | ||
self.psp_modules = PPM( | ||
self.pool_scales, | ||
self.in_channels, | ||
self.in_channels // 4, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg, | ||
align_corners=True) | ||
self.bottleneck = ConvModule( | ||
self.in_channels * 2, | ||
self.channels, | ||
3, | ||
padding=1, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg) | ||
|
||
self.fpn_in = [] | ||
for fpn_inplane in self.fpn_inplanes[:-1]: | ||
self.fpn_in.append( | ||
ConvModule( | ||
fpn_inplane, | ||
self.fpn_dim, | ||
kernel_size=1, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg, | ||
inplace=False)) | ||
self.fpn_in = nn.ModuleList(self.fpn_in) | ||
self.fpn_out = [] | ||
self.fpn_out_align = [] | ||
self.dsn = [] | ||
for i in range(len(self.fpn_inplanes) - 1): | ||
self.fpn_out.append( | ||
ConvModule( | ||
self.fpn_dim, | ||
self.fpn_dim, | ||
kernel_size=3, | ||
stride=1, | ||
padding=1, | ||
bias=False, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg, | ||
inplace=True)) | ||
self.fpn_out_align.append( | ||
AlignedModule( | ||
inplane=self.fpn_dim, outplane=self.fpn_dim // 2)) | ||
|
||
self.fpn_out = nn.ModuleList(self.fpn_out) | ||
self.fpn_out_align = nn.ModuleList(self.fpn_out_align) | ||
self.conv_last = ConvModule( | ||
len(self.fpn_inplanes) * self.fpn_dim, | ||
self.fpn_dim, | ||
kernel_size=3, | ||
stride=1, | ||
padding=1, | ||
bias=False, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg, | ||
inplace=True) | ||
|
||
def forward(self, inputs): | ||
x = self._transform_inputs(inputs) | ||
psp_outs = [x] | ||
psp_outs.extend(self.psp_modules(x)[::-1]) | ||
psp_outs = torch.cat(psp_outs, dim=1) | ||
psp_out = self.bottleneck(psp_outs) | ||
|
||
f = psp_out | ||
fpn_feature_list = [psp_out] | ||
|
||
for i in reversed(range(len(inputs) - 1)): | ||
conv_x = inputs[i] | ||
conv_x = self.fpn_in[i](conv_x) | ||
f = self.fpn_out_align[i]([conv_x, f]) | ||
f = conv_x + f | ||
fpn_feature_list.append(self.fpn_out[i](f)) | ||
|
||
fpn_feature_list.reverse() # [P2 - P5] | ||
output_size = fpn_feature_list[0].size()[2:] | ||
fusion_list = [fpn_feature_list[0]] | ||
|
||
for i in range(1, len(fpn_feature_list)): | ||
fusion_list.append( | ||
nn.functional.interpolate( | ||
fpn_feature_list[i], | ||
output_size, | ||
mode='bilinear', | ||
align_corners=True)) | ||
|
||
fusion_out = torch.cat(fusion_list, 1) | ||
x = self.conv_last(fusion_out) | ||
output = self.cls_seg(x) | ||
|
||
return output | ||
|
||
|
||
class AlignedModule(BaseModule): | ||
"""The implementation of Flow Alignment Module (FAM). | ||
|
||
Args: | ||
inplane (int): The number of FAM input channles. | ||
outplane (int): The number of FAM output channles. | ||
""" | ||
|
||
def __init__(self, inplane, outplane, kernel_size=3): | ||
MengzhangLI marked this conversation as resolved.
Show resolved
Hide resolved
|
||
super(AlignedModule, self).__init__() | ||
self.down_h = nn.Conv2d(inplane, outplane, 1, bias=False) | ||
self.down_l = nn.Conv2d(inplane, outplane, 1, bias=False) | ||
self.flow_make = nn.Conv2d( | ||
outplane * 2, 2, kernel_size=kernel_size, padding=1, bias=False) | ||
|
||
def forward(self, x): | ||
low_feature, h_feature = x | ||
h_feature_orign = h_feature | ||
h, w = low_feature.size()[2:] | ||
size = (h, w) | ||
low_feature = self.down_l(low_feature) | ||
h_feature = self.down_h(h_feature) | ||
h_feature = resize( | ||
h_feature, size=size, mode='bilinear', align_corners=True) | ||
flow = self.flow_make(torch.cat([h_feature, low_feature], 1)) | ||
h_feature = self.flow_warp(h_feature_orign, flow, size=size) | ||
|
||
return h_feature | ||
|
||
def flow_warp(self, input, flow, size): | ||
"""Implementation of Warp Procedure in Fig 3(b) of original paper, | ||
which is between Flow Field and High Resolution Feature Map. | ||
|
||
Args: | ||
input (Tensor): High Resolution Feature Map. | ||
flow (Tensor): Semantic Flow Field that will give | ||
dynamic indication about how to align these | ||
two feature maps effectively. | ||
size (Tuple): Shape of height and width of output. | ||
|
||
Returns: | ||
output (Tensor): High Resolution Feature Map after | ||
warped offset and bilinear interpolation. | ||
|
||
For example, in cityscapes 1024x2048 dataset with ResNet18 config, | ||
feature map from backbone is: | ||
[[1, 64, 256, 512], | ||
[1, 128, 128, 256], | ||
[1, 256, 64, 128], | ||
[1, 512, 32, 64]] | ||
|
||
Thus, its inverse shape of [input, flow, size] is: | ||
[[1, 128, 32, 64], [1, 2, 64, 128], (64, 128)], | ||
[[1, 128, 64, 128], [1, 2, 128, 256], (128, 256)], and | ||
[[1, 128, 128, 256], [1, 2, 256, 512], (256, 512)], respectively. | ||
|
||
The final output is: | ||
[[1, 128, 64, 128], | ||
[1, 128, 128, 256], | ||
[1, 128, 256, 512]], respectively. | ||
""" | ||
|
||
out_h, out_w = size | ||
n, c, h, w = input.size() | ||
|
||
# Warped offset in grid, from -1 to 1. | ||
norm = torch.tensor([[[[out_w, | ||
out_h]]]]).type_as(input).to(input.device) | ||
h = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w) | ||
w = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1) | ||
MengzhangLI marked this conversation as resolved.
Show resolved
Hide resolved
|
||
grid = torch.cat((w.unsqueeze(2), h.unsqueeze(2)), 2) | ||
grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device) | ||
|
||
# Warped grid which is corrected the flow offset. | ||
grid = grid + flow.permute(0, 2, 3, 1) / norm | ||
|
||
# Sampling mechanism interpolates the values of the 4-neighbors | ||
# (top-left, top-right, bottom-left, and bottom-right) of input. | ||
output = nn.functional.grid_sample(input, grid, align_corners=True) | ||
return output |
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.