Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add SFSegNet head #733

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions configs/_base_/models/sfnet_r50-d8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 2, 2),
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=False),
decode_head=dict(
type='SFNetHead',
in_channels=2048,
in_index=3,
channels=256,
pool_scales=(1, 2, 3, 6),
fpn_inplanes=[256, 512, 1024, 2048],
fpn_dim=256,
dropout_ratio=0,
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),

# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
11 changes: 11 additions & 0 deletions configs/sfnet/sfnet_r18-d8_512x1024_80k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = './sfnet_r50-d8_512x1024_80k_cityscapes.py'
model = dict(
pretrained='open-mmlab://resnet18_v1c',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check whether the r18 is v1c or v1d.

backbone=dict(depth=18),
decode_head=dict(
in_channels=512,
channels=128,
fpn_inplanes=[64, 128, 256, 512],
fpn_dim=128,
),
)
4 changes: 4 additions & 0 deletions configs/sfnet/sfnet_r50-d8_512x1024_80k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = [
'../_base_/models/sfnet_r50-d8.py', '../_base_/datasets/cityscapes.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
4 changes: 3 additions & 1 deletion mmseg/models/decode_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from .sep_fcn_head import DepthwiseSeparableFCNHead
from .setr_mla_head import SETRMLAHead
from .setr_up_head import SETRUPHead
from .sfnet_head import SFNetHead
from .uper_head import UPerHead

__all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 'SETRMLAHead'
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'SFNetHead'
]
206 changes: 206 additions & 0 deletions mmseg/models/decode_heads/sfnet_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner.base_module import BaseModule

from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
from .psp_head import PPM


@HEADS.register_module()
class SFNetHead(BaseDecodeHead):
"""Semantic Flow for Fast and Accurate SceneParsing.

This head is the implementation of
`SFSegNet <https://arxiv.org/pdf/2002.10120>`_.

Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module. Default: (1, 2, 3, 6).
fpn_inplanes (list):
The list of feature channels number from backbone.
fpn_dim (int, optional):
The input channels of FAM module.
Default: 256 for ResNet50, 128 for ResNet18.
"""

def __init__(self,
pool_scales=(1, 2, 3, 6),
fpn_inplanes=[256, 512, 1024, 2048],
fpn_dim=256,
**kwargs):
super(SFNetHead, self).__init__(**kwargs)
assert isinstance(pool_scales, (list, tuple))
self.pool_scales = pool_scales
self.fpn_inplanes = fpn_inplanes
self.fpn_dim = fpn_dim
self.psp_modules = PPM(
self.pool_scales,
self.in_channels,
self.channels * 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.bottleneck = ConvModule(
self.in_channels + len(pool_scales) * self.channels * 2,
self.channels,
3,
padding=1,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)

self.fpn_in = []
for fpn_inplane in self.fpn_inplanes[:-1]:
self.fpn_in.append(
nn.Sequential(
nn.Conv2d(fpn_inplane, self.fpn_dim, 1),
nn.BatchNorm2d(self.fpn_dim), nn.ReLU(inplace=False)))
self.fpn_in = nn.ModuleList(self.fpn_in)
self.fpn_out = []
self.fpn_out_align = []
self.dsn = []
for i in range(len(self.fpn_inplanes) - 1):
self.fpn_out.append(
nn.Sequential(
nn.Conv2d(
self.fpn_dim,
self.fpn_dim,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(self.fpn_dim),
nn.ReLU(inplace=True),
))
self.fpn_out_align.append(
AlignedModule(
inplane=self.fpn_dim, outplane=self.fpn_dim // 2))

self.fpn_out = nn.ModuleList(self.fpn_out)
self.fpn_out_align = nn.ModuleList(self.fpn_out_align)
self.conv_last = nn.Sequential(
nn.Conv2d(
len(self.fpn_inplanes) * self.fpn_dim,
self.fpn_dim,
kernel_size=3,
stride=1,
padding=1,
bias=False), nn.BatchNorm2d(self.fpn_dim),
nn.ReLU(inplace=True))

def forward(self, inputs):
x = self._transform_inputs(inputs)
psp_outs = [x]
psp_outs.extend(self.psp_modules(x)[::-1])
psp_outs = torch.cat(psp_outs, dim=1)
psp_out = self.bottleneck(psp_outs)

f = psp_out
fpn_feature_list = [psp_out]

for i in reversed(range(len(inputs) - 1)):
conv_x = inputs[i]
conv_x = self.fpn_in[i](conv_x)
f = self.fpn_out_align[i]([conv_x, f])
f = conv_x + f
fpn_feature_list.append(self.fpn_out[i](f))

fpn_feature_list.reverse() # [P2 - P5]
output_size = fpn_feature_list[0].size()[2:]
fusion_list = [fpn_feature_list[0]]

for i in range(1, len(fpn_feature_list)):
fusion_list.append(
nn.functional.interpolate(
fpn_feature_list[i],
output_size,
mode='bilinear',
align_corners=True))

fusion_out = torch.cat(fusion_list, 1)
x = self.conv_last(fusion_out)
output = self.cls_seg(x)

return output


class AlignedModule(BaseModule):
"""The implementation of Flow Alignment Module (FAM).

Args:
inplane (int): The number of FAM input channles.
outplane (int): The number of FAM output channles.
"""

def __init__(self, inplane, outplane, kernel_size=3):
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
super(AlignedModule, self).__init__()
self.down_h = nn.Conv2d(inplane, outplane, 1, bias=False)
self.down_l = nn.Conv2d(inplane, outplane, 1, bias=False)
self.flow_make = nn.Conv2d(
outplane * 2, 2, kernel_size=kernel_size, padding=1, bias=False)

def forward(self, x):
low_feature, h_feature = x
h_feature_orign = h_feature
h, w = low_feature.size()[2:]
size = (h, w)
low_feature = self.down_l(low_feature)
h_feature = self.down_h(h_feature)
h_feature = resize(
h_feature, size=size, mode='bilinear', align_corners=True)
flow = self.flow_make(torch.cat([h_feature, low_feature], 1))
h_feature = self.flow_warp(h_feature_orign, flow, size=size)

return h_feature

def flow_warp(self, input, flow, size):
"""Implementation of Warp Procedure in Fig 3(b) of original paper,
which is between Flow Field and High Resolution Feature Map.

Args:
input (Tensor): High Resolution Feature Map.
flow (Tensor): Semantic Flow Field that will give dynamic
indication about how to align these two feature maps effectively.
size (Tuple): Shape of height and width of output.

For example, in cityscapes 1024x2048 dataset with ResNet18 config,
feature map from backbone is:
[[1, 64, 256, 512],
[1, 128, 128, 256],
[1, 256, 64, 128],
[1, 512, 32, 64]]

Thus, its inverse shape of [input, flow, size] is:
[[1, 128, 32, 64], [1, 2, 64, 128], (64, 128)],
[[1, 128, 64, 128], [1, 2, 128, 256], (128, 256)], and
[[1, 128, 128, 256], [1, 2, 256, 512], (256, 512)], respectively.

The final output is:
[[1, 128, 64, 128],
[1, 128, 128, 256],
[1, 128, 256, 512]], respectively.
"""

out_h, out_w = size
n, c, h, w = input.size()

# Warped offset in grid, from -1 to 1.
norm = torch.tensor([[[[out_w,
out_h]]]]).type_as(input).to(input.device)
h = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w)
w = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1)
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
grid = torch.cat((w.unsqueeze(2), h.unsqueeze(2)), 2)
grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device)

# Warped grid which is corrected the flow offset.
grid = grid + flow.permute(0, 2, 3, 1) / norm

# Sampling mechanism interpolates the values of the 4-neighbors
# (top-left, top-right, bottom-left, and bottom-right) of input.
output = nn.functional.grid_sample(input, grid, align_corners=True)
return output
95 changes: 95 additions & 0 deletions tests/test_models/test_heads/test_sfnet_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pytest
import torch

from mmseg.models.decode_heads import SFNetHead
from .utils import _conv_has_norm, to_cuda


def test_sfnet_head():
# test the config of ResNet50 backbone
with pytest.raises(AssertionError):
# pool_scales must be list|tuple
SFNetHead(
in_channels=2048, channels=256, num_classes=19, pool_scales=1)

# test no norm_cfg
head = SFNetHead(in_channels=2048, channels=256, num_classes=19)
assert not _conv_has_norm(head, sync_bn=False)

# test with norm_cfg
head = SFNetHead(
in_channels=2048,
channels=256,
num_classes=19,
norm_cfg=dict(type='SyncBN'))
assert _conv_has_norm(head, sync_bn=True)

inputs = [
torch.randn(1, 256, 45, 45),
torch.randn(1, 512, 45, 45),
torch.randn(1, 1024, 45, 45),
torch.randn(1, 2048, 45, 45)
]
head = SFNetHead(
in_channels=2048,
channels=256,
num_classes=19,
pool_scales=(1, 2, 3, 6))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
assert head.psp_modules[0][0].output_size == 1
assert head.psp_modules[1][0].output_size == 2
assert head.psp_modules[2][0].output_size == 3
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 45, 45)
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved

# test the config of ResNet18 backbone
with pytest.raises(AssertionError):
# pool_scales must be list|tuple
SFNetHead(
in_channels=512,
channels=128,
num_classes=19,
pool_scales=1,
fpn_inplanes=[64, 128, 256, 512],
fpn_dim=128)

# test no norm_cfg
head = SFNetHead(
in_channels=512,
channels=128,
num_classes=19,
fpn_inplanes=[64, 128, 256, 512],
fpn_dim=128)
assert not _conv_has_norm(head, sync_bn=False)

# test with norm_cfg
head = SFNetHead(
in_channels=512,
channels=128,
num_classes=19,
norm_cfg=dict(type='SyncBN'),
fpn_inplanes=[64, 128, 256, 512],
fpn_dim=128)
assert _conv_has_norm(head, sync_bn=True)

inputs = [
torch.randn(1, 64, 45, 45),
torch.randn(1, 128, 45, 45),
torch.randn(1, 256, 45, 45),
torch.randn(1, 512, 45, 45)
]
head = SFNetHead(
in_channels=512,
channels=128,
num_classes=19,
pool_scales=(1, 2, 3, 6),
fpn_inplanes=[64, 128, 256, 512],
fpn_dim=128)
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
assert head.psp_modules[0][0].output_size == 1
assert head.psp_modules[1][0].output_size == 2
assert head.psp_modules[2][0].output_size == 3
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 45, 45)