Skip to content

Commit

Permalink
Refactoring weight info.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Feb 21, 2022
1 parent aa82cf1 commit 9310325
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
3 changes: 2 additions & 1 deletion torchvision/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
# Temporary TF weights
"efficientnet_v2_s": "https://download.pytorch.org/models/efficientnet_v2_s-tmp.pth",
}


Expand Down Expand Up @@ -176,7 +178,6 @@ def __init__(
cnf: FusedMBConvConfig,
stochastic_depth_prob: float,
norm_layer: Callable[..., nn.Module],
**kwargs: Any,
) -> None:
super().__init__()

Expand Down
43 changes: 37 additions & 6 deletions torchvision/prototype/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,30 @@ def _efficientnet(
return model


_COMMON_META_V1 = {
_COMMON_META = {
"task": "image_classification",
"architecture": "EfficientNet",
"publication_year": 2019,
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BICUBIC,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
}


_COMMON_META_V1 = {
**_COMMON_META,
"architecture": "EfficientNet",
"publication_year": 2019,
"min_size": (1, 1),
}


_COMMON_META_V2 = {
**_COMMON_META,
"architecture": "EfficientNetV2",
"publication_year": 2021,
"min_size": (33, 33),
}


class EfficientNet_B0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
Expand Down Expand Up @@ -202,7 +215,25 @@ class EfficientNet_B7_Weights(WeightsEnum):


class EfficientNet_V2_S_Weights(WeightsEnum):
pass
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_s-tmp.pth",
transforms=partial(
ImageNetEval,
crop_size=384,
resize_size=384,
interpolation=InterpolationMode.BICUBIC,
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5),
),
meta={
**_COMMON_META_V2,
"num_params": 21458488,
"size": (384, 384),
"acc@1": 83.152,
"acc@5": 96.400,
},
)
DEFAULT = IMAGENET1K_V1


class EfficientNet_V2_M_Weights(WeightsEnum):
Expand Down Expand Up @@ -317,7 +348,7 @@ def efficientnet_b7(
)


@handle_legacy_interface(weights=("pretrained", None))
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
def efficientnet_v2_s(
*, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
Expand Down

0 comments on commit 9310325

Please sign in to comment.