Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add model script, training recipe and pretrained weight of cmt_s #680

Merged
merged 1 commit into from
Jul 10, 2023

Conversation

wcrzlh
Copy link
Collaborator

@wcrzlh wcrzlh commented Jun 12, 2023

Thank you for your contribution to the MindCV repo.
Before submitting this PR, please make sure:

Motivation

The model script, training recipe and pertained weight of cmt_small is added.

Test Plan

Please refer to 'configs/cmt/README.md' for testing and reproducing.

Related Issues and PRs

(Is this PR part of a group of changes? Link the other relevant PRs and Issues here. Use https://help.github.com/en/articles/closing-issues-using-keywords for help on GitHub syntax)

@wcrzlh wcrzlh force-pushed the cmt branch 3 times, most recently from 3c4214a to d9bdd27 Compare June 15, 2023 07:08
to_2tuple = _ntuple(2)


class DropPath(nn.Cell):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以用公共层的实现吗


### Deployment

Please refer to the [deployment tutorial](https://github.com/mindspore-lab/mindcv/blob/main/tutorials/deployment.md) in MindCV.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

链接不对,链接到新地址

nn.Conv2d(hidden_features, out_features, 1, 1, has_bias=True),
nn.BatchNorm2d(out_features),
])
self.drop = nn.Dropout(keep_prob=1-drop)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用公共层的dropout



def swish(x):
return x * P.Sigmoid()(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数式接口ops.sigmoid


def construct(self, x, H, W):
B, _, C = x.shape
x = ops.Transpose()(x, (0, 2, 1)).reshape(B, C, H, W)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ops.transpose

self.patch_embed_d = PatchEmbed(
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])

self.relative_pos_a = P.Zeros()(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ops.zeros

@vigo999 vigo999 self-requested a review June 26, 2023 14:49
@wcrzlh wcrzlh force-pushed the cmt branch 2 times, most recently from b18be79 to 720e601 Compare June 29, 2023 09:32
@wcrzlh
Copy link
Collaborator Author

wcrzlh commented Jun 29, 2023

fixed


class CMT(nn.Cell):
def __init__(
self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里缩进降低4个空格

self.patch_embed_d = PatchEmbed(
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])

self.relative_pos_a = ops.Zeros()(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ops.zeros()

@wcrzlh
Copy link
Collaborator Author

wcrzlh commented Jun 30, 2023

fixed

momentum: 0.9
weight_decay: 0.05
loss_scale_type: 'dynamic'
loss_scale: 16777216.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really large. How can you get this magic number?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use the initial loss scale(2**24) in mindspore.amp.DynamicLossScaleManager.
dynamiclossscale

@wcrzlh wcrzlh requested a review from SamitHuang July 6, 2023 08:50
@geniuspatrick geniuspatrick merged commit c391b57 into mindspore-lab:main Jul 10, 2023
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants