|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from mmcv.cnn import ConvModule |
| 4 | +from mmcv.runner.base_module import BaseModule |
| 5 | + |
| 6 | +from mmseg.ops import resize |
| 7 | +from ..builder import HEADS |
| 8 | +from .decode_head import BaseDecodeHead |
| 9 | +from .psp_head import PPM |
| 10 | + |
| 11 | + |
| 12 | +@HEADS.register_module() |
| 13 | +class SFNetHead(BaseDecodeHead): |
| 14 | + """Semantic Flow for Fast and Accurate SceneParsing. |
| 15 | +
|
| 16 | + This head is the implementation of |
| 17 | + `SFSegNet <https://arxiv.org/pdf/2002.10120>`_. |
| 18 | +
|
| 19 | + Args: |
| 20 | + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid |
| 21 | + Module. Default: (1, 2, 3, 6). |
| 22 | + fpn_inplanes (list): |
| 23 | + The list of feature channels number from backbone. |
| 24 | + fpn_dim (int, optional): |
| 25 | + The input channels of FAM module. |
| 26 | + Default: 256 for ResNet50, 128 for ResNet18. |
| 27 | + """ |
| 28 | + |
| 29 | + def __init__(self, |
| 30 | + pool_scales=(1, 2, 3, 6), |
| 31 | + fpn_inplanes=[256, 512, 1024, 2048], |
| 32 | + fpn_dim=256, |
| 33 | + **kwargs): |
| 34 | + super(SFNetHead, self).__init__(**kwargs) |
| 35 | + assert isinstance(pool_scales, (list, tuple)) |
| 36 | + self.pool_scales = pool_scales |
| 37 | + self.fpn_inplanes = fpn_inplanes |
| 38 | + self.fpn_dim = fpn_dim |
| 39 | + self.psp_modules = PPM( |
| 40 | + self.pool_scales, |
| 41 | + self.in_channels, |
| 42 | + self.in_channels // 4, |
| 43 | + bias=True, |
| 44 | + conv_cfg=self.conv_cfg, |
| 45 | + norm_cfg=self.norm_cfg, |
| 46 | + act_cfg=self.act_cfg, |
| 47 | + align_corners=True) |
| 48 | + self.bottleneck = ConvModule( |
| 49 | + self.in_channels * 2, |
| 50 | + self.channels, |
| 51 | + 3, |
| 52 | + padding=1, |
| 53 | + bias=True, |
| 54 | + conv_cfg=self.conv_cfg, |
| 55 | + norm_cfg=self.norm_cfg, |
| 56 | + act_cfg=self.act_cfg) |
| 57 | + |
| 58 | + self.fpn_in = [] |
| 59 | + for fpn_inplane in self.fpn_inplanes[:-1]: |
| 60 | + self.fpn_in.append( |
| 61 | + ConvModule( |
| 62 | + fpn_inplane, |
| 63 | + self.fpn_dim, |
| 64 | + kernel_size=1, |
| 65 | + bias=True, |
| 66 | + conv_cfg=self.conv_cfg, |
| 67 | + norm_cfg=self.norm_cfg, |
| 68 | + act_cfg=self.act_cfg, |
| 69 | + inplace=False)) |
| 70 | + self.fpn_in = nn.ModuleList(self.fpn_in) |
| 71 | + self.fpn_out = [] |
| 72 | + self.fpn_out_align = [] |
| 73 | + self.dsn = [] |
| 74 | + for i in range(len(self.fpn_inplanes) - 1): |
| 75 | + self.fpn_out.append( |
| 76 | + ConvModule( |
| 77 | + self.fpn_dim, |
| 78 | + self.fpn_dim, |
| 79 | + kernel_size=3, |
| 80 | + stride=1, |
| 81 | + padding=1, |
| 82 | + bias=False, |
| 83 | + conv_cfg=self.conv_cfg, |
| 84 | + norm_cfg=self.norm_cfg, |
| 85 | + act_cfg=self.act_cfg, |
| 86 | + inplace=True)) |
| 87 | + self.fpn_out_align.append( |
| 88 | + AlignedModule( |
| 89 | + inplane=self.fpn_dim, outplane=self.fpn_dim // 2)) |
| 90 | + |
| 91 | + self.fpn_out = nn.ModuleList(self.fpn_out) |
| 92 | + self.fpn_out_align = nn.ModuleList(self.fpn_out_align) |
| 93 | + self.conv_last = ConvModule( |
| 94 | + len(self.fpn_inplanes) * self.fpn_dim, |
| 95 | + self.fpn_dim, |
| 96 | + kernel_size=3, |
| 97 | + stride=1, |
| 98 | + padding=1, |
| 99 | + bias=False, |
| 100 | + conv_cfg=self.conv_cfg, |
| 101 | + norm_cfg=self.norm_cfg, |
| 102 | + act_cfg=self.act_cfg, |
| 103 | + inplace=True) |
| 104 | + |
| 105 | + def forward(self, inputs): |
| 106 | + x = self._transform_inputs(inputs) |
| 107 | + psp_outs = [x] |
| 108 | + psp_outs.extend(self.psp_modules(x)[::-1]) |
| 109 | + psp_outs = torch.cat(psp_outs, dim=1) |
| 110 | + psp_out = self.bottleneck(psp_outs) |
| 111 | + |
| 112 | + f = psp_out |
| 113 | + fpn_feature_list = [psp_out] |
| 114 | + |
| 115 | + for i in reversed(range(len(inputs) - 1)): |
| 116 | + conv_x = inputs[i] |
| 117 | + conv_x = self.fpn_in[i](conv_x) |
| 118 | + f = self.fpn_out_align[i]([conv_x, f]) |
| 119 | + f = conv_x + f |
| 120 | + fpn_feature_list.append(self.fpn_out[i](f)) |
| 121 | + |
| 122 | + fpn_feature_list.reverse() # [P2 - P5] |
| 123 | + output_size = fpn_feature_list[0].size()[2:] |
| 124 | + fusion_list = [fpn_feature_list[0]] |
| 125 | + |
| 126 | + for i in range(1, len(fpn_feature_list)): |
| 127 | + fusion_list.append( |
| 128 | + nn.functional.interpolate( |
| 129 | + fpn_feature_list[i], |
| 130 | + output_size, |
| 131 | + mode='bilinear', |
| 132 | + align_corners=True)) |
| 133 | + |
| 134 | + fusion_out = torch.cat(fusion_list, 1) |
| 135 | + x = self.conv_last(fusion_out) |
| 136 | + output = self.cls_seg(x) |
| 137 | + |
| 138 | + return output |
| 139 | + |
| 140 | + |
| 141 | +class AlignedModule(BaseModule): |
| 142 | + """The implementation of Flow Alignment Module (FAM). |
| 143 | +
|
| 144 | + Args: |
| 145 | + inplane (int): The number of FAM input channles. |
| 146 | + outplane (int): The number of FAM output channles. |
| 147 | + """ |
| 148 | + |
| 149 | + def __init__(self, inplane, outplane, kernel_size=3): |
| 150 | + super(AlignedModule, self).__init__() |
| 151 | + self.down_h = nn.Conv2d(inplane, outplane, 1, bias=False) |
| 152 | + self.down_l = nn.Conv2d(inplane, outplane, 1, bias=False) |
| 153 | + self.flow_make = nn.Conv2d( |
| 154 | + outplane * 2, 2, kernel_size=kernel_size, padding=1, bias=False) |
| 155 | + |
| 156 | + def forward(self, x): |
| 157 | + low_feature, h_feature = x |
| 158 | + h_feature_orign = h_feature |
| 159 | + h, w = low_feature.size()[2:] |
| 160 | + size = (h, w) |
| 161 | + low_feature = self.down_l(low_feature) |
| 162 | + h_feature = self.down_h(h_feature) |
| 163 | + h_feature = resize( |
| 164 | + h_feature, size=size, mode='bilinear', align_corners=True) |
| 165 | + flow = self.flow_make(torch.cat([h_feature, low_feature], 1)) |
| 166 | + h_feature = self.flow_warp(h_feature_orign, flow, size=size) |
| 167 | + |
| 168 | + return h_feature |
| 169 | + |
| 170 | + def flow_warp(self, input, flow, size): |
| 171 | + """Implementation of Warp Procedure in Fig 3(b) of original paper, |
| 172 | + which is between Flow Field and High Resolution Feature Map. |
| 173 | +
|
| 174 | + Args: |
| 175 | + input (Tensor): High Resolution Feature Map. |
| 176 | + flow (Tensor): Semantic Flow Field that will give |
| 177 | + dynamic indication about how to align these |
| 178 | + two feature maps effectively. |
| 179 | + size (Tuple): Shape of height and width of output. |
| 180 | +
|
| 181 | + Returns: |
| 182 | + output (Tensor): High Resolution Feature Map after |
| 183 | + warped offset and bilinear interpolation. |
| 184 | +
|
| 185 | + For example, in cityscapes 1024x2048 dataset with ResNet18 config, |
| 186 | + feature map from backbone is: |
| 187 | + [[1, 64, 256, 512], |
| 188 | + [1, 128, 128, 256], |
| 189 | + [1, 256, 64, 128], |
| 190 | + [1, 512, 32, 64]] |
| 191 | +
|
| 192 | + Thus, its inverse shape of [input, flow, size] is: |
| 193 | + [[1, 128, 32, 64], [1, 2, 64, 128], (64, 128)], |
| 194 | + [[1, 128, 64, 128], [1, 2, 128, 256], (128, 256)], and |
| 195 | + [[1, 128, 128, 256], [1, 2, 256, 512], (256, 512)], respectively. |
| 196 | +
|
| 197 | + The final output is: |
| 198 | + [[1, 128, 64, 128], |
| 199 | + [1, 128, 128, 256], |
| 200 | + [1, 128, 256, 512]], respectively. |
| 201 | + """ |
| 202 | + |
| 203 | + out_h, out_w = size |
| 204 | + n, c, h, w = input.size() |
| 205 | + |
| 206 | + # Warped offset in grid, from -1 to 1. |
| 207 | + norm = torch.tensor([[[[out_w, |
| 208 | + out_h]]]]).type_as(input).to(input.device) |
| 209 | + h = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w) |
| 210 | + w = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1) |
| 211 | + grid = torch.cat((w.unsqueeze(2), h.unsqueeze(2)), 2) |
| 212 | + grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device) |
| 213 | + |
| 214 | + # Warped grid which is corrected the flow offset. |
| 215 | + grid = grid + flow.permute(0, 2, 3, 1) / norm |
| 216 | + |
| 217 | + # Sampling mechanism interpolates the values of the 4-neighbors |
| 218 | + # (top-left, top-right, bottom-left, and bottom-right) of input. |
| 219 | + output = nn.functional.grid_sample(input, grid, align_corners=True) |
| 220 | + return output |
0 commit comments