Skip to content

Commit

Permalink
[Fix] Fix the output position of Swin-Transformer. (open-mmlab#947)
Browse files Browse the repository at this point in the history
* [Fix] Fix the output position of Swin-Transformer.

* Rename `downsample` argument to `do_downsample`.
  • Loading branch information
mzr1996 authored Aug 3, 2022
1 parent 6ec38fe commit b5bb86a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
25 changes: 19 additions & 6 deletions mmcls/models/backbones/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,11 @@ def __init__(self,
else:
self.downsample = None

def forward(self, x, in_shape):
def forward(self, x, in_shape, do_downsample=True):
for block in self.blocks:
x = block(x, in_shape)

if self.downsample:
if self.downsample is not None and do_downsample:
x, out_shape = self.downsample(x, in_shape)
else:
out_shape = in_shape
Expand Down Expand Up @@ -232,6 +232,8 @@ class SwinTransformer(BaseBackbone):
window_size (int): The height and width of the window. Defaults to 7.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
out_after_downsample (bool): Whether to output the feature map of a
stage after the following downsample layer. Defaults to False.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
interpolate_mode (str): Select the interpolate mode for absolute
Expand Down Expand Up @@ -301,6 +303,7 @@ def __init__(self,
drop_rate=0.,
drop_path_rate=0.1,
out_indices=(3, ),
out_after_downsample=False,
use_abs_pos_embed=False,
interpolate_mode='bicubic',
with_cp=False,
Expand Down Expand Up @@ -329,6 +332,7 @@ def __init__(self,
self.num_heads = self.arch_settings['num_heads']
self.num_layers = len(self.depths)
self.out_indices = out_indices
self.out_after_downsample = out_after_downsample
self.use_abs_pos_embed = use_abs_pos_embed
self.interpolate_mode = interpolate_mode
self.frozen_stages = frozen_stages
Expand Down Expand Up @@ -392,9 +396,15 @@ def __init__(self,
dpr = dpr[depth:]
embed_dims.append(stage.out_channels)

if self.out_after_downsample:
self.num_features = embed_dims[1:]
else:
self.num_features = embed_dims[:-1]

for i in out_indices:
if norm_cfg is not None:
norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1]
norm_layer = build_norm_layer(norm_cfg,
self.num_features[i])[1]
else:
norm_layer = nn.Identity()

Expand All @@ -421,14 +431,17 @@ def forward(self, x):

outs = []
for i, stage in enumerate(self.stages):
x, hw_shape = stage(x, hw_shape)
x, hw_shape = stage(
x, hw_shape, do_downsample=self.out_after_downsample)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *hw_shape,
stage.out_channels).permute(0, 3, 1,
2).contiguous()
self.num_features[i]).permute(0, 3, 1,
2).contiguous()
outs.append(out)
if stage.downsample is not None and not self.out_after_downsample:
x, hw_shape = stage.downsample(x, hw_shape)

return tuple(outs)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_backbones/test_swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_forward(self):
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 4)
for stride, out in zip([2, 4, 8, 8], outs):
for stride, out in zip([1, 2, 4, 8], outs):
self.assertEqual(out.shape,
(1, 128 * stride, 56 // stride, 56 // stride))

Expand Down

0 comments on commit b5bb86a

Please sign in to comment.