From 931032519acf16e9833ad3881fc9a45b69d258d8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 21 Feb 2022 19:26:53 +0000 Subject: [PATCH] Refactoring weight info. --- torchvision/models/efficientnet.py | 3 +- torchvision/prototype/models/efficientnet.py | 43 +++++++++++++++++--- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 82e4174cd9b..c56fac844ba 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -42,6 +42,8 @@ "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", + # Temporary TF weights + "efficientnet_v2_s": "https://download.pytorch.org/models/efficientnet_v2_s-tmp.pth", } @@ -176,7 +178,6 @@ def __init__( cnf: FusedMBConvConfig, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module], - **kwargs: Any, ) -> None: super().__init__() diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py index 95e7d0987f3..94f871d337c 100644 --- a/torchvision/prototype/models/efficientnet.py +++ b/torchvision/prototype/models/efficientnet.py @@ -57,17 +57,30 @@ def _efficientnet( return model -_COMMON_META_V1 = { +_COMMON_META = { "task": "image_classification", - "architecture": "EfficientNet", - "publication_year": 2019, - "min_size": (1, 1), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BICUBIC, "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet", } +_COMMON_META_V1 = { + **_COMMON_META, + "architecture": "EfficientNet", + "publication_year": 2019, + "min_size": (1, 1), +} + + +_COMMON_META_V2 = { + **_COMMON_META, + "architecture": "EfficientNetV2", + "publication_year": 2021, + "min_size": (33, 33), +} + + class EfficientNet_B0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", @@ -202,7 +215,25 @@ class EfficientNet_B7_Weights(WeightsEnum): class EfficientNet_V2_S_Weights(WeightsEnum): - pass + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_s-tmp.pth", + transforms=partial( + ImageNetEval, + crop_size=384, + resize_size=384, + interpolation=InterpolationMode.BICUBIC, + mean=(0.5, 0.5, 0.5), + std=(0.5, 0.5, 0.5), + ), + meta={ + **_COMMON_META_V2, + "num_params": 21458488, + "size": (384, 384), + "acc@1": 83.152, + "acc@5": 96.400, + }, + ) + DEFAULT = IMAGENET1K_V1 class EfficientNet_V2_M_Weights(WeightsEnum): @@ -317,7 +348,7 @@ def efficientnet_b7( ) -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1)) def efficientnet_v2_s( *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: