-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add model script, training recipe and pretrained weight of cmt_s #680
Conversation
3c4214a
to
d9bdd27
Compare
mindcv/models/cmt.py
Outdated
to_2tuple = _ntuple(2) | ||
|
||
|
||
class DropPath(nn.Cell): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以用公共层的实现吗
configs/cmt/README.md
Outdated
|
||
### Deployment | ||
|
||
Please refer to the [deployment tutorial](https://github.com/mindspore-lab/mindcv/blob/main/tutorials/deployment.md) in MindCV. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
链接不对,链接到新地址
mindcv/models/cmt.py
Outdated
nn.Conv2d(hidden_features, out_features, 1, 1, has_bias=True), | ||
nn.BatchNorm2d(out_features), | ||
]) | ||
self.drop = nn.Dropout(keep_prob=1-drop) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用公共层的dropout
mindcv/models/cmt.py
Outdated
|
||
|
||
def swish(x): | ||
return x * P.Sigmoid()(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数式接口ops.sigmoid
mindcv/models/cmt.py
Outdated
|
||
def construct(self, x, H, W): | ||
B, _, C = x.shape | ||
x = ops.Transpose()(x, (0, 2, 1)).reshape(B, C, H, W) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ops.transpose
mindcv/models/cmt.py
Outdated
self.patch_embed_d = PatchEmbed( | ||
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]) | ||
|
||
self.relative_pos_a = P.Zeros()( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ops.zeros
b18be79
to
720e601
Compare
fixed |
mindcv/models/cmt.py
Outdated
|
||
class CMT(nn.Cell): | ||
def __init__( | ||
self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里缩进降低4个空格
mindcv/models/cmt.py
Outdated
self.patch_embed_d = PatchEmbed( | ||
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]) | ||
|
||
self.relative_pos_a = ops.Zeros()( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ops.zeros()
fixed |
momentum: 0.9 | ||
weight_decay: 0.05 | ||
loss_scale_type: 'dynamic' | ||
loss_scale: 16777216.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really large. How can you get this magic number?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your contribution to the MindCV repo.
Before submitting this PR, please make sure:
Motivation
The model script, training recipe and pertained weight of cmt_small is added.
Test Plan
Please refer to 'configs/cmt/README.md' for testing and reproducing.
Related Issues and PRs
(Is this PR part of a group of changes? Link the other relevant PRs and Issues here. Use https://help.github.com/en/articles/closing-issues-using-keywords for help on GitHub syntax)