diff --git a/configs/mobileseg/README.md b/configs/mobileseg/README.md new file mode 100644 index 0000000000..7d2a8ab5f3 --- /dev/null +++ b/configs/mobileseg/README.md @@ -0,0 +1,29 @@ +# MobileSeg + +These semantic segmentation models are designed for mobile and edge devices. + +MobileSeg models adopt encoder-decoder architecture and use lightweight models as encoder. + +## Reference + +> Sandler, Mark, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, and Liang-Chieh Chen. "Mobilenetv2: Inverted residuals and linear bottlenecks." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 4510-4520. 2018. + +> Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for mobilenetv3." In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 1314-1324. 2019. + +> Ma, Ningning, Xiangyu Zhang, Hai-Tao Zheng, and Jian Sun. "Shufflenet v2: Practical guidelines for efficient cnn architecture design." In Proceedings of the European conference on computer vision (ECCV), pp. 116-131. 2018. + +> Yu, Changqian, Bin Xiao, Changxin Gao, Lu Yuan, Lei Zhang, Nong Sang, and Jingdong Wang. "Lite-hrnet: A lightweight high-resolution network." In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 10440-10450. 2021. + +> Han, Kai, Yunhe Wang, Qi Tian, Jianyuan Guo, Chunjing Xu, and Chang Xu. "Ghostnet: More features from cheap operations." In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 1580-1589. 2020. + +## Performance + +### Cityscapes + +| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links | +|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| +|MobileSeg|MobileNetV2|1024x512|80000|73.94%|74.32%|75.33%|[model](https://paddleseg.bj.bcebos.com/dygraph/cityscapes/mobileseg_mobilenetv2_cityscapes_1024x512_80k/model.pdparams) \| [log](https://paddleseg.bj.bcebos.com/dygraph/cityscapes/mobileseg_mobilenetv2_cityscapes_1024x512_80k/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=f210c79b6fd52f5135cf2f238e9d678d)| +|MobileSeg|MobileNetV3_large_x1_0|1024x512|80000|73.47%|73.72%|74.70%|[model](https://paddleseg.bj.bcebos.com/dygraph/cityscapes/mobileseg_mobilenetv3_cityscapes_1024x512_80k/model.pdparams) \| [log](https://paddleseg.bj.bcebos.com/dygraph/cityscapes/mobileseg_mobilenetv3_cityscapes_1024x512_80k/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=28c57d0e666337ea98a1046160ef95d2)| +|MobileSeg|Lite_HRNet_18|1024x512|80000|70.75%|71.62%|72.40%|[model](https://paddleseg.bj.bcebos.com/dygraph/cityscapes/mobileseg_litehrnet18_cityscapes_1024x512_80k/model.pdparams) \| [log](https://paddleseg.bj.bcebos.com/dygraph/cityscapes/mobileseg_litehrnet18_cityscapes_1024x512_80k/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=02706145c7c463f3c76a0cb9d54728b8)| +|MobileSeg|ShuffleNetV2_x1_0|1024x512|80000|69.46%|70.00%|70.90%|[model](https://paddleseg.bj.bcebos.com/dygraph/cityscapes/mobileseg_shufflenetv2_cityscapes_1024x512_80k/model.pdparams) \| [log](https://paddleseg.bj.bcebos.com/dygraph/cityscapes/mobileseg_shufflenetv2_cityscapes_1024x512_80k/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=3d83c00cf9b90f2446959e8c97a4fb7a)| +|MobileSeg|GhostNet_x1_0|1024x512|80000|71.88%|72.22%|73.11%|[model](https://paddleseg.bj.bcebos.com/dygraph/cityscapes/mobileseg_ghostnet_cityscapes_1024x512_80k/model.pdparams) \| [log](https://paddleseg.bj.bcebos.com/dygraph/cityscapes/mobileseg_ghostnet_cityscapes_1024x512_80k/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=73a6b325c0ae941a40746d53911c03bc)| diff --git a/configs/mobileseg/mobileseg_ghostnet_cityscapes_1024x512_80k.yml b/configs/mobileseg/mobileseg_ghostnet_cityscapes_1024x512_80k.yml new file mode 100644 index 0000000000..d1c6b52065 --- /dev/null +++ b/configs/mobileseg/mobileseg_ghostnet_cityscapes_1024x512_80k.yml @@ -0,0 +1,48 @@ +_base_: '../_base_/cityscapes.yml' + +batch_size: 4 # use 4 GPU in default +iters: 80000 + +optimizer: + weight_decay: 5.0e-4 + +lr_scheduler: + warmup_iters: 1000 + warmup_start_lr: 1.0e-5 + learning_rate: 0.005 + +loss: + types: + - type: OhemCrossEntropyLoss + min_kept: 130000 + - type: OhemCrossEntropyLoss + min_kept: 130000 + - type: OhemCrossEntropyLoss + min_kept: 130000 + coef: [1, 1, 1] + +train_dataset: + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + scale_step_size: 0.25 + - type: RandomPaddingCrop + crop_size: [1024, 512] + - type: RandomHorizontalFlip + - type: RandomDistort + brightness_range: 0.5 + contrast_range: 0.5 + saturation_range: 0.5 + - type: Normalize + mode: train + +model: + type: MobileSeg + backbone: + type: GhostNet_x1_0 # out channels: [24, 40, 112, 160] + pretrained: https://paddleseg.bj.bcebos.com/dygraph/backbone/ghostnet_x1_0.zip + cm_bin_sizes: [1, 2, 4] + cm_out_ch: 128 + arm_out_chs: [32, 64, 128] + seg_head_inter_chs: [32, 32, 32] \ No newline at end of file diff --git a/configs/mobileseg/mobileseg_litehrnet18_cityscapes_1024x512_80k.yml b/configs/mobileseg/mobileseg_litehrnet18_cityscapes_1024x512_80k.yml new file mode 100644 index 0000000000..5f4769ab64 --- /dev/null +++ b/configs/mobileseg/mobileseg_litehrnet18_cityscapes_1024x512_80k.yml @@ -0,0 +1,50 @@ +_base_: '../_base_/cityscapes.yml' + +batch_size: 4 +iters: 80000 + +optimizer: + weight_decay: 5.0e-4 + +lr_scheduler: + warmup_iters: 1000 + warmup_start_lr: 1.0e-5 + learning_rate: 0.005 + +loss: + types: + - type: OhemCrossEntropyLoss + min_kept: 130000 + - type: OhemCrossEntropyLoss + min_kept: 130000 + - type: OhemCrossEntropyLoss + min_kept: 130000 + coef: [1, 1, 1] + +train_dataset: + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + scale_step_size: 0.25 + - type: RandomPaddingCrop + crop_size: [1024, 512] + - type: RandomHorizontalFlip + - type: RandomDistort + brightness_range: 0.5 + contrast_range: 0.5 + saturation_range: 0.5 + - type: Normalize + mode: train + +model: + type: MobileSeg + backbone: + type: Lite_HRNet_18 + use_head: True # False : [40, 80, 160, 320] True: [40, 40, 80, 160] + pretrained: https://paddleseg.bj.bcebos.com/dygraph/backbone/lite_hrnet_18.tar.gz + backbone_indices: [0, 1, 2] + cm_bin_sizes: [1, 2, 4] + cm_out_ch: 128 + arm_out_chs: [32, 64, 128] + seg_head_inter_chs: [32, 32, 32] \ No newline at end of file diff --git a/configs/mobileseg/mobileseg_mobilenetv2_cityscapes_1024x512_80k.yml b/configs/mobileseg/mobileseg_mobilenetv2_cityscapes_1024x512_80k.yml new file mode 100644 index 0000000000..a109034a32 --- /dev/null +++ b/configs/mobileseg/mobileseg_mobilenetv2_cityscapes_1024x512_80k.yml @@ -0,0 +1,48 @@ +_base_: '../_base_/cityscapes.yml' + +batch_size: 4 +iters: 80000 + +optimizer: + weight_decay: 5.0e-4 + +lr_scheduler: + warmup_iters: 1000 + warmup_start_lr: 1.0e-5 + learning_rate: 0.005 + +loss: + types: + - type: OhemCrossEntropyLoss + min_kept: 130000 + - type: OhemCrossEntropyLoss + min_kept: 130000 + - type: OhemCrossEntropyLoss + min_kept: 130000 + coef: [1, 1, 1] + +train_dataset: + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + scale_step_size: 0.25 + - type: RandomPaddingCrop + crop_size: [1024, 512] + - type: RandomHorizontalFlip + - type: RandomDistort + brightness_range: 0.5 + contrast_range: 0.5 + saturation_range: 0.5 + - type: Normalize + mode: train + +model: + type: MobileSeg + backbone: + type: MobileNetV2_x1_0 # out channels: [24, 32, 96, 320] + pretrained: https://paddleseg.bj.bcebos.com/dygraph/backbone/mobilenetv2_x1_0_ssld.tar.gz + cm_bin_sizes: [1, 2, 4] + cm_out_ch: 128 + arm_out_chs: [32, 64, 128] + seg_head_inter_chs: [32, 32, 32] diff --git a/configs/mobileseg/mobileseg_mobilenetv3_cityscapes_1024x512_80k.yml b/configs/mobileseg/mobileseg_mobilenetv3_cityscapes_1024x512_80k.yml new file mode 100644 index 0000000000..fd4e7a7cc5 --- /dev/null +++ b/configs/mobileseg/mobileseg_mobilenetv3_cityscapes_1024x512_80k.yml @@ -0,0 +1,48 @@ +_base_: '../_base_/cityscapes.yml' + +batch_size: 4 +iters: 80000 + +optimizer: + weight_decay: 5.0e-4 + +lr_scheduler: + warmup_iters: 1000 + warmup_start_lr: 1.0e-5 + learning_rate: 0.005 + +loss: + types: + - type: OhemCrossEntropyLoss + min_kept: 130000 + - type: OhemCrossEntropyLoss + min_kept: 130000 + - type: OhemCrossEntropyLoss + min_kept: 130000 + coef: [1, 1, 1] + +train_dataset: + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + scale_step_size: 0.25 + - type: RandomPaddingCrop + crop_size: [1024, 512] + - type: RandomHorizontalFlip + - type: RandomDistort + brightness_range: 0.5 + contrast_range: 0.5 + saturation_range: 0.5 + - type: Normalize + mode: train + +model: + type: MobileSeg + backbone: + type: MobileNetV3_large_x1_0 # out channels: [24, 40, 112, 160] + pretrained: https://paddleseg.bj.bcebos.com/dygraph/backbone/mobilenetv3_large_x1_0_ssld.tar.gz + cm_bin_sizes: [1, 2, 4] + cm_out_ch: 128 + arm_out_chs: [32, 64, 128] + seg_head_inter_chs: [32, 32, 32] \ No newline at end of file diff --git a/configs/mobileseg/mobileseg_shufflenetv2_cityscapes_1024x512_80k.yml b/configs/mobileseg/mobileseg_shufflenetv2_cityscapes_1024x512_80k.yml new file mode 100644 index 0000000000..ec766944fd --- /dev/null +++ b/configs/mobileseg/mobileseg_shufflenetv2_cityscapes_1024x512_80k.yml @@ -0,0 +1,48 @@ +_base_: '../_base_/cityscapes.yml' + +batch_size: 4 +iters: 80000 + +optimizer: + weight_decay: 5.0e-4 + +lr_scheduler: + warmup_iters: 1000 + warmup_start_lr: 1.0e-5 + learning_rate: 0.005 + +loss: + types: + - type: OhemCrossEntropyLoss + min_kept: 130000 + - type: OhemCrossEntropyLoss + min_kept: 130000 + - type: OhemCrossEntropyLoss + min_kept: 130000 + coef: [1, 1, 1] + +train_dataset: + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + scale_step_size: 0.25 + - type: RandomPaddingCrop + crop_size: [1024, 512] + - type: RandomHorizontalFlip + - type: RandomDistort + brightness_range: 0.5 + contrast_range: 0.5 + saturation_range: 0.5 + - type: Normalize + mode: train + +model: + type: MobileSeg + backbone: + type: ShuffleNetV2_x1_0 # out channels: [24, 116, 232, 464] + pretrained: https://paddleseg.bj.bcebos.com/dygraph/backbone/shufflenetv2_x1_0.zip + cm_bin_sizes: [1, 2, 4] + cm_out_ch: 128 + arm_out_chs: [32, 64, 128] + seg_head_inter_chs: [32, 32, 32] \ No newline at end of file diff --git a/paddleseg/models/__init__.py b/paddleseg/models/__init__.py index 734350fb88..d7ea63585c 100644 --- a/paddleseg/models/__init__.py +++ b/paddleseg/models/__init__.py @@ -59,3 +59,4 @@ from .glore import GloRe from .ddrnet import DDRNet_23 from .ccnet import CCNet +from .mobileseg import MobileSeg diff --git a/paddleseg/models/backbones/__init__.py b/paddleseg/models/backbones/__init__.py index 108f87d013..77860d0626 100644 --- a/paddleseg/models/backbones/__init__.py +++ b/paddleseg/models/backbones/__init__.py @@ -21,3 +21,6 @@ from .mobilenetv2 import * from .mix_transformer import * from .stdcnet import * +from .lite_hrnet import * +from .shufflenetv2 import * +from .ghostnet import * \ No newline at end of file diff --git a/paddleseg/models/backbones/ghostnet.py b/paddleseg/models/backbones/ghostnet.py new file mode 100644 index 0000000000..eaa47f2880 --- /dev/null +++ b/paddleseg/models/backbones/ghostnet.py @@ -0,0 +1,318 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Code was based on https://github.com/huawei-noah/CV-Backbones/tree/master/ghostnet_pytorch + +import math +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import Conv2D, BatchNorm, AdaptiveAvgPool2D, Linear +from paddle.regularizer import L2Decay +from paddle.nn.initializer import Uniform, KaimingNormal + +from paddleseg.cvlibs import manager +from paddleseg.utils import utils, logger + +__all__ = ["GhostNet_x0_5", "GhostNet_x1_0", "GhostNet_x1_3"] + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + act="relu", + name=None): + super(ConvBNLayer, self).__init__() + self._conv = Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr( + initializer=KaimingNormal(), name=name + "_weights"), + bias_attr=False) + bn_name = name + "_bn" + + self._batch_norm = BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr( + name=bn_name + "_scale", regularizer=L2Decay(0.0)), + bias_attr=ParamAttr( + name=bn_name + "_offset", regularizer=L2Decay(0.0)), + moving_mean_name=bn_name + "_mean", + moving_variance_name=bn_name + "_variance") + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class SEBlock(nn.Layer): + def __init__(self, num_channels, reduction_ratio=4, name=None): + super(SEBlock, self).__init__() + self.pool2d_gap = AdaptiveAvgPool2D(1) + self._num_channels = num_channels + stdv = 1.0 / math.sqrt(num_channels * 1.0) + med_ch = num_channels // reduction_ratio + self.squeeze = Linear( + num_channels, + med_ch, + weight_attr=ParamAttr( + initializer=Uniform(-stdv, stdv), name=name + "_1_weights"), + bias_attr=ParamAttr(name=name + "_1_offset")) + stdv = 1.0 / math.sqrt(med_ch * 1.0) + self.excitation = Linear( + med_ch, + num_channels, + weight_attr=ParamAttr( + initializer=Uniform(-stdv, stdv), name=name + "_2_weights"), + bias_attr=ParamAttr(name=name + "_2_offset")) + + def forward(self, inputs): + pool = self.pool2d_gap(inputs) + pool = paddle.squeeze(pool, axis=[2, 3]) + squeeze = self.squeeze(pool) + squeeze = F.relu(squeeze) + excitation = self.excitation(squeeze) + excitation = paddle.clip(x=excitation, min=0, max=1) + excitation = paddle.unsqueeze(excitation, axis=[2, 3]) + out = paddle.multiply(inputs, excitation) + return out + + +class GhostModule(nn.Layer): + def __init__(self, + in_channels, + output_channels, + kernel_size=1, + ratio=2, + dw_size=3, + stride=1, + relu=True, + name=None): + super(GhostModule, self).__init__() + init_channels = int(math.ceil(output_channels / ratio)) + new_channels = int(init_channels * (ratio - 1)) + self.primary_conv = ConvBNLayer( + in_channels=in_channels, + out_channels=init_channels, + kernel_size=kernel_size, + stride=stride, + groups=1, + act="relu" if relu else None, + name=name + "_primary_conv") + self.cheap_operation = ConvBNLayer( + in_channels=init_channels, + out_channels=new_channels, + kernel_size=dw_size, + stride=1, + groups=init_channels, + act="relu" if relu else None, + name=name + "_cheap_operation") + + def forward(self, inputs): + x = self.primary_conv(inputs) + y = self.cheap_operation(x) + out = paddle.concat([x, y], axis=1) + return out + + +class GhostBottleneck(nn.Layer): + def __init__(self, + in_channels, + hidden_dim, + output_channels, + kernel_size, + stride, + use_se, + name=None): + super(GhostBottleneck, self).__init__() + self._stride = stride + self._use_se = use_se + self._num_channels = in_channels + self._output_channels = output_channels + self.ghost_module_1 = GhostModule( + in_channels=in_channels, + output_channels=hidden_dim, + kernel_size=1, + stride=1, + relu=True, + name=name + "_ghost_module_1") + if stride == 2: + self.depthwise_conv = ConvBNLayer( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=kernel_size, + stride=stride, + groups=hidden_dim, + act=None, + name=name + + "_depthwise_depthwise" # looks strange due to an old typo, will be fixed later. + ) + if use_se: + self.se_block = SEBlock(num_channels=hidden_dim, name=name + "_se") + self.ghost_module_2 = GhostModule( + in_channels=hidden_dim, + output_channels=output_channels, + kernel_size=1, + relu=False, + name=name + "_ghost_module_2") + if stride != 1 or in_channels != output_channels: + self.shortcut_depthwise = ConvBNLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=stride, + groups=in_channels, + act=None, + name=name + + "_shortcut_depthwise_depthwise" # looks strange due to an old typo, will be fixed later. + ) + self.shortcut_conv = ConvBNLayer( + in_channels=in_channels, + out_channels=output_channels, + kernel_size=1, + stride=1, + groups=1, + act=None, + name=name + "_shortcut_conv") + + def forward(self, inputs): + x = self.ghost_module_1(inputs) + if self._stride == 2: + x = self.depthwise_conv(x) + if self._use_se: + x = self.se_block(x) + x = self.ghost_module_2(x) + if self._stride == 1 and self._num_channels == self._output_channels: + shortcut = inputs + else: + shortcut = self.shortcut_depthwise(inputs) + shortcut = self.shortcut_conv(shortcut) + return paddle.add(x=x, y=shortcut) + + +class GhostNet(nn.Layer): + def __init__(self, scale, pretrained=None): + super(GhostNet, self).__init__() + self.cfgs = [ + # k, t, c, SE, s + [3, 16, 16, 0, 1], + [3, 48, 24, 0, 2], + [3, 72, 24, 0, 1], # x4 + [5, 72, 40, 1, 2], + [5, 120, 40, 1, 1], # x8 + [3, 240, 80, 0, 2], + [3, 200, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 480, 112, 1, 1], + [3, 672, 112, 1, 1], # x16 + [5, 672, 160, 1, 2], + [5, 960, 160, 0, 1], + [5, 960, 160, 1, 1], + [5, 960, 160, 0, 1], + [5, 960, 160, 1, 1] # x32 + ] + self.scale = scale + self.pretrained = pretrained + + output_channels = int(self._make_divisible(16 * self.scale, 4)) + self.conv1 = ConvBNLayer( + in_channels=3, + out_channels=output_channels, + kernel_size=3, + stride=2, + groups=1, + act="relu", + name="conv1") + + # build inverted residual blocks + self.out_index = [2, 4, 10, 15] + self.feat_channels = [] + self.ghost_bottleneck_list = [] + for idx, (k, exp_size, c, use_se, s) in enumerate(self.cfgs): + in_channels = output_channels + output_channels = int(self._make_divisible(c * self.scale, 4)) + hidden_dim = int(self._make_divisible(exp_size * self.scale, 4)) + ghost_bottleneck = self.add_sublayer( + name="_ghostbottleneck_" + str(idx), + sublayer=GhostBottleneck( + in_channels=in_channels, + hidden_dim=hidden_dim, + output_channels=output_channels, + kernel_size=k, + stride=s, + use_se=use_se, + name="_ghostbottleneck_" + str(idx))) + self.ghost_bottleneck_list.append(ghost_bottleneck) + if idx in self.out_index: + self.feat_channels.append(output_channels) + + self.init_weight() + + def init_weight(self): + if self.pretrained is not None: + utils.load_entire_model(self, self.pretrained) + + def forward(self, inputs): + feat_list = [] + x = self.conv1(inputs) + for idx, ghost_bottleneck in enumerate(self.ghost_bottleneck_list): + x = ghost_bottleneck(x) + if idx in self.out_index: + feat_list.append(x) + return feat_list + + def _make_divisible(self, v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +@manager.BACKBONES.add_component +def GhostNet_x0_5(**kwargs): + model = GhostNet(scale=0.5, **kwargs) + return model + + +@manager.BACKBONES.add_component +def GhostNet_x1_0(**kwargs): + model = GhostNet(scale=1.0, **kwargs) + return model + + +@manager.BACKBONES.add_component +def GhostNet_x1_3(**kwargs): + model = GhostNet(scale=1.3, **kwargs) + return model diff --git a/paddleseg/models/backbones/lite_hrnet.py b/paddleseg/models/backbones/lite_hrnet.py new file mode 100644 index 0000000000..9b55e68ca5 --- /dev/null +++ b/paddleseg/models/backbones/lite_hrnet.py @@ -0,0 +1,972 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is based on +https://github.com/HRNet/Lite-HRNet/blob/hrnet/models/backbones/litehrnet.py +""" + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from numbers import Integral +from paddle import ParamAttr +from paddle.regularizer import L2Decay +from paddle.nn.initializer import Normal, Constant + +from paddleseg.cvlibs import manager +from paddleseg import utils + +__all__ = [ + "Lite_HRNet_18", "Lite_HRNet_30", "Lite_HRNet_naive", + "Lite_HRNet_wider_naive", "LiteHRNet" +] + + +def Conv2d(in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + weight_init=Normal(std=0.001), + bias_init=Constant(0.)): + weight_attr = paddle.framework.ParamAttr(initializer=weight_init) + if bias: + bias_attr = paddle.framework.ParamAttr(initializer=bias_init) + else: + bias_attr = False + conv = nn.Conv2D( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + weight_attr=weight_attr, + bias_attr=bias_attr) + return conv + + +def channel_shuffle(x, groups): + x_shape = paddle.shape(x) + batch_size, height, width = x_shape[0], x_shape[2], x_shape[3] + num_channels = x.shape[1] + channels_per_group = num_channels // groups + + x = paddle.reshape( + x=x, shape=[batch_size, groups, channels_per_group, height, width]) + x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4]) + x = paddle.reshape(x=x, shape=[batch_size, num_channels, height, width]) + + return x + + +class ConvNormLayer(nn.Layer): + def __init__(self, + ch_in, + ch_out, + filter_size, + stride=1, + groups=1, + norm_type=None, + norm_groups=32, + norm_decay=0., + freeze_norm=False, + act=None): + super(ConvNormLayer, self).__init__() + self.act = act + norm_lr = 0. if freeze_norm else 1. + if norm_type is not None: + assert norm_type in ['bn', 'sync_bn', 'gn'], \ + "norm_type should be one of ['bn', 'sync_bn', 'gn'], but got {}".format(norm_type) + param_attr = ParamAttr( + initializer=Constant(1.0), + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay), ) + bias_attr = ParamAttr( + learning_rate=norm_lr, regularizer=L2Decay(norm_decay)) + global_stats = True if freeze_norm else None + if norm_type in ['bn', 'sync_bn']: + self.norm = nn.BatchNorm2D( + ch_out, + weight_attr=param_attr, + bias_attr=bias_attr, + use_global_stats=global_stats, ) + elif norm_type == 'gn': + self.norm = nn.GroupNorm( + num_groups=norm_groups, + num_channels=ch_out, + weight_attr=param_attr, + bias_attr=bias_attr) + norm_params = self.norm.parameters() + if freeze_norm: + for param in norm_params: + param.stop_gradient = True + conv_bias_attr = False + else: + conv_bias_attr = True + self.norm = None + + self.conv = nn.Conv2D( + in_channels=ch_in, + out_channels=ch_out, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.001)), + bias_attr=conv_bias_attr) + + def forward(self, inputs): + out = self.conv(inputs) + if self.norm is not None: + out = self.norm(out) + + if self.act == 'relu': + out = F.relu(out) + elif self.act == 'sigmoid': + out = F.sigmoid(out) + return out + + +class DepthWiseSeparableConvNormLayer(nn.Layer): + def __init__(self, + ch_in, + ch_out, + filter_size, + stride=1, + dw_norm_type=None, + pw_norm_type=None, + norm_decay=0., + freeze_norm=False, + dw_act=None, + pw_act=None): + super(DepthWiseSeparableConvNormLayer, self).__init__() + self.depthwise_conv = ConvNormLayer( + ch_in=ch_in, + ch_out=ch_in, + filter_size=filter_size, + stride=stride, + groups=ch_in, + norm_type=dw_norm_type, + act=dw_act, + norm_decay=norm_decay, + freeze_norm=freeze_norm, ) + self.pointwise_conv = ConvNormLayer( + ch_in=ch_in, + ch_out=ch_out, + filter_size=1, + stride=1, + norm_type=pw_norm_type, + act=pw_act, + norm_decay=norm_decay, + freeze_norm=freeze_norm, ) + + def forward(self, x): + x = self.depthwise_conv(x) + x = self.pointwise_conv(x) + return x + + +class CrossResolutionWeightingModule(nn.Layer): + def __init__(self, + channels, + ratio=16, + norm_type='bn', + freeze_norm=False, + norm_decay=0.): + super(CrossResolutionWeightingModule, self).__init__() + self.channels = channels + total_channel = sum(channels) + self.conv1 = ConvNormLayer( + ch_in=total_channel, + ch_out=total_channel // ratio, + filter_size=1, + stride=1, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay) + self.conv2 = ConvNormLayer( + ch_in=total_channel // ratio, + ch_out=total_channel, + filter_size=1, + stride=1, + norm_type=norm_type, + act='sigmoid', + freeze_norm=freeze_norm, + norm_decay=norm_decay) + + def forward(self, x): + out = [] + for idx, xi in enumerate(x[:-1]): + kernel_size = stride = pow(2, len(x) - idx - 1) + xi = F.avg_pool2d(xi, kernel_size=kernel_size, stride=stride) + out.append(xi) + out.append(x[-1]) + + out = paddle.concat(out, 1) + out = self.conv1(out) + out = self.conv2(out) + out = paddle.split(out, self.channels, 1) + out = [ + s * F.interpolate( + a, paddle.shape(s)[-2:], mode='nearest') for s, a in zip(x, out) + ] + return out + + +class SpatialWeightingModule(nn.Layer): + def __init__(self, in_channel, ratio=16, freeze_norm=False, norm_decay=0.): + super(SpatialWeightingModule, self).__init__() + self.global_avgpooling = nn.AdaptiveAvgPool2D(1) + self.conv1 = ConvNormLayer( + ch_in=in_channel, + ch_out=in_channel // ratio, + filter_size=1, + stride=1, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay) + self.conv2 = ConvNormLayer( + ch_in=in_channel // ratio, + ch_out=in_channel, + filter_size=1, + stride=1, + act='sigmoid', + freeze_norm=freeze_norm, + norm_decay=norm_decay) + + def forward(self, x): + out = self.global_avgpooling(x) + out = self.conv1(out) + out = self.conv2(out) + return x * out + + +class ConditionalChannelWeightingBlock(nn.Layer): + def __init__(self, + in_channels, + stride, + reduce_ratio, + norm_type='bn', + freeze_norm=False, + norm_decay=0.): + super(ConditionalChannelWeightingBlock, self).__init__() + assert stride in [1, 2] + branch_channels = [channel // 2 for channel in in_channels] + + self.cross_resolution_weighting = CrossResolutionWeightingModule( + branch_channels, + ratio=reduce_ratio, + norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay) + self.depthwise_convs = nn.LayerList([ + ConvNormLayer( + channel, + channel, + filter_size=3, + stride=stride, + groups=channel, + norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay) for channel in branch_channels + ]) + + self.spatial_weighting = nn.LayerList([ + SpatialWeightingModule( + channel, + ratio=4, + freeze_norm=freeze_norm, + norm_decay=norm_decay) for channel in branch_channels + ]) + + def forward(self, x): + x = [s.chunk(2, axis=1) for s in x] + x1 = [s[0] for s in x] + x2 = [s[1] for s in x] + + x2 = self.cross_resolution_weighting(x2) + x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)] + x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)] + + out = [paddle.concat([s1, s2], axis=1) for s1, s2 in zip(x1, x2)] + out = [channel_shuffle(s, groups=2) for s in out] + return out + + +class ShuffleUnit(nn.Layer): + def __init__(self, + in_channel, + out_channel, + stride, + norm_type='bn', + freeze_norm=False, + norm_decay=0.): + super(ShuffleUnit, self).__init__() + branch_channel = out_channel // 2 + self.stride = stride + if self.stride == 1: + assert in_channel == branch_channel * 2, \ + "when stride=1, in_channel {} should equal to branch_channel*2 {}".format(in_channel, branch_channel * 2) + if stride > 1: + self.branch1 = nn.Sequential( + ConvNormLayer( + ch_in=in_channel, + ch_out=in_channel, + filter_size=3, + stride=self.stride, + groups=in_channel, + norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay), + ConvNormLayer( + ch_in=in_channel, + ch_out=branch_channel, + filter_size=1, + stride=1, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay), ) + self.branch2 = nn.Sequential( + ConvNormLayer( + ch_in=branch_channel if stride == 1 else in_channel, + ch_out=branch_channel, + filter_size=1, + stride=1, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay), + ConvNormLayer( + ch_in=branch_channel, + ch_out=branch_channel, + filter_size=3, + stride=self.stride, + groups=branch_channel, + norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay), + ConvNormLayer( + ch_in=branch_channel, + ch_out=branch_channel, + filter_size=1, + stride=1, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay), ) + + def forward(self, x): + if self.stride > 1: + x1 = self.branch1(x) + x2 = self.branch2(x) + else: + x1, x2 = x.chunk(2, axis=1) + x2 = self.branch2(x2) + out = paddle.concat([x1, x2], axis=1) + out = channel_shuffle(out, groups=2) + return out + + +class IterativeHead(nn.Layer): + def __init__(self, + in_channels, + norm_type='bn', + freeze_norm=False, + norm_decay=0.): + super(IterativeHead, self).__init__() + num_branches = len(in_channels) + self.in_channels = in_channels[::-1] + + projects = [] + for i in range(num_branches): + if i != num_branches - 1: + projects.append( + DepthWiseSeparableConvNormLayer( + ch_in=self.in_channels[i], + ch_out=self.in_channels[i + 1], + filter_size=3, + stride=1, + dw_act=None, + pw_act='relu', + dw_norm_type=norm_type, + pw_norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay)) + else: + projects.append( + DepthWiseSeparableConvNormLayer( + ch_in=self.in_channels[i], + ch_out=self.in_channels[i], + filter_size=3, + stride=1, + dw_act=None, + pw_act='relu', + dw_norm_type=norm_type, + pw_norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay)) + self.projects = nn.LayerList(projects) + + def forward(self, x): + x = x[::-1] + y = [] + last_x = None + for i, s in enumerate(x): + if last_x is not None: + last_x = F.interpolate( + last_x, + size=paddle.shape(s)[-2:], + mode='bilinear', + align_corners=True) + s = s + last_x + s = self.projects[i](s) + y.append(s) + last_x = s + + return y[::-1] + + +class Stem(nn.Layer): + def __init__(self, + in_channel, + stem_channel, + out_channel, + expand_ratio, + norm_type='bn', + freeze_norm=False, + norm_decay=0.): + super(Stem, self).__init__() + self.conv1 = ConvNormLayer( + in_channel, + stem_channel, + filter_size=3, + stride=2, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay) + mid_channel = int(round(stem_channel * expand_ratio)) + branch_channel = stem_channel // 2 + if stem_channel == out_channel: + inc_channel = out_channel - branch_channel + else: + inc_channel = out_channel - stem_channel + self.branch1 = nn.Sequential( + ConvNormLayer( + ch_in=branch_channel, + ch_out=branch_channel, + filter_size=3, + stride=2, + groups=branch_channel, + norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay), + ConvNormLayer( + ch_in=branch_channel, + ch_out=inc_channel, + filter_size=1, + stride=1, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay), ) + self.expand_conv = ConvNormLayer( + ch_in=branch_channel, + ch_out=mid_channel, + filter_size=1, + stride=1, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay) + self.depthwise_conv = ConvNormLayer( + ch_in=mid_channel, + ch_out=mid_channel, + filter_size=3, + stride=2, + groups=mid_channel, + norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay) + self.linear_conv = ConvNormLayer( + ch_in=mid_channel, + ch_out=branch_channel + if stem_channel == out_channel else stem_channel, + filter_size=1, + stride=1, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay) + + def forward(self, x): + x = self.conv1(x) + x1, x2 = x.chunk(2, axis=1) + x1 = self.branch1(x1) + x2 = self.expand_conv(x2) + x2 = self.depthwise_conv(x2) + x2 = self.linear_conv(x2) + out = paddle.concat([x1, x2], axis=1) + out = channel_shuffle(out, groups=2) + + return out + + +class LiteHRNetModule(nn.Layer): + def __init__(self, + num_branches, + num_blocks, + in_channels, + reduce_ratio, + module_type, + multiscale_output=False, + with_fuse=True, + norm_type='bn', + freeze_norm=False, + norm_decay=0.): + super(LiteHRNetModule, self).__init__() + assert num_branches == len(in_channels),\ + "num_branches {} should equal to num_in_channels {}".format(num_branches, len(in_channels)) + assert module_type in [ + 'LITE', 'NAIVE' + ], "module_type should be one of ['LITE', 'NAIVE']" + self.num_branches = num_branches + self.in_channels = in_channels + self.multiscale_output = multiscale_output + self.with_fuse = with_fuse + self.norm_type = 'bn' + self.module_type = module_type + + if self.module_type == 'LITE': + self.layers = self._make_weighting_blocks( + num_blocks, + reduce_ratio, + freeze_norm=freeze_norm, + norm_decay=norm_decay) + elif self.module_type == 'NAIVE': + self.layers = self._make_naive_branches( + num_branches, + num_blocks, + freeze_norm=freeze_norm, + norm_decay=norm_decay) + + if self.with_fuse: + self.fuse_layers = self._make_fuse_layers( + freeze_norm=freeze_norm, norm_decay=norm_decay) + self.relu = nn.ReLU() + + def _make_weighting_blocks(self, + num_blocks, + reduce_ratio, + stride=1, + freeze_norm=False, + norm_decay=0.): + layers = [] + for i in range(num_blocks): + layers.append( + ConditionalChannelWeightingBlock( + self.in_channels, + stride=stride, + reduce_ratio=reduce_ratio, + norm_type=self.norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay)) + return nn.Sequential(*layers) + + def _make_naive_branches(self, + num_branches, + num_blocks, + freeze_norm=False, + norm_decay=0.): + branches = [] + for branch_idx in range(num_branches): + layers = [] + for i in range(num_blocks): + layers.append( + ShuffleUnit( + self.in_channels[branch_idx], + self.in_channels[branch_idx], + stride=1, + norm_type=self.norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay)) + branches.append(nn.Sequential(*layers)) + return nn.LayerList(branches) + + def _make_fuse_layers(self, freeze_norm=False, norm_decay=0.): + if self.num_branches == 1: + return None + fuse_layers = [] + num_out_branches = self.num_branches if self.multiscale_output else 1 + for i in range(num_out_branches): + fuse_layer = [] + for j in range(self.num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + Conv2d( + self.in_channels[j], + self.in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False, ), + nn.BatchNorm2D(self.in_channels[i]), + nn.Upsample( + scale_factor=2**(j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv_downsamples = [] + for k in range(i - j): + if k == i - j - 1: + conv_downsamples.append( + nn.Sequential( + Conv2d( + self.in_channels[j], + self.in_channels[j], + kernel_size=3, + stride=2, + padding=1, + groups=self.in_channels[j], + bias=False, ), + nn.BatchNorm2D(self.in_channels[j]), + Conv2d( + self.in_channels[j], + self.in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False, ), + nn.BatchNorm2D(self.in_channels[i]))) + else: + conv_downsamples.append( + nn.Sequential( + Conv2d( + self.in_channels[j], + self.in_channels[j], + kernel_size=3, + stride=2, + padding=1, + groups=self.in_channels[j], + bias=False, ), + nn.BatchNorm2D(self.in_channels[j]), + Conv2d( + self.in_channels[j], + self.in_channels[j], + kernel_size=1, + stride=1, + padding=0, + bias=False, ), + nn.BatchNorm2D(self.in_channels[j]), + nn.ReLU())) + + fuse_layer.append(nn.Sequential(*conv_downsamples)) + fuse_layers.append(nn.LayerList(fuse_layer)) + + return nn.LayerList(fuse_layers) + + def forward(self, x): + if self.num_branches == 1: + return [self.layers[0](x[0])] + if self.module_type == 'LITE': + out = self.layers(x) + elif self.module_type == 'NAIVE': + for i in range(self.num_branches): + x[i] = self.layers[i](x[i]) + out = x + if self.with_fuse: + out_fuse = [] + for i in range(len(self.fuse_layers)): + y = out[0] if i == 0 else self.fuse_layers[i][0](out[0]) + for j in range(self.num_branches): + if j == 0: + y += y + elif i == j: + y += out[j] + else: + y += self.fuse_layers[i][j](out[j]) + if i == 0: + out[i] = y + out_fuse.append(self.relu(y)) + out = out_fuse + elif not self.multiscale_output: + out = [out[0]] + return out + + +class LiteHRNet(nn.Layer): + """ + @inproceedings{Yulitehrnet21, + title={Lite-HRNet: A Lightweight High-Resolution Network}, + author={Yu, Changqian and Xiao, Bin and Gao, Changxin and Yuan, Lu and Zhang, Lei and Sang, Nong and Wang, Jingdong}, + booktitle={CVPR},year={2021} + } + + Args: + network_type (str): the network_type should be one of ["lite_18", "lite_30", "naive", "wider_naive"], + "naive": Simply combining the shuffle block in ShuffleNet and the highresolution design pattern in HRNet. + "wider_naive": Naive network with wider channels in each block. + "lite_18": Lite-HRNet-18, which replaces the pointwise convolution in a shuffle block by conditional channel weighting. + "lite_30": Lite-HRNet-30, with more blocks compared with Lite-HRNet-18. + freeze_at (int): the stage to freeze + freeze_norm (bool): whether to freeze norm in HRNet + norm_decay (float): weight decay for normalization layer weights + return_idx (List): the stage to return + """ + + def __init__(self, + network_type, + freeze_at=0, + freeze_norm=True, + norm_decay=0., + return_idx=[0, 1, 2, 3], + use_head=False, + pretrained=None): + super(LiteHRNet, self).__init__() + if isinstance(return_idx, Integral): + return_idx = [return_idx] + assert network_type in ["lite_18", "lite_30", "naive", "wider_naive"], \ + "the network_type should be one of [lite_18, lite_30, naive, wider_naive]" + assert len(return_idx) > 0, "need one or more return index" + self.freeze_at = freeze_at + self.freeze_norm = freeze_norm + self.norm_decay = norm_decay + self.return_idx = return_idx + self.norm_type = 'bn' + self.use_head = use_head + self.pretrained = pretrained + + self.module_configs = { + "lite_18": { + "num_modules": [2, 4, 2], + "num_branches": [2, 3, 4], + "num_blocks": [2, 2, 2], + "module_type": ["LITE", "LITE", "LITE"], + "reduce_ratios": [8, 8, 8], + "num_channels": [[40, 80], [40, 80, 160], [40, 80, 160, 320]], + }, + "lite_30": { + "num_modules": [3, 8, 3], + "num_branches": [2, 3, 4], + "num_blocks": [2, 2, 2], + "module_type": ["LITE", "LITE", "LITE"], + "reduce_ratios": [8, 8, 8], + "num_channels": [[40, 80], [40, 80, 160], [40, 80, 160, 320]], + }, + "naive": { + "num_modules": [2, 4, 2], + "num_branches": [2, 3, 4], + "num_blocks": [2, 2, 2], + "module_type": ["NAIVE", "NAIVE", "NAIVE"], + "reduce_ratios": [1, 1, 1], + "num_channels": [[30, 60], [30, 60, 120], [30, 60, 120, 240]], + }, + "wider_naive": { + "num_modules": [2, 4, 2], + "num_branches": [2, 3, 4], + "num_blocks": [2, 2, 2], + "module_type": ["NAIVE", "NAIVE", "NAIVE"], + "reduce_ratios": [1, 1, 1], + "num_channels": [[40, 80], [40, 80, 160], [40, 80, 160, 320]], + }, + } + + self.stages_config = self.module_configs[network_type] + + self.stem = Stem(3, 32, 32, 1) + num_channels_pre_layer = [32] + for stage_idx in range(3): + num_channels = self.stages_config["num_channels"][stage_idx] + setattr(self, 'transition{}'.format(stage_idx), + self._make_transition_layer(num_channels_pre_layer, + num_channels, self.freeze_norm, + self.norm_decay)) + stage, num_channels_pre_layer = self._make_stage( + self.stages_config, stage_idx, num_channels, True, + self.freeze_norm, self.norm_decay) + setattr(self, 'stage{}'.format(stage_idx), stage) + + num_channels = self.stages_config["num_channels"][-1] + self.feat_channels = num_channels + + if self.use_head: + self.head_layer = IterativeHead(num_channels_pre_layer, 'bn', + self.freeze_norm, self.norm_decay) + + self.feat_channels = [num_channels[0]] + for i in range(1, len(num_channels)): + self.feat_channels.append(num_channels[i] // 2) + + self.init_weight() + + def init_weight(self): + if self.pretrained is not None: + utils.load_entire_model(self, self.pretrained) + + def _make_transition_layer(self, + num_channels_pre_layer, + num_channels_cur_layer, + freeze_norm=False, + norm_decay=0.): + num_branches_pre = len(num_channels_pre_layer) + num_branches_cur = len(num_channels_cur_layer) + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + Conv2d( + num_channels_pre_layer[i], + num_channels_pre_layer[i], + kernel_size=3, + stride=1, + padding=1, + groups=num_channels_pre_layer[i], + bias=False), + nn.BatchNorm2D(num_channels_pre_layer[i]), + Conv2d( + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=1, + stride=1, + padding=0, + bias=False, ), + nn.BatchNorm2D(num_channels_cur_layer[i]), + nn.ReLU())) + else: + transition_layers.append(None) + else: + conv_downsamples = [] + for j in range(i + 1 - num_branches_pre): + conv_downsamples.append( + nn.Sequential( + Conv2d( + num_channels_pre_layer[-1], + num_channels_pre_layer[-1], + groups=num_channels_pre_layer[-1], + kernel_size=3, + stride=2, + padding=1, + bias=False, ), + nn.BatchNorm2D(num_channels_pre_layer[-1]), + Conv2d( + num_channels_pre_layer[-1], + num_channels_cur_layer[i] + if j == i - num_branches_pre else + num_channels_pre_layer[-1], + kernel_size=1, + stride=1, + padding=0, + bias=False, ), + nn.BatchNorm2D(num_channels_cur_layer[i] + if j == i - num_branches_pre else + num_channels_pre_layer[-1]), + nn.ReLU())) + transition_layers.append(nn.Sequential(*conv_downsamples)) + return nn.LayerList(transition_layers) + + def _make_stage(self, + stages_config, + stage_idx, + in_channels, + multiscale_output, + freeze_norm=False, + norm_decay=0.): + num_modules = stages_config["num_modules"][stage_idx] + num_branches = stages_config["num_branches"][stage_idx] + num_blocks = stages_config["num_blocks"][stage_idx] + reduce_ratio = stages_config['reduce_ratios'][stage_idx] + module_type = stages_config['module_type'][stage_idx] + + modules = [] + for i in range(num_modules): + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + modules.append( + LiteHRNetModule( + num_branches, + num_blocks, + in_channels, + reduce_ratio, + module_type, + multiscale_output=reset_multiscale_output, + with_fuse=True, + freeze_norm=freeze_norm, + norm_decay=norm_decay)) + in_channels = modules[-1].in_channels + return nn.Sequential(*modules), in_channels + + def forward(self, x): + x = self.stem(x) + + y_list = [x] + for stage_idx in range(3): + x_list = [] + transition = getattr(self, 'transition{}'.format(stage_idx)) + for j in range(self.stages_config["num_branches"][stage_idx]): + if transition[j] is not None: + if j >= len(y_list): + x_list.append(transition[j](y_list[-1])) + else: + x_list.append(transition[j](y_list[j])) + else: + x_list.append(y_list[j]) + y_list = getattr(self, 'stage{}'.format(stage_idx))(x_list) + + if self.use_head: + y_list = self.head_layer(y_list) + + res = [] + for i, layer in enumerate(y_list): + if i == self.freeze_at: + layer.stop_gradient = True + if i in self.return_idx: + res.append(layer) + return res + + +@manager.BACKBONES.add_component +def Lite_HRNet_18(**kwargs): + model = LiteHRNet(network_type="lite_18", **kwargs) + return model + + +@manager.BACKBONES.add_component +def Lite_HRNet_30(**kwargs): + model = LiteHRNet(network_type="lite_30", **kwargs) + return model + + +@manager.BACKBONES.add_component +def Lite_HRNet_naive(**kwargs): + model = LiteHRNet(network_type="naive", **kwargs) + return model + + +@manager.BACKBONES.add_component +def Lite_HRNet_wider_naive(**kwargs): + model = LiteHRNet(network_type="wider_naive", **kwargs) + return model diff --git a/paddleseg/models/backbones/mobilenetv2.py b/paddleseg/models/backbones/mobilenetv2.py index 2efad66d3c..0405cee5a0 100644 --- a/paddleseg/models/backbones/mobilenetv2.py +++ b/paddleseg/models/backbones/mobilenetv2.py @@ -1,4 +1,4 @@ -# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle +from paddle import ParamAttr import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import Conv2D, BatchNorm, Linear, Dropout +from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D from paddleseg.cvlibs import manager from paddleseg import utils +__all__ = [ + "MobileNetV2_x0_25", + "MobileNetV2_x0_5", + "MobileNetV2_x0_75", + "MobileNetV2_x1_0", + "MobileNetV2_x1_5", + "MobileNetV2_x2_0", +] + -@manager.BACKBONES.add_component class MobileNetV2(nn.Layer): """ The MobileNetV2 implementation based on PaddlePaddle. @@ -29,69 +42,69 @@ class MobileNetV2(nn.Layer): (https://arxiv.org/abs/1801.04381). Args: - channel_ratio (float, optional): The ratio of channel. Default: 1.0 - min_channel (int, optional): The minimum of channel. Default: 16 + scale (float, optional): The scale of channel. Default: 1.0 pretrained (str, optional): The path or url of pretrained model. Default: None """ - def __init__(self, channel_ratio=1.0, min_channel=16, pretrained=None): - super(MobileNetV2, self).__init__() - self.channel_ratio = channel_ratio - self.min_channel = min_channel + def __init__(self, scale=1.0, pretrained=None): + super().__init__() + self.scale = scale self.pretrained = pretrained + prefix_name = "" - self.stage0 = conv_bn(3, self.depth(32), 3, 2) - - self.stage1 = InvertedResidual(self.depth(32), self.depth(16), 1, 1) - - self.stage2 = nn.Sequential( - InvertedResidual(self.depth(16), self.depth(24), 2, 6), - InvertedResidual(self.depth(24), self.depth(24), 1, 6), ) - - self.stage3 = nn.Sequential( - InvertedResidual(self.depth(24), self.depth(32), 2, 6), - InvertedResidual(self.depth(32), self.depth(32), 1, 6), - InvertedResidual(self.depth(32), self.depth(32), 1, 6), ) + bottleneck_params_list = [ + (1, 16, 1, 1), + (6, 24, 2, 2), # x4 + (6, 32, 3, 2), # x8 + (6, 64, 4, 2), + (6, 96, 3, 1), # x16 + (6, 160, 3, 2), + (6, 320, 1, 1), # x32 + ] + self.out_index = [1, 2, 4, 6] - self.stage4 = nn.Sequential( - InvertedResidual(self.depth(32), self.depth(64), 2, 6), - InvertedResidual(self.depth(64), self.depth(64), 1, 6), - InvertedResidual(self.depth(64), self.depth(64), 1, 6), - InvertedResidual(self.depth(64), self.depth(64), 1, 6), ) + self.conv1 = ConvBNLayer( + num_channels=3, + num_filters=int(32 * scale), + filter_size=3, + stride=2, + padding=1, + name=prefix_name + "conv1_1") - self.stage5 = nn.Sequential( - InvertedResidual(self.depth(64), self.depth(96), 1, 6), - InvertedResidual(self.depth(96), self.depth(96), 1, 6), - InvertedResidual(self.depth(96), self.depth(96), 1, 6), ) + self.block_list = [] + i = 1 + in_c = int(32 * scale) + for layer_setting in bottleneck_params_list: + t, c, n, s = layer_setting + i += 1 + block = self.add_sublayer( + prefix_name + "conv" + str(i), + sublayer=InvresiBlocks( + in_c=in_c, + t=t, + c=int(c * scale), + n=n, + s=s, + name=prefix_name + "conv" + str(i))) + self.block_list.append(block) + in_c = int(c * scale) - self.stage6 = nn.Sequential( - InvertedResidual(self.depth(96), self.depth(160), 2, 6), - InvertedResidual(self.depth(160), self.depth(160), 1, 6), - InvertedResidual(self.depth(160), self.depth(160), 1, 6), ) - - self.stage7 = InvertedResidual(self.depth(160), self.depth(320), 1, 6) + out_channels = [ + bottleneck_params_list[idx][1] for idx in self.out_index + ] + self.feat_channels = [int(c * scale) for c in out_channels] self.init_weight() - def depth(self, channels): - min_channel = min(channels, self.min_channel) - return max(min_channel, int(channels * self.channel_ratio)) - - def forward(self, x): + def forward(self, inputs): feat_list = [] - feature_1_2 = self.stage0(x) - feature_1_2 = self.stage1(feature_1_2) - feature_1_4 = self.stage2(feature_1_2) - feature_1_8 = self.stage3(feature_1_4) - feature_1_16 = self.stage4(feature_1_8) - feature_1_16 = self.stage5(feature_1_16) - feature_1_32 = self.stage6(feature_1_16) - feature_1_32 = self.stage7(feature_1_32) - feat_list.append(feature_1_4) - feat_list.append(feature_1_8) - feat_list.append(feature_1_16) - feat_list.append(feature_1_32) + y = self.conv1(inputs, if_act=True) + for idx, block in enumerate(self.block_list): + y = block(y) + if idx in self.out_index: + feat_list.append(y) + return feat_list def init_weight(self): @@ -99,66 +112,153 @@ def init_weight(self): utils.load_entire_model(self, self.pretrained) -def conv_bn(inp, oup, kernel, stride): - return nn.Sequential( - nn.Conv2D( - in_channels=inp, - out_channels=oup, - kernel_size=kernel, +class ConvBNLayer(nn.Layer): + def __init__(self, + num_channels, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + name=None, + use_cudnn=True): + super(ConvBNLayer, self).__init__() + + self._conv = Conv2D( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + weight_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + + self._batch_norm = BatchNorm( + num_filters, + param_attr=ParamAttr(name=name + "_bn_scale"), + bias_attr=ParamAttr(name=name + "_bn_offset"), + moving_mean_name=name + "_bn_mean", + moving_variance_name=name + "_bn_variance") + + def forward(self, inputs, if_act=True): + y = self._conv(inputs) + y = self._batch_norm(y) + if if_act: + y = F.relu6(y) + return y + + +class InvertedResidualUnit(nn.Layer): + def __init__(self, num_channels, num_in_filter, num_filters, stride, + filter_size, padding, expansion_factor, name): + super(InvertedResidualUnit, self).__init__() + num_expfilter = int(round(num_in_filter * expansion_factor)) + self._expand_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=num_expfilter, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + name=name + "_expand") + + self._bottleneck_conv = ConvBNLayer( + num_channels=num_expfilter, + num_filters=num_expfilter, + filter_size=filter_size, stride=stride, - padding=(kernel - 1) // 2, - bias_attr=False), - nn.BatchNorm2D( - num_features=oup, epsilon=1e-05, momentum=0.1), - nn.ReLU()) - - -class InvertedResidual(nn.Layer): - def __init__(self, inp, oup, stride, expand_ratio, dilation=1): - super(InvertedResidual, self).__init__() - self.stride = stride - assert stride in [1, 2] - self.use_res_connect = self.stride == 1 and inp == oup - - self.conv = nn.Sequential( - nn.Conv2D( - inp, - inp * expand_ratio, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - groups=1, - bias_attr=False), - nn.BatchNorm2D( - num_features=inp * expand_ratio, epsilon=1e-05, momentum=0.1), - nn.ReLU(), - nn.Conv2D( - inp * expand_ratio, - inp * expand_ratio, - kernel_size=3, - stride=stride, - padding=dilation, - dilation=dilation, - groups=inp * expand_ratio, - bias_attr=False), - nn.BatchNorm2D( - num_features=inp * expand_ratio, epsilon=1e-05, momentum=0.1), - nn.ReLU(), - nn.Conv2D( - inp * expand_ratio, - oup, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - groups=1, - bias_attr=False), - nn.BatchNorm2D( - num_features=oup, epsilon=1e-05, momentum=0.1), ) - - def forward(self, x): - if self.use_res_connect: - return x + self.conv(x) - else: - return self.conv(x) + padding=padding, + num_groups=num_expfilter, + use_cudnn=False, + name=name + "_dwise") + + self._linear_conv = ConvBNLayer( + num_channels=num_expfilter, + num_filters=num_filters, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + name=name + "_linear") + + def forward(self, inputs, ifshortcut): + y = self._expand_conv(inputs, if_act=True) + y = self._bottleneck_conv(y, if_act=True) + y = self._linear_conv(y, if_act=False) + if ifshortcut: + y = paddle.add(inputs, y) + return y + + +class InvresiBlocks(nn.Layer): + def __init__(self, in_c, t, c, n, s, name): + super(InvresiBlocks, self).__init__() + + self._first_block = InvertedResidualUnit( + num_channels=in_c, + num_in_filter=in_c, + num_filters=c, + stride=s, + filter_size=3, + padding=1, + expansion_factor=t, + name=name + "_1") + + self._block_list = [] + for i in range(1, n): + block = self.add_sublayer( + name + "_" + str(i + 1), + sublayer=InvertedResidualUnit( + num_channels=c, + num_in_filter=c, + num_filters=c, + stride=1, + filter_size=3, + padding=1, + expansion_factor=t, + name=name + "_" + str(i + 1))) + self._block_list.append(block) + + def forward(self, inputs): + y = self._first_block(inputs, ifshortcut=False) + for block in self._block_list: + y = block(y, ifshortcut=True) + return y + + +@manager.BACKBONES.add_component +def MobileNetV2_x0_25(**kwargs): + model = MobileNetV2(scale=0.25, **kwargs) + return model + + +@manager.BACKBONES.add_component +def MobileNetV2_x0_5(**kwargs): + model = MobileNetV2(scale=0.5, **kwargs) + return model + + +@manager.BACKBONES.add_component +def MobileNetV2_x0_75(**kwargs): + model = MobileNetV2(scale=0.75, **kwargs) + return model + + +@manager.BACKBONES.add_component +def MobileNetV2_x1_0(**kwargs): + model = MobileNetV2(scale=1.0, **kwargs) + return model + + +@manager.BACKBONES.add_component +def MobileNetV2_x1_5(**kwargs): + model = MobileNetV2(scale=1.5, **kwargs) + return model + + +@manager.BACKBONES.add_component +def MobileNetV2_x2_0(**kwargs): + model = MobileNetV2(scale=2.0, **kwargs) + return model diff --git a/paddleseg/models/backbones/mobilenetv3.py b/paddleseg/models/backbones/mobilenetv3.py index 1c19d25fb5..37fc7cc867 100644 --- a/paddleseg/models/backbones/mobilenetv3.py +++ b/paddleseg/models/backbones/mobilenetv3.py @@ -14,10 +14,12 @@ import paddle import paddle.nn as nn -import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.regularizer import L2Decay +from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear from paddleseg.cvlibs import manager -from paddleseg.utils import utils +from paddleseg.utils import utils, logger from paddleseg.models import layers __all__ = [ @@ -28,8 +30,59 @@ "MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25" ] - -def make_divisible(v, divisor=8, min_value=None): +MODEL_STAGES_PATTERN = { + "MobileNetV3_small": ["blocks[0]", "blocks[2]", "blocks[7]", "blocks[10]"], + "MobileNetV3_large": + ["blocks[0]", "blocks[2]", "blocks[5]", "blocks[11]", "blocks[14]"] +} + +# "large", "small" is just for MobinetV3_large, MobileNetV3_small respectively. +# The type of "large" or "small" config is a list. Each element(list) represents a depthwise block, which is composed of k, exp, se, act, s. +# k: kernel_size +# exp: middle channel number in depthwise block +# c: output channel number in depthwise block +# se: whether to use SE block +# act: which activation to use +# s: stride in depthwise block +NET_CONFIG = { + "large": [ + # k, exp, c, se, act, s + [3, 16, 16, False, "relu", 1], + [3, 64, 24, False, "relu", 2], + [3, 72, 24, False, "relu", 1], # x4 + [5, 72, 40, True, "relu", 2], + [5, 120, 40, True, "relu", 1], + [5, 120, 40, True, "relu", 1], # x8 + [3, 240, 80, False, "hardswish", 2], + [3, 200, 80, False, "hardswish", 1], + [3, 184, 80, False, "hardswish", 1], + [3, 184, 80, False, "hardswish", 1], + [3, 480, 112, True, "hardswish", 1], + [3, 672, 112, True, "hardswish", 1], # x16 + [5, 672, 160, True, "hardswish", 2], + [5, 960, 160, True, "hardswish", 1], + [5, 960, 160, True, "hardswish", 1], # x32 + ], + "small": [ + # k, exp, c, se, act, s + [3, 16, 16, True, "relu", 2], + [3, 72, 24, False, "relu", 2], + [3, 88, 24, False, "relu", 1], + [5, 96, 40, True, "hardswish", 2], + [5, 240, 40, True, "hardswish", 1], + [5, 240, 40, True, "hardswish", 1], + [5, 120, 48, True, "hardswish", 1], + [5, 144, 48, True, "hardswish", 1], + [5, 288, 96, True, "hardswish", 2], + [5, 576, 96, True, "hardswish", 1], + [5, 576, 96, True, "hardswish", 1], + ] +} + +OUT_INDEX = {"large": [2, 5, 11, 14], "small": [0, 2, 7, 10]} + + +def _make_divisible(v, divisor=8, min_value=None): if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) @@ -38,156 +91,109 @@ def make_divisible(v, divisor=8, min_value=None): return new_v -class MobileNetV3(nn.Layer): - """ - The MobileNetV3 implementation based on PaddlePaddle. +def _create_act(act): + if act == "hardswish": + return nn.Hardswish() + elif act == "relu": + return nn.ReLU() + elif act is None: + return None + else: + raise RuntimeError( + "The activation function is not supported: {}".format(act)) - The original article refers to Jingdong - Andrew Howard, et, al. "Searching for MobileNetV3" - (https://arxiv.org/pdf/1905.02244.pdf). +class MobileNetV3(nn.Layer): + """ + MobileNetV3 Args: - pretrained (str, optional): The path of pretrained model. - scale (float, optional): The scale of channels . Default: 1.0. - model_name (str, optional): Model name. It determines the type of MobileNetV3. The value is 'small' or 'large'. Defualt: 'small'. - output_stride (int, optional): The stride of output features compared to input images. The value should be one of (2, 4, 8, 16, 32). Default: None. - + config: list. MobileNetV3 depthwise blocks config. + scale: float=1.0. The coefficient that controls the size of network parameters. + Returns: + model: nn.Layer. Specific MobileNetV3 model depends on args. """ def __init__(self, - pretrained=None, + config, + stages_pattern, + out_index, scale=1.0, - model_name="small", - output_stride=None): - super(MobileNetV3, self).__init__() + pretrained=None): + super().__init__() + self.cfg = config + self.out_index = out_index + self.scale = scale + self.pretrained = pretrained inplanes = 16 - if model_name == "large": - self.cfg = [ - # k, exp, c, se, nl, s, - [3, 16, 16, False, "relu", 1], - [3, 64, 24, False, "relu", 2], - [3, 72, 24, False, "relu", 1], # output 1 -> out_index=2 - [5, 72, 40, True, "relu", 2], - [5, 120, 40, True, "relu", 1], - [5, 120, 40, True, "relu", 1], # output 2 -> out_index=5 - [3, 240, 80, False, "hard_swish", 2], - [3, 200, 80, False, "hard_swish", 1], - [3, 184, 80, False, "hard_swish", 1], - [3, 184, 80, False, "hard_swish", 1], - [3, 480, 112, True, "hard_swish", 1], - [3, 672, 112, True, "hard_swish", - 1], # output 3 -> out_index=11 - [5, 672, 160, True, "hard_swish", 2], - [5, 960, 160, True, "hard_swish", 1], - [5, 960, 160, True, "hard_swish", - 1], # output 3 -> out_index=14 - ] - self.out_indices = [2, 5, 11, 14] - self.feat_channels = [ - make_divisible(i * scale) for i in [24, 40, 112, 160] - ] - - self.cls_ch_squeeze = 960 - self.cls_ch_expand = 1280 - elif model_name == "small": - self.cfg = [ - # k, exp, c, se, nl, s, - [3, 16, 16, True, "relu", 2], # output 1 -> out_index=0 - [3, 72, 24, False, "relu", 2], - [3, 88, 24, False, "relu", 1], # output 2 -> out_index=3 - [5, 96, 40, True, "hard_swish", 2], - [5, 240, 40, True, "hard_swish", 1], - [5, 240, 40, True, "hard_swish", 1], - [5, 120, 48, True, "hard_swish", 1], - [5, 144, 48, True, "hard_swish", 1], # output 3 -> out_index=7 - [5, 288, 96, True, "hard_swish", 2], - [5, 576, 96, True, "hard_swish", 1], - [5, 576, 96, True, "hard_swish", 1], # output 4 -> out_index=10 - ] - self.out_indices = [0, 3, 7, 10] - self.feat_channels = [ - make_divisible(i * scale) for i in [16, 24, 48, 96] - ] - - self.cls_ch_squeeze = 576 - self.cls_ch_expand = 1280 - else: - raise NotImplementedError( - "mode[{}_model] is not implemented!".format(model_name)) - - ################################################### - # modify stride and dilation based on output_stride - self.dilation_cfg = [1] * len(self.cfg) - self.modify_bottle_params(output_stride=output_stride) - ################################################### - - self.conv1 = ConvBNLayer( + + self.conv = ConvBNLayer( in_c=3, - out_c=make_divisible(inplanes * scale), + out_c=_make_divisible(inplanes * self.scale), filter_size=3, stride=2, padding=1, num_groups=1, if_act=True, - act="hard_swish") - - self.block_list = [] - - inplanes = make_divisible(inplanes * scale) - for i, (k, exp, c, se, nl, s) in enumerate(self.cfg): - ###################################### - # add dilation rate - dilation_rate = self.dilation_cfg[i] - ###################################### - self.block_list.append( - ResidualUnit( - in_c=inplanes, - mid_c=make_divisible(scale * exp), - out_c=make_divisible(scale * c), - filter_size=k, - stride=s, - dilation=dilation_rate, - use_se=se, - act=nl, - name="conv" + str(i + 2))) - self.add_sublayer( - sublayer=self.block_list[-1], name="conv" + str(i + 2)) - inplanes = make_divisible(scale * c) - - self.pretrained = pretrained + act="hardswish") + self.blocks = nn.Sequential(*[ + ResidualUnit( + in_c=_make_divisible(inplanes * self.scale if i == 0 else + self.cfg[i - 1][2] * self.scale), + mid_c=_make_divisible(self.scale * exp), + out_c=_make_divisible(self.scale * c), + filter_size=k, + stride=s, + use_se=se, + act=act) for i, (k, exp, c, se, act, s) in enumerate(self.cfg) + ]) + + out_channels = [config[idx][2] for idx in self.out_index] + self.feat_channels = [ + _make_divisible(self.scale * c) for c in out_channels + ] + + self.init_res(stages_pattern) self.init_weight() - def modify_bottle_params(self, output_stride=None): - - if output_stride is not None and output_stride % 2 != 0: - raise ValueError("output stride must to be even number") - if output_stride is not None: - stride = 2 - rate = 1 - for i, _cfg in enumerate(self.cfg): - stride = stride * _cfg[-1] - if stride > output_stride: - rate = rate * _cfg[-1] - self.cfg[i][-1] = 1 + def init_weight(self): + if self.pretrained is not None: + utils.load_entire_model(self, self.pretrained) + + def init_res(self, stages_pattern, return_patterns=None, + return_stages=None): + if return_patterns and return_stages: + msg = f"The 'return_patterns' would be ignored when 'return_stages' is set." + logger.warning(msg) + return_stages = None + + if return_stages is True: + return_patterns = stages_pattern + # return_stages is int or bool + if type(return_stages) is int: + return_stages = [return_stages] + if isinstance(return_stages, list): + if max(return_stages) > len(stages_pattern) or min( + return_stages) < 0: + msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}." + logger.warning(msg) + return_stages = [ + val for val in return_stages + if val >= 0 and val < len(stages_pattern) + ] + return_patterns = [stages_pattern[i] for i in return_stages] - self.dilation_cfg[i] = rate + def forward(self, x): + x = self.conv(x) - def forward(self, inputs, label=None): - x = self.conv1(inputs) - # A feature list saves each downsampling feature. feat_list = [] - for i, block in enumerate(self.block_list): + for idx, block in enumerate(self.blocks): x = block(x) - if i in self.out_indices: + if idx in self.out_index: feat_list.append(x) return feat_list - def init_weight(self): - if self.pretrained is not None: - utils.load_pretrained_model(self, self.pretrained) - class ConvBNLayer(nn.Layer): def __init__(self, @@ -196,36 +202,32 @@ def __init__(self, filter_size, stride, padding, - dilation=1, num_groups=1, if_act=True, act=None): - super(ConvBNLayer, self).__init__() - self.if_act = if_act - self.act = act + super().__init__() - self.conv = nn.Conv2D( + self.conv = Conv2D( in_channels=in_c, out_channels=out_c, kernel_size=filter_size, stride=stride, padding=padding, - dilation=dilation, groups=num_groups, bias_attr=False) - self.bn = layers.SyncBatchNorm( - num_features=out_c, - weight_attr=paddle.ParamAttr( - regularizer=paddle.regularizer.L2Decay(0.0)), - bias_attr=paddle.ParamAttr( - regularizer=paddle.regularizer.L2Decay(0.0))) - self._act_op = layers.Activation(act='hardswish') + self.bn = BatchNorm( + num_channels=out_c, + act=None, + param_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self.if_act = if_act + self.act = _create_act(act) def forward(self, x): x = self.conv(x) x = self.bn(x) if self.if_act: - x = self._act_op(x) + x = self.act(x) return x @@ -237,10 +239,8 @@ def __init__(self, filter_size, stride, use_se, - dilation=1, - act=None, - name=''): - super(ResidualUnit, self).__init__() + act=None): + super().__init__() self.if_shortcut = stride == 1 and in_c == out_c self.if_se = use_se @@ -252,19 +252,17 @@ def __init__(self, padding=0, if_act=True, act=act) - self.bottleneck_conv = ConvBNLayer( in_c=mid_c, out_c=mid_c, filter_size=filter_size, stride=stride, - padding='same', - dilation=dilation, + padding=int((filter_size - 1) // 2), num_groups=mid_c, if_act=True, act=act) if self.if_se: - self.mid_se = SEModule(mid_c, name=name + "_se") + self.mid_se = SEModule(mid_c) self.linear_conv = ConvBNLayer( in_c=mid_c, out_c=out_c, @@ -273,92 +271,165 @@ def __init__(self, padding=0, if_act=False, act=None) - self.dilation = dilation - def forward(self, inputs): - x = self.expand_conv(inputs) + def forward(self, x): + identity = x + x = self.expand_conv(x) x = self.bottleneck_conv(x) if self.if_se: x = self.mid_se(x) x = self.linear_conv(x) if self.if_shortcut: - x = inputs + x + x = paddle.add(identity, x) return x +# nn.Hardsigmoid can't transfer "slope" and "offset" in nn.functional.hardsigmoid +class Hardsigmoid(nn.Layer): + def __init__(self, slope=0.2, offset=0.5): + super().__init__() + self.slope = slope + self.offset = offset + + def forward(self, x): + return nn.functional.hardsigmoid( + x, slope=self.slope, offset=self.offset) + + class SEModule(nn.Layer): - def __init__(self, channel, reduction=4, name=""): - super(SEModule, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2D(1) - self.conv1 = nn.Conv2D( + def __init__(self, channel, reduction=4): + super().__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv1 = Conv2D( in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1, padding=0) - self.conv2 = nn.Conv2D( + self.relu = nn.ReLU() + self.conv2 = Conv2D( in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1, padding=0) + self.hardsigmoid = Hardsigmoid(slope=0.2, offset=0.5) - def forward(self, inputs): - outputs = self.avg_pool(inputs) - outputs = self.conv1(outputs) - outputs = F.relu(outputs) - outputs = self.conv2(outputs) - outputs = F.hardsigmoid(outputs) - return paddle.multiply(x=inputs, y=outputs) + def forward(self, x): + identity = x + x = self.avg_pool(x) + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.hardsigmoid(x) + return paddle.multiply(x=identity, y=x) +@manager.BACKBONES.add_component def MobileNetV3_small_x0_35(**kwargs): - model = MobileNetV3(model_name="small", scale=0.35, **kwargs) + model = MobileNetV3( + config=NET_CONFIG["small"], + scale=0.35, + stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"], + out_index=OUT_INDEX["small"], + **kwargs) return model +@manager.BACKBONES.add_component def MobileNetV3_small_x0_5(**kwargs): - model = MobileNetV3(model_name="small", scale=0.5, **kwargs) + model = MobileNetV3( + config=NET_CONFIG["small"], + scale=0.5, + stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"], + out_index=OUT_INDEX["small"], + **kwargs) return model +@manager.BACKBONES.add_component def MobileNetV3_small_x0_75(**kwargs): - model = MobileNetV3(model_name="small", scale=0.75, **kwargs) + model = MobileNetV3( + config=NET_CONFIG["small"], + scale=0.75, + stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"], + out_index=OUT_INDEX["small"], + **kwargs) return model @manager.BACKBONES.add_component def MobileNetV3_small_x1_0(**kwargs): - model = MobileNetV3(model_name="small", scale=1.0, **kwargs) + model = MobileNetV3( + config=NET_CONFIG["small"], + scale=1.0, + stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"], + out_index=OUT_INDEX["small"], + **kwargs) return model +@manager.BACKBONES.add_component def MobileNetV3_small_x1_25(**kwargs): - model = MobileNetV3(model_name="small", scale=1.25, **kwargs) + model = MobileNetV3( + config=NET_CONFIG["small"], + scale=1.25, + stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"], + out_index=OUT_INDEX["small"], + **kwargs) return model +@manager.BACKBONES.add_component def MobileNetV3_large_x0_35(**kwargs): - model = MobileNetV3(model_name="large", scale=0.35, **kwargs) + model = MobileNetV3( + config=NET_CONFIG["large"], + scale=0.35, + stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"], + out_index=OUT_INDEX["large"], + **kwargs) return model +@manager.BACKBONES.add_component def MobileNetV3_large_x0_5(**kwargs): - model = MobileNetV3(model_name="large", scale=0.5, **kwargs) + model = MobileNetV3( + config=NET_CONFIG["large"], + scale=0.5, + stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"], + out_index=OUT_INDEX["large"], + **kwargs) return model +@manager.BACKBONES.add_component def MobileNetV3_large_x0_75(**kwargs): - model = MobileNetV3(model_name="large", scale=0.75, **kwargs) + model = MobileNetV3( + config=NET_CONFIG["large"], + scale=0.75, + stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"], + out_index=OUT_INDEX["large"], + **kwargs) return model @manager.BACKBONES.add_component def MobileNetV3_large_x1_0(**kwargs): - model = MobileNetV3(model_name="large", scale=1.0, **kwargs) + model = MobileNetV3( + config=NET_CONFIG["large"], + scale=1.0, + stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"], + out_index=OUT_INDEX["large"], + **kwargs) return model +@manager.BACKBONES.add_component def MobileNetV3_large_x1_25(**kwargs): - model = MobileNetV3(model_name="large", scale=1.25, **kwargs) + model = MobileNetV3( + config=NET_CONFIG["large"], + scale=1.25, + stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"], + out_index=OUT_INDEX["large"], + **kwargs) return model diff --git a/paddleseg/models/backbones/shufflenetv2.py b/paddleseg/models/backbones/shufflenetv2.py new file mode 100644 index 0000000000..a4c4ae6b18 --- /dev/null +++ b/paddleseg/models/backbones/shufflenetv2.py @@ -0,0 +1,315 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import ParamAttr, reshape, transpose, concat, split +from paddle.nn import Layer, Conv2D, MaxPool2D, AdaptiveAvgPool2D, BatchNorm, Linear +from paddle.nn.initializer import KaimingNormal +from paddle.nn.functional import swish + +from paddleseg.cvlibs import manager +from paddleseg.utils import utils, logger + +__all__ = [ + 'ShuffleNetV2_x0_25', 'ShuffleNetV2_x0_33', 'ShuffleNetV2_x0_5', + 'ShuffleNetV2_x1_0', 'ShuffleNetV2_x1_5', 'ShuffleNetV2_x2_0', + 'ShuffleNetV2_swish' +] + + +def channel_shuffle(x, groups): + x_shape = paddle.shape(x) + batch_size, height, width = x_shape[0], x_shape[2], x_shape[3] + num_channels = x.shape[1] + channels_per_group = num_channels // groups + + # reshape + x = reshape( + x=x, shape=[batch_size, groups, channels_per_group, height, width]) + + # transpose + x = transpose(x=x, perm=[0, 2, 1, 3, 4]) + + # flatten + x = reshape(x=x, shape=[batch_size, num_channels, height, width]) + + return x + + +class ConvBNLayer(Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups=1, + act=None, + name=None, ): + super(ConvBNLayer, self).__init__() + self._conv = Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr( + initializer=KaimingNormal(), name=name + "_weights"), + bias_attr=False) + + self._batch_norm = BatchNorm( + out_channels, + param_attr=ParamAttr(name=name + "_bn_scale"), + bias_attr=ParamAttr(name=name + "_bn_offset"), + act=act, + moving_mean_name=name + "_bn_mean", + moving_variance_name=name + "_bn_variance") + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class InvertedResidual(Layer): + def __init__(self, in_channels, out_channels, stride, act="relu", + name=None): + super(InvertedResidual, self).__init__() + self._conv_pw = ConvBNLayer( + in_channels=in_channels // 2, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act, + name='stage_' + name + '_conv1') + self._conv_dw = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=3, + stride=stride, + padding=1, + groups=out_channels // 2, + act=None, + name='stage_' + name + '_conv2') + self._conv_linear = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act, + name='stage_' + name + '_conv3') + + def forward(self, inputs): + x1, x2 = split( + inputs, + num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2], + axis=1) + x2 = self._conv_pw(x2) + x2 = self._conv_dw(x2) + x2 = self._conv_linear(x2) + out = concat([x1, x2], axis=1) + return channel_shuffle(out, 2) + + +class InvertedResidualDS(Layer): + def __init__(self, in_channels, out_channels, stride, act="relu", + name=None): + super(InvertedResidualDS, self).__init__() + + # branch1 + self._conv_dw_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=stride, + padding=1, + groups=in_channels, + act=None, + name='stage_' + name + '_conv4') + self._conv_linear_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act, + name='stage_' + name + '_conv5') + # branch2 + self._conv_pw_2 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act, + name='stage_' + name + '_conv1') + self._conv_dw_2 = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=3, + stride=stride, + padding=1, + groups=out_channels // 2, + act=None, + name='stage_' + name + '_conv2') + self._conv_linear_2 = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act, + name='stage_' + name + '_conv3') + + def forward(self, inputs): + x1 = self._conv_dw_1(inputs) + x1 = self._conv_linear_1(x1) + x2 = self._conv_pw_2(inputs) + x2 = self._conv_dw_2(x2) + x2 = self._conv_linear_2(x2) + out = concat([x1, x2], axis=1) + + return channel_shuffle(out, 2) + + +class ShuffleNet(Layer): + def __init__(self, scale=1.0, act="relu", pretrained=None): + super(ShuffleNet, self).__init__() + self.scale = scale + self.pretrained = pretrained + stage_repeats = [4, 8, 4] + + if scale == 0.25: + stage_out_channels = [-1, 24, 24, 48, 96, 512] + elif scale == 0.33: + stage_out_channels = [-1, 24, 32, 64, 128, 512] + elif scale == 0.5: + stage_out_channels = [-1, 24, 48, 96, 192, 1024] + elif scale == 1.0: + stage_out_channels = [-1, 24, 116, 232, 464, 1024] + elif scale == 1.5: + stage_out_channels = [-1, 24, 176, 352, 704, 1024] + elif scale == 2.0: + stage_out_channels = [-1, 24, 224, 488, 976, 2048] + else: + raise NotImplementedError("This scale size:[" + str(scale) + + "] is not implemented!") + + self.out_index = [3, 11, 15] + self.feat_channels = stage_out_channels[1:5] + + # 1. conv1 + self._conv1 = ConvBNLayer( + in_channels=3, + out_channels=stage_out_channels[1], + kernel_size=3, + stride=2, + padding=1, + act=act, + name='stage1_conv') + self._max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1) + + # 2. bottleneck sequences + self._block_list = [] + for stage_id, num_repeat in enumerate(stage_repeats): + for i in range(num_repeat): + if i == 0: + block = self.add_sublayer( + name=str(stage_id + 2) + '_' + str(i + 1), + sublayer=InvertedResidualDS( + in_channels=stage_out_channels[stage_id + 1], + out_channels=stage_out_channels[stage_id + 2], + stride=2, + act=act, + name=str(stage_id + 2) + '_' + str(i + 1))) + else: + block = self.add_sublayer( + name=str(stage_id + 2) + '_' + str(i + 1), + sublayer=InvertedResidual( + in_channels=stage_out_channels[stage_id + 2], + out_channels=stage_out_channels[stage_id + 2], + stride=1, + act=act, + name=str(stage_id + 2) + '_' + str(i + 1))) + self._block_list.append(block) + + self.init_weight() + + def init_weight(self): + if self.pretrained is not None: + utils.load_entire_model(self, self.pretrained) + + def forward(self, inputs): + feat_list = [] + + y = self._conv1(inputs) + y = self._max_pool(y) + feat_list.append(y) + + for idx, inv in enumerate(self._block_list): + y = inv(y) + if idx in self.out_index: + feat_list.append(y) + return feat_list + + +@manager.BACKBONES.add_component +def ShuffleNetV2_x0_25(**kwargs): + model = ShuffleNet(scale=0.25, **kwargs) + return model + + +@manager.BACKBONES.add_component +def ShuffleNetV2_x0_33(**kwargs): + model = ShuffleNet(scale=0.33, **kwargs) + return model + + +@manager.BACKBONES.add_component +def ShuffleNetV2_x0_5(**kwargs): + model = ShuffleNet(scale=0.5, **kwargs) + return model + + +@manager.BACKBONES.add_component +def ShuffleNetV2_x1_0(**kwargs): + model = ShuffleNet(scale=1.0, **kwargs) + return model + + +@manager.BACKBONES.add_component +def ShuffleNetV2_x1_5(**kwargs): + model = ShuffleNet(scale=1.5, **kwargs) + return model + + +@manager.BACKBONES.add_component +def ShuffleNetV2_x2_0(**kwargs): + model = ShuffleNet(scale=2.0, **kwargs) + return model + + +@manager.BACKBONES.add_component +def ShuffleNetV2_swish(**kwargs): + model = ShuffleNet(scale=1.0, act="swish", **kwargs) + return model diff --git a/paddleseg/models/layers/__init__.py b/paddleseg/models/layers/__init__.py index 5c32fc06d5..9429834afe 100644 --- a/paddleseg/models/layers/__init__.py +++ b/paddleseg/models/layers/__init__.py @@ -18,4 +18,4 @@ from .attention import AttentionBlock from .nonlocal2d import NonLocal2D from .wrap_functions import * -from .tensor_fusion import UAFM_SpAtten, UAFM_SpAtten_S, UAFM_ChAtten, UAFM_ChAtten_S +from .tensor_fusion import UAFM_SpAtten, UAFM_SpAtten_S, UAFM_ChAtten, UAFM_ChAtten_S, UAFM, UAFMMobile diff --git a/paddleseg/models/layers/tensor_fusion.py b/paddleseg/models/layers/tensor_fusion.py index 1bac3c7b94..a50a96b8ad 100644 --- a/paddleseg/models/layers/tensor_fusion.py +++ b/paddleseg/models/layers/tensor_fusion.py @@ -218,3 +218,23 @@ def fuse(self, x, y): out = x * atten + y * (1 - atten) out = self.conv_out(out) return out + + +class UAFMMobile(UAFM): + """ + Unified Attention Fusion Module for mobile. + Args: + x_ch (int): The channel of x tensor, which is the low level feature. + y_ch (int): The channel of y tensor, which is the high level feature. + out_ch (int): The channel of output tensor. + ksize (int, optional): The kernel size of the conv for x tensor. Default: 3. + resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear. + """ + + def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='bilinear'): + super().__init__(x_ch, y_ch, out_ch, ksize, resize_mode) + + self.conv_x = layers.SeparableConvBNReLU( + x_ch, y_ch, kernel_size=ksize, padding=ksize // 2, bias_attr=False) + self.conv_out = layers.SeparableConvBNReLU( + y_ch, out_ch, kernel_size=3, padding=1, bias_attr=False) diff --git a/paddleseg/models/mobileseg.py b/paddleseg/models/mobileseg.py new file mode 100644 index 0000000000..9d60863d83 --- /dev/null +++ b/paddleseg/models/mobileseg.py @@ -0,0 +1,258 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleseg import utils +from paddleseg.models import layers +from paddleseg.cvlibs import manager + + +@manager.MODELS.add_component +class MobileSeg(nn.Layer): + """ + The semantic segmentation models for mobile devices. + + Args: + num_classes (int): The number of target classes. + backbone(nn.Layer): Backbone network, such as stdc1net and resnet18. The backbone must + has feat_channels, of which the length is 5. + backbone_indices (List(int), optional): The values indicate the indices of output of backbone. + Default: [2, 3, 4]. + cm_bin_sizes (List(int), optional): The bin size of context module. Default: [1,2,4]. + cm_out_ch (int, optional): The output channel of the last context module. Default: 128. + arm_type (str, optional): The type of attention refinement module. Default: ARM_Add_SpAttenAdd3. + arm_out_chs (List(int), optional): The out channels of each arm module. Default: [64, 96, 128]. + seg_head_inter_chs (List(int), optional): The intermediate channels of segmentation head. + Default: [64, 64, 64]. + resize_mode (str, optional): The resize mode for the upsampling operation in decoder. + Default: bilinear. + pretrained (str, optional): The path or url of pretrained model. Default: None. + """ + + def __init__(self, + num_classes, + backbone, + backbone_indices=[1, 2, 3], + cm_bin_sizes=[1, 2], + cm_out_ch=64, + arm_type='UAFMMobile', + arm_out_chs=[32, 48, 64], + seg_head_inter_chs=[32, 32, 32], + resize_mode='bilinear', + pretrained=None): + super().__init__() + + # backbone + assert hasattr(backbone, 'feat_channels'), \ + "The backbone should has feat_channels." + assert len(backbone.feat_channels) >= len(backbone_indices), \ + f"The length of input backbone_indices ({len(backbone_indices)}) should not be" \ + f"greater than the length of feat_channels ({len(backbone.feat_channels)})." + assert len(backbone.feat_channels) > max(backbone_indices), \ + f"The max value ({max(backbone_indices)}) of backbone_indices should be " \ + f"less than the length of feat_channels ({len(backbone.feat_channels)})." + self.backbone = backbone + + assert len(backbone_indices) >= 1, "The lenght of backbone_indices " \ + "should not be lesser than 1" + self.backbone_indices = backbone_indices # [..., x16_id, x32_id] + backbone_out_chs = [backbone.feat_channels[i] for i in backbone_indices] + + # head + if len(arm_out_chs) == 1: + arm_out_chs = arm_out_chs * len(backbone_indices) + assert len(arm_out_chs) == len(backbone_indices), "The length of " \ + "arm_out_chs and backbone_indices should be equal" + + self.ppseg_head = MobileSegHead(backbone_out_chs, arm_out_chs, + cm_bin_sizes, cm_out_ch, arm_type, + resize_mode) + + if len(seg_head_inter_chs) == 1: + seg_head_inter_chs = seg_head_inter_chs * len(backbone_indices) + assert len(seg_head_inter_chs) == len(backbone_indices), "The length of " \ + "seg_head_inter_chs and backbone_indices should be equal" + self.seg_heads = nn.LayerList() # [..., head_16, head32] + for in_ch, mid_ch in zip(arm_out_chs, seg_head_inter_chs): + self.seg_heads.append(SegHead(in_ch, mid_ch, num_classes)) + + # pretrained + self.pretrained = pretrained + self.init_weight() + + def forward(self, x): + x_hw = paddle.shape(x)[2:] + + feats_backbone = self.backbone(x) # [x4, x8, x16, x32] + assert len(feats_backbone) >= len(self.backbone_indices), \ + f"The nums of backbone feats ({len(feats_backbone)}) should be greater or " \ + f"equal than the nums of backbone_indices ({len(self.backbone_indices)})" + + feats_selected = [feats_backbone[i] for i in self.backbone_indices] + feats_head = self.ppseg_head(feats_selected) # [..., x8, x16, x32] + + if self.training: + logit_list = [] + for x, seg_head in zip(feats_head, self.seg_heads): + x = seg_head(x) + logit_list.append(x) + logit_list = [ + F.interpolate( + x, x_hw, mode='bilinear', align_corners=False) + for x in logit_list + ] + else: + x = self.seg_heads[0](feats_head[0]) + x = F.interpolate(x, x_hw, mode='bilinear', align_corners=False) + logit_list = [x] + + return logit_list + + def init_weight(self): + if self.pretrained is not None: + utils.load_entire_model(self, self.pretrained) + + +class MobileSegHead(nn.Layer): + """ + The head of MobileSeg. + + Args: + backbone_out_chs (List(Tensor)): The channels of output tensors in the backbone. + arm_out_chs (List(int)): The out channels of each arm module. + cm_bin_sizes (List(int)): The bin size of context module. + cm_out_ch (int): The output channel of the last context module. + arm_type (str): The type of attention refinement module. + resize_mode (str): The resize mode for the upsampling operation in decoder. + """ + + def __init__(self, backbone_out_chs, arm_out_chs, cm_bin_sizes, cm_out_ch, + arm_type, resize_mode): + super().__init__() + + self.cm = MobileContextModule(backbone_out_chs[-1], cm_out_ch, + cm_out_ch, cm_bin_sizes) + + assert hasattr(layers,arm_type), \ + "Not support arm_type ({})".format(arm_type) + arm_class = eval("layers." + arm_type) + + self.arm_list = nn.LayerList() # [..., arm8, arm16, arm32] + for i in range(len(backbone_out_chs)): + low_chs = backbone_out_chs[i] + high_ch = cm_out_ch if i == len( + backbone_out_chs) - 1 else arm_out_chs[i + 1] + out_ch = arm_out_chs[i] + arm = arm_class( + low_chs, high_ch, out_ch, ksize=3, resize_mode=resize_mode) + self.arm_list.append(arm) + + def forward(self, in_feat_list): + """ + Args: + in_feat_list (List(Tensor)): Such as [x2, x4, x8, x16, x32]. + x2, x4 and x8 are optional. + Returns: + out_feat_list (List(Tensor)): Such as [x2, x4, x8, x16, x32]. + x2, x4 and x8 are optional. + The length of in_feat_list and out_feat_list are the same. + """ + + high_feat = self.cm(in_feat_list[-1]) + out_feat_list = [] + + for i in reversed(range(len(in_feat_list))): + low_feat = in_feat_list[i] + arm = self.arm_list[i] + high_feat = arm(low_feat, high_feat) + out_feat_list.insert(0, high_feat) + + return out_feat_list + + +class MobileContextModule(nn.Layer): + """ + Context Module for Mobile Model. + + Args: + in_channels (int): The number of input channels to pyramid pooling module. + inter_channels (int): The number of inter channels to pyramid pooling module. + out_channels (int): The number of output channels after pyramid pooling module. + bin_sizes (tuple, optional): The out size of pooled feature maps. Default: (1, 3). + align_corners (bool): An argument of F.interpolate. It should be set to False + when the output size of feature is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. + """ + + def __init__(self, + in_channels, + inter_channels, + out_channels, + bin_sizes, + align_corners=False): + super().__init__() + + self.stages = nn.LayerList([ + self._make_stage(in_channels, inter_channels, size) + for size in bin_sizes + ]) + + self.conv_out = layers.SeparableConvBNReLU( + in_channels=inter_channels, + out_channels=out_channels, + kernel_size=3, + bias_attr=False) + + self.align_corners = align_corners + + def _make_stage(self, in_channels, out_channels, size): + prior = nn.AdaptiveAvgPool2D(output_size=size) + conv = layers.ConvBNReLU( + in_channels=in_channels, out_channels=out_channels, kernel_size=1) + return nn.Sequential(prior, conv) + + def forward(self, input): + out = None + input_shape = paddle.shape(input)[2:] + + for stage in self.stages: + x = stage(input) + x = F.interpolate( + x, + input_shape, + mode='bilinear', + align_corners=self.align_corners) + if out is None: + out = x + else: + out += x + + out = self.conv_out(out) + return out + + +class SegHead(nn.Layer): + def __init__(self, in_chan, mid_chan, n_classes): + super().__init__() + self.conv = layers.SeparableConvBNReLU( + in_chan, mid_chan, kernel_size=3, bias_attr=False) + self.conv_out = nn.Conv2D( + mid_chan, n_classes, kernel_size=1, bias_attr=False) + + def forward(self, x): + x = self.conv(x) + x = self.conv_out(x) + return x