Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shufflenetv2 1.5 and 2.0 weights #5906

Merged
merged 16 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,10 @@ Densenet-201 76.896 93.370
Densenet-161 77.138 93.560
Inception v3 77.294 93.450
GoogleNet 69.778 89.530
ShuffleNet V2 x1.0 69.362 88.316
ShuffleNet V2 x0.5 60.552 81.746
ShuffleNet V2 x1.0 69.362 88.316
ShuffleNet V2 x1.5 72.784 91.058
ShuffleNet V2 x2.0 76.200 92.888
MobileNet V2 71.878 90.286
MobileNet V3 Large 74.042 91.340
MobileNet V3 Small 67.668 87.402
Expand Down
30 changes: 28 additions & 2 deletions torchvision/models/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,37 @@ class ShuffleNet_V2_X1_0_Weights(WeightsEnum):


class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
pass
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x1_5-3c479a10.pth",
transforms=partial(ImageClassification, crop_size=224),
YosuaMichael marked this conversation as resolved.
Show resolved Hide resolved
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/pull/5906",
"num_params": 3503624,
"metrics": {
"acc@1": 72.784,
"acc@5": 91.058,
},
},
)
DEFAULT = IMAGENET1K_V1


class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
pass
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x2_0-8be3c8ee.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/pull/5906",
"num_params": 7393996,
"metrics": {
"acc@1": 76.200,
"acc@5": 92.888,
},
},
)
DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1))
Expand Down