Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding more ConvNeXt variants + Speed optimizations #5253

Merged
merged 19 commits into from
Feb 1, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ vit_b_32 75.912 92.466
vit_l_16 79.662 94.638
vit_l_32 76.972 93.070
convnext_tiny (prototype) 82.520 96.146
convnext_small (prototype) 83.616 96.650
convnext_base (prototype) 84.062 96.870
convnext_large (prototype) 84.414 96.976
================================ ============= =============


Expand Down
5 changes: 3 additions & 2 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,12 @@ and `--batch_size 64`.
### ConvNeXt
```
torchrun --nproc_per_node=8 train.py\
--model convnext_tiny --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \
--model $MODEL --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \
--lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \
--label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.05 --norm-weight-decay 0.0 \
--train-crop-size 176 --model-ema --val-resize-size 236 --ra-sampler --ra-reps 4
--train-crop-size 176 --model-ema --val-resize-size 232 --ra-sampler --ra-reps 4
```
Here `$MODEL` is one of `convnext_tiny`, `convnext_small`, `convnext_base` and `convnext_large`. Note that each variant had its `--val-resize-size` optimized in a post-training step, see their `Weights` entry for their exact value.

Note that the above command corresponds to training on a single node with 8 GPUs.
For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs),
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
204 changes: 158 additions & 46 deletions torchvision/prototype/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,47 +15,56 @@
from ._utils import handle_legacy_interface, _ovewrite_named_param


__all__ = ["ConvNeXt", "ConvNeXt_Tiny_Weights", "convnext_tiny"]
__all__ = [
"ConvNeXt",
"ConvNeXt_Tiny_Weights",
"ConvNeXt_Small_Weights",
"ConvNeXt_Base_Weights",
"ConvNeXt_Large_Weights",
"convnext_tiny",
"convnext_small",
"convnext_base",
"convnext_large",
]


class LayerNorm2d(nn.LayerNorm):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.channels_last = kwargs.pop("channels_last", False)
super().__init__(*args, **kwargs)

def forward(self, x: Tensor) -> Tensor:
# TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298
if not self.channels_last:
x = x.permute(0, 2, 3, 1)
x = x.permute(0, 2, 3, 1)
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
if not self.channels_last:
x = x.permute(0, 3, 1, 2)
x = x.permute(0, 3, 1, 2)
return x


class Permute(nn.Module):
def __init__(self, dims: List[int]):
super().__init__()
self.dims = dims

def forward(self, x):
return torch.permute(x, self.dims)


class CNBlock(nn.Module):
def __init__(
self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module]
self,
dim,
layer_scale: float,
stochastic_depth_prob: float,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)

self.block = nn.Sequential(
ConvNormActivation(
dim,
dim,
kernel_size=7,
groups=dim,
norm_layer=norm_layer,
activation_layer=None,
bias=True,
),
ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None),
ConvNormActivation(
4 * dim,
dim,
kernel_size=1,
norm_layer=None,
activation_layer=None,
),
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
Permute([0, 2, 3, 1]),
norm_layer(dim),
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
nn.GELU(),
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
Permute([0, 3, 1, 2]),
)
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
Expand Down Expand Up @@ -138,7 +147,7 @@ def __init__(
for _ in range(cnf.num_layers):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer))
stage.append(block(cnf.input_channels, layer_scale, sd_prob))
stage_block_id += 1
layers.append(nn.Sequential(*stage))
if cnf.out_channels is not None:
Expand Down Expand Up @@ -177,30 +186,95 @@ def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)


def _convnext(
block_setting: List[CNBlockConfig],
stochastic_depth_prob: float,
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> ConvNeXt:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))

return model


_COMMON_META = {
"task": "image_classification",
"architecture": "ConvNeXt",
"publication_year": 2022,
"size": (224, 224),
"min_size": (32, 32),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
}


class ConvNeXt_Tiny_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth",
url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=236),
meta={
"task": "image_classification",
"architecture": "ConvNeXt",
"publication_year": 2022,
**_COMMON_META,
"num_params": 28589128,
"size": (224, 224),
"min_size": (32, 32),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
"acc@1": 82.520,
"acc@5": 96.146,
},
)
DEFAULT = IMAGENET1K_V1


class ConvNeXt_Small_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=230),
meta={
**_COMMON_META,
"num_params": 50223688,
"acc@1": 83.616,
"acc@5": 96.650,
},
)
DEFAULT = IMAGENET1K_V1


class ConvNeXt_Base_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 88591464,
"acc@1": 84.062,
"acc@5": 96.870,
},
)
DEFAULT = IMAGENET1K_V1


class ConvNeXt_Large_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 197767336,
"acc@1": 84.414,
"acc@5": 96.976,
},
)
DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
r"""ConvNeXt model architecture from the
r"""ConvNeXt Tiny model architecture from the
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.

Args:
Expand All @@ -209,19 +283,57 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress:
"""
weights = ConvNeXt_Tiny_Weights.verify(weights)

if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

block_setting = [
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 9),
CNBlockConfig(768, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))

return model
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
def convnext_small(
*, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
) -> ConvNeXt:
weights = ConvNeXt_Small_Weights.verify(weights)

block_setting = [
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 27),
CNBlockConfig(768, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)


@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
weights = ConvNeXt_Base_Weights.verify(weights)

block_setting = [
CNBlockConfig(128, 256, 3),
CNBlockConfig(256, 512, 3),
CNBlockConfig(512, 1024, 27),
CNBlockConfig(1024, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)


@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
def convnext_large(
*, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
) -> ConvNeXt:
weights = ConvNeXt_Large_Weights.verify(weights)

block_setting = [
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 3),
CNBlockConfig(768, 1536, 27),
CNBlockConfig(1536, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)