Skip to content

Commit

Permalink
Add Medium and Large weights
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Mar 2, 2022
1 parent bf41dfb commit 9057045
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 8 deletions.
10 changes: 6 additions & 4 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,14 @@ torchrun --nproc_per_node=8 train.py \
--train-crop-size $TRAIN_SIZE --model-ema --val-crop-size $EVAL_SIZE --val-resize-size $EVAL_SIZE \
--ra-sampler --ra-reps 4
```
Here `$MODEL` is one of `efficientnet_v2_s`, `efficientnet_v2_m` and `efficientnet_v2_l`.
Note that the Small variant had a `$TRAIN_SIZE` of `300` and a `$EVAL_SIZE` of `384`, while the other variants `384` and `480` respectively.
Here `$MODEL` is one of `efficientnet_v2_s` and `efficientnet_v2_m`.
Note that the Small variant had a `$TRAIN_SIZE` of `300` and a `$EVAL_SIZE` of `384`, while the Medium `384` and `480` respectively.

Note that the above command corresponds to training on a single node with 8 GPUs.
For generatring the pre-trained weights, we trained with 8 nodes, each with 8 GPUs (for a total of 64 GPUs),
and `--batch_size 16`.
For generatring the pre-trained weights, we trained with 4 nodes, each with 8 GPUs (for a total of 32 GPUs),
and `--batch_size 32`.

The weights of the Large variant are ported from the original paper rather than trained from scratch. See the `EfficientNet_V2_L_Weights` entry for their exact preprocessing transforms.


### RegNet
Expand Down
3 changes: 3 additions & 0 deletions torchvision/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
# Weights trained with TorchVision
"efficientnet_v2_s": "https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
"efficientnet_v2_m": "https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
# Weights ported from TF
"efficientnet_v2_l": "https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
}


Expand Down
42 changes: 38 additions & 4 deletions torchvision/prototype/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,45 @@ class EfficientNet_V2_S_Weights(WeightsEnum):


class EfficientNet_V2_M_Weights(WeightsEnum):
pass
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_s-dc08266a.pth",
transforms=partial(
ImageNetEval,
crop_size=480,
resize_size=480,
interpolation=InterpolationMode.BILINEAR,
),
meta={
**_COMMON_META_V2,
"num_params": 54139356,
"size": (480, 480),
"acc@1": 85.119,
"acc@5": 97.151,
},
)
DEFAULT = IMAGENET1K_V1


class EfficientNet_V2_L_Weights(WeightsEnum):
pass
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
transforms=partial(
ImageNetEval,
crop_size=480,
resize_size=480,
interpolation=InterpolationMode.BICUBIC,
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5),
),
meta={
**_COMMON_META_V2,
"num_params": 118515272,
"size": (480, 480),
"acc@1": 85.808,
"acc@5": 97.788,
},
)
DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
Expand Down Expand Up @@ -365,7 +399,7 @@ def efficientnet_v2_s(
)


@handle_legacy_interface(weights=("pretrained", None))
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
def efficientnet_v2_m(
*, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
Expand All @@ -383,7 +417,7 @@ def efficientnet_v2_m(
)


@handle_legacy_interface(weights=("pretrained", None))
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
def efficientnet_v2_l(
*, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
Expand Down

0 comments on commit 9057045

Please sign in to comment.