-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] fix patch_embed and pos_embed mismatch error #685
Conversation
Codecov Report
@@ Coverage Diff @@
## master #685 +/- ##
=======================================
Coverage 85.78% 85.78%
=======================================
Files 105 105
Lines 5627 5627
Branches 915 916 +1
=======================================
Hits 4827 4827
Misses 621 621
Partials 179 179
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
mmseg/models/backbones/vit.py
Outdated
x, H, W = self.patch_embed( | ||
inputs), self.patch_embed.DH, self.patch_embed.DW |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may use hw_shape
instead, to be consistent with Swin Transformer.
mmseg/models/backbones/vit.py
Outdated
@@ -317,14 +316,13 @@ def init_weights(self): | |||
constant_init(m.bias, 0) | |||
constant_init(m.weight, 1.0) | |||
|
|||
def _pos_embeding(self, img, patched_img, pos_embed): | |||
def _pos_embeding(self, downsampled_img_size, patched_img, pos_embed): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _pos_embeding(self, downsampled_img_size, patched_img, pos_embed): | |
def _pos_embeding(self, x, hw_shape, pos_embed): |
assert out_shape in ['NLC', | ||
'NCHW'], 'output shape must be "NLC" or "NCHW".' | ||
if output_cls_token: | ||
assert with_cls_token is True, f'with_cls_token must be True if' \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to add this description to Docstring.
mmseg/models/backbones/vit.py
Outdated
pos_embed (torch.Tensor): pos_embed weights. | ||
input_shpae (tuple): Tuple for (input_h, intput_w). | ||
pos_shape (tuple): Tuple for (pos_h, pos_w). | ||
patch_size (int): Patch size. | ||
mode (str): Algorithm used for upsampling. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to descript the Args
in more detail, and we don’t need abbreviations in the description, e.g. pos_embed should be position embedding.
mmseg/models/backbones/vit.py
Outdated
@@ -371,7 +370,7 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode): | |||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) | |||
pos_embed_weight = F.interpolate( | |||
pos_embed_weight, | |||
size=[input_h // patch_size, input_w // patch_size], | |||
size=[input_h, input_w], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size=[input_h, input_w], | |
size=input_shpae, |
so input_h, input_w = input_shpae
is redundant.
patched_img (torch.Tensor): The patched image, it should be | ||
shape of [B, L1, C]. | ||
hw_shape (tuple): The downsampled image resolution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe output_shape
is better.
I am not sure. @xvjiarui
mmseg/models/backbones/vit.py
Outdated
self.interpolate_mode) | ||
return self.drop_after_pos(patched_img + pos_embed) | ||
|
||
@staticmethod | ||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode): | ||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): | ||
"""Resize pos_embed weights. | ||
|
||
Resize pos_embed using bicubic interpolate method. | ||
Args: | ||
pos_embed (torch.Tensor): pos_embed weights. | ||
input_shpae (tuple): Tuple for (input_h, intput_w). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe output_shape
is better.
I am not sure. @xvjiarui
…into fix_vit_pos_embed
If -> Whether
…mentation into fix_vit_pos_embed
…mentation into fix_vit_pos_embed
Some configs still have |
I have checked all vit configs, |
* fix patch_embed and pos_embed mismatch error * add docstring * update unittest * use downsampled image shape * use tuple * remove unused parameters and add doc * fix init weights function * revise docstring * Update vit.py If -> Whether * fix lint Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
Fix patch_embed and pos_embed mismatch error and remove
out_shape
param.Modification
resize_pos_embed()
function, interpolation will match padded input image.out_shape
parameters withoutput_cls_token
, user may choose whether appendcls_token
to output feature map.Checklist