Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Rethinking BiSeNet (STDCSeg) to paddleseg #1305

Merged
merged 15 commits into from
Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions configs/stdcseg/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Rethinking BiSeNet For Real-time Semantic Segmentation

## Reference

> Fan, Mingyuan, et al. "Rethinking BiSeNet For Real-time Semantic Segmentation." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021.
CuberrChen marked this conversation as resolved.
Show resolved Hide resolved


## Performance

### CityScapes

| Model |Resolution | Training Iters | mIOU | Links1 |Links2 |log|
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

表格删去link1,link2,直接在最后一列links加上 model|log|vdl 即可

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model 中包含训练好模型的下载链接,log为log下载链接,backbone可以不用加入到这部分

| -------------|---------|---------------|------ |------ |------ |-------|
| STDC2-Seg50 (Paddle)| 1024x512|80k| 74.62 |[backbone提取码:tss7](https://pan.baidu.com/s/16kh3aHTBBX6wfKiIG-y3yA) |[model+log提取码:nchx](https://pan.baidu.com/s/1sFHqZWhcl8hFzGCrXu_c7Q)|[vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/index?id=30a6031fcc7cc09db93b4d33eb21724a)|
62 changes: 62 additions & 0 deletions configs/stdcseg/stdc2_seg_cityscapes_1024x512_80k.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
_base_: '../_base_/cityscapes.yml'

batch_size: 36
iters: 80000

model:
type: STDCSeg
backbone:
type: STDCNet1446
pretrained: '/home/path/STDCNet1446_76.47.pdiparams'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backbone不应当是本地索引,应当是一个网址

num_classes: 19
pretrained: null


CuberrChen marked this conversation as resolved.
Show resolved Hide resolved
train_dataset:
type: Cityscapes
dataset_root: data/cityscapes
transforms:
- type: ResizeStepScaling
min_scale_factor: 0.125
max_scale_factor: 1.5
scale_step_size: 0.125
- type: RandomPaddingCrop
crop_size: [1024, 512]
- type: RandomHorizontalFlip
- type: RandomDistort
brightness_range: 0.5
contrast_range: 0.5
saturation_range: 0.5
- type: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
mode: train

val_dataset:
type: Cityscapes
dataset_root: data/cityscapes
transforms:
- type: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
mode: val


optimizer:
type: sgd
momentum: 0.9
weight_decay: 4.0e-5

loss:
types:
- type: OhemCrossEntropyLoss
- type: OhemCrossEntropyLoss
- type: OhemCrossEntropyLoss
- type: DetailAggregateLoss
coef: [1, 1, 1, 1]

lr_scheduler:
type: PolynomialDecay
learning_rate: 0.01
end_lr: 0
power: 0.9
1 change: 1 addition & 0 deletions paddleseg/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@
from .ppseg_lite import *
from .mla_transformer import MLATransformer
from .portraitnet import PortraitNet
from .stdcseg import STDCSeg
from .segformer import SegFormer
1 change: 1 addition & 0 deletions paddleseg/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .swin_transformer import *
from .mobilenetv2 import *
from .mix_transformer import *
from .stdcnet import *
247 changes: 247 additions & 0 deletions paddleseg/models/backbones/stdcnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import math
from paddleseg.cvlibs import manager
from paddleseg.utils import utils

__all__ = [
"STDCNet813", "STDCNet1446",
]

class STDCNet(nn.Layer):
"""
The STDCNet implementation based on PaddlePaddle.

The original article refers to Meituan
Fan, Mingyuan, et al. "Rethinking BiSeNet For Real-time Semantic Segmentation."
(https://arxiv.org/abs/2104.13188)

Args:
base (int, optional): base channels . Default: 64.
layers (list, optional): layers numbers list. It determines STDC block numbers of STDCNet's stage3\4\5. Defualt:[4, 5, 3].
block_num(int,optional): block_num of features block. Default: 4.
type(str,optional): feature fusion method "cat"/"add".Default:"cat".
num_classes (int, optional): class number for image classification. Default: 1000.
dropout(float,optional): dropout ratio. if >0,use dropout ratio. Default: 0.20.
pretrained (str, optional): The path of pretrained model.

"""

def __init__(self, base=64,
layers=[4, 5, 3],
block_num=4,
type="cat",
num_classes=1000,
dropout=0.20,
pretrained=None,
use_conv_last=False):
super(STDCNet, self).__init__()
if type == "cat":
block = CatBottleneck
elif type == "add":
block = AddBottleneck
self.use_conv_last = use_conv_last
self.features = self._make_layers(base, layers, block_num, block)
self.conv_last = ConvX(base * 16, max(1024, base * 16), 1, 1)
self.gap = nn.AdaptiveAvgPool2D(1)
self.fc = nn.Linear(max(1024, base * 16), max(1024, base * 16),bias_attr=None)
self.bn = nn.BatchNorm1D(max(1024, base * 16))
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=dropout)
self.linear = nn.Linear(max(1024, base * 16), num_classes, bias_attr=None)

if(layers==[4,5,3]): #stdc1446
self.x2 = nn.Sequential(self.features[:1])
self.x4 = nn.Sequential(self.features[1:2])
self.x8 = nn.Sequential(self.features[2:6])
self.x16 = nn.Sequential(self.features[6:11])
self.x32 = nn.Sequential(self.features[11:])
elif(layers==[2,2,2]):#stdc813
self.x2 = nn.Sequential(self.features[:1])
self.x4 = nn.Sequential(self.features[1:2])
self.x8 = nn.Sequential(self.features[2:4])
self.x16 = nn.Sequential(self.features[4:6])
self.x32 = nn.Sequential(self.features[6:])
else:
raise NotImplementedError("model with layers:{} is not implemented!".format(layers))

self.pretrained = pretrained
self.init_weight()

def init_weight(self):
if self.pretrained is not None:
utils.load_pretrained_model(self, self.pretrained)

def _make_layers(self, base, layers, block_num, block):
features = []
features += [ConvX(3, base // 2, 3, 2)]
features += [ConvX(base // 2, base, 3, 2)]

for i, layer in enumerate(layers):
for j in range(layer):
if i == 0 and j == 0:
features.append(block(base, base * 4, block_num, 2))
elif j == 0:
features.append(block(base * int(math.pow(2, i + 1)), base * int(math.pow(2, i + 2)), block_num, 2))
else:
features.append(block(base * int(math.pow(2, i + 2)), base * int(math.pow(2, i + 2)), block_num, 1))

return nn.Sequential(*features)

def forward(self, x):
feat2 = self.x2(x)
feat4 = self.x4(feat2)
feat8 = self.x8(feat4)
feat16 = self.x16(feat8)
feat32 = self.x32(feat16)
if self.use_conv_last:
feat32 = self.conv_last(feat32)

return feat2, feat4, feat8, feat16, feat32

def forward_impl(self, x):
out = self.features(x)
out = self.conv_last(out).pow(2)
out = self.gap(out).flatten(1)
out = self.fc(out)
# out = self.bn(out)
out = self.relu(out)
# out = self.relu(self.bn(self.fc(out)))
out = self.dropout(out)
out = self.linear(out)
return out

class ConvX(nn.Layer):
CuberrChen marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, in_planes, out_planes, kernel=3, stride=1):
super(ConvX, self).__init__()
self.conv = nn.Conv2D(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel // 2,bias_attr=None)
self.bn = nn.BatchNorm2D(out_planes)
self.relu = nn.ReLU()

def forward(self, x):
out = self.relu(self.bn(self.conv(x)))
return out


class AddBottleneck(nn.Layer):
def __init__(self, in_planes, out_planes, block_num=3, stride=1):
super(AddBottleneck, self).__init__()
assert block_num > 1, print("block number should be larger than 1.")
self.conv_list = nn.LayerList()
self.stride = stride
if stride == 2:
self.avd_layer = nn.Sequential(
nn.Conv2D(out_planes // 2, out_planes // 2, kernel_size=3, stride=2, padding=1, groups=out_planes // 2,bias_attr=None),
nn.BatchNorm2D(out_planes // 2),
)
self.skip = nn.Sequential(
nn.Conv2D(in_planes, in_planes, kernel_size=3, stride=2, padding=1, groups=in_planes,bias_attr=None),
nn.BatchNorm2D(in_planes),
nn.Conv2D(in_planes, out_planes, kernel_size=1,bias_attr=None),
nn.BatchNorm2D(out_planes),
)
stride = 1

for idx in range(block_num):
if idx == 0:
self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1))
elif idx == 1 and block_num == 2:
self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride))
elif idx == 1 and block_num > 2:
self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride))
elif idx < block_num - 1:
self.conv_list.append(
ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx + 1))))
else:
self.conv_list.append(ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx))))

def forward(self, x):
out_list = []
out = x

for idx, conv in enumerate(self.conv_list):
if idx == 0 and self.stride == 2:
out = self.avd_layer(conv(out))
else:
out = conv(out)
out_list.append(out)

if self.stride == 2:
x = self.skip(x)

return paddle.concat(out_list, axis=1) + x


class CatBottleneck(nn.Layer):
def __init__(self, in_planes, out_planes, block_num=3, stride=1):
super(CatBottleneck, self).__init__()
assert block_num > 1, print("block number should be larger than 1.")
self.conv_list = nn.LayerList()
self.stride = stride
if stride == 2:
self.avd_layer = nn.Sequential(
nn.Conv2D(out_planes // 2, out_planes // 2, kernel_size=3, stride=2, padding=1, groups=out_planes // 2,bias_attr=None
),
nn.BatchNorm2D(out_planes // 2),
)
self.skip = nn.AvgPool2D(kernel_size=3, stride=2, padding=1)
stride = 1

for idx in range(block_num):
if idx == 0:
self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1))
elif idx == 1 and block_num == 2:
self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride))
elif idx == 1 and block_num > 2:
self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride))
elif idx < block_num - 1:
self.conv_list.append(
ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx + 1))))
else:
self.conv_list.append(ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx))))

def forward(self, x):
out_list = []
out1 = self.conv_list[0](x)

for idx, conv in enumerate(self.conv_list[1:]):
if idx == 0:
if self.stride == 2:
out = conv(self.avd_layer(out1))
else:
out = conv(out1)
else:
out = conv(out)
out_list.append(out)

if self.stride == 2:
out1 = self.skip(out1)
out_list.insert(0, out1)

out = paddle.concat(out_list, axis=1)
return out


@manager.BACKBONES.add_component
def STDCNet1446(**kwargs):
model = STDCNet(base=64,layers=[4,5,3],**kwargs)
return model

@manager.BACKBONES.add_component
def STDCNet813(**kwargs):
model = STDCNet(base=64,layers=[2,2,2],**kwargs)
return model
1 change: 1 addition & 0 deletions paddleseg/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@
from .focal_loss import FocalLoss
from .kl_loss import KLLoss
from .rmi_loss import RMILoss
from .detail_aggregate_loss import DetailAggregateLoss
Loading