Skip to content

Commit

Permalink
Adding multiweight support for mobilenetv2 prototype (#4784)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdsgomes authored Oct 28, 2021
1 parent 79b350e commit 082f37e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchvision/prototype/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .vgg import *
from .efficientnet import *
from .mobilenetv3 import *
from .mobilenetv2 import *
from .mnasnet import *
from . import detection
from . import quantization
Expand Down
46 changes: 46 additions & 0 deletions torchvision/prototype/models/mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import warnings
from functools import partial
from typing import Any, Optional

from torchvision.transforms.functional import InterpolationMode

from ...models.mobilenetv2 import MobileNetV2
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES


__all__ = ["MobileNetV2", "MobileNetV2Weights", "mobilenet_v2"]


_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}


class MobileNetV2Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
"acc@1": 71.878,
"acc@5": 90.286,
},
)


def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = MobileNetV2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = MobileNetV2Weights.verify(weights)

if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])

model = MobileNetV2(**kwargs)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))

return model

0 comments on commit 082f37e

Please sign in to comment.