Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ConvNeXt architecture in prototype #5197

Merged
merged 22 commits into from
Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .alexnet import *
from .convnext import *
from .resnet import *
from .vgg import *
from .squeezenet import *
Expand Down
201 changes: 201 additions & 0 deletions torchvision/models/convnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence

import torch
from torch import nn, Tensor
from torch.nn import functional as F

from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation
from ..ops.stochastic_depth import StochasticDepth
from ..utils import _log_api_usage_once


__all__ = [
"ConvNeXt",
"convnext_tiny",
]


model_urls: Dict[str, Optional[str]] = {}


class LayerNorm2d(nn.LayerNorm):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.channels_last = kwargs.pop("channels_last", False)
super().__init__(*args, **kwargs)

def forward(self, x):
if not self.channels_last:
x = x.permute(0, 2, 3, 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At TorchVision, we try to reuse as much as possible standard PyTorch components. I understand that in your original implementation, you provide a custom implementation for channels first. Could you talk about the performance degradation you've experienced and how much this was that lead you to reimplementing it?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is we don't have a standard PyTorch component for channels_first LN. Here are some discussions with Ross Wightman (https://twitter.com/wightmanr/status/1481383509142818817?s=20). In Ross's timm implementation, he has a presumably better LN implementation for channel_first. See here: https://github.com/rwightman/pytorch-image-models/blob/b669f4a5881d17fe3875656ec40138c1ef50c4c9/timm/models/convnext.py#L109

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the references. I wasn't aware of the concurrent discussions. I'll have a look to measure on our side. Ideally I would like to reuse the existing kernels as much as possible, unless there is a big gap in performance to justify a custom implementation.

Copy link

@liuzhuang13 liuzhuang13 Jan 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @datumbox,
In this post I try to explain why we use linear layers instead of conv layers for 1x1 convs in residual blocks, in case it is any help: facebookresearch/ConvNeXt#18 (comment)

As for why we use the custom LN in the downsampling layers instead of permuting -> PyTorch LN -> permuting back, the reason is similar, we observe the former is slightly faster when used in downsampling layers.

Thanks for your work on incorporating ConvNeXt!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liuzhuang13 Thanks for the reply. Is it fair to say that the approach that we follow here is expected to be 0-5% slower than the optimum? I haven't had the chance to run benchmarks but that's what I understand from your note.

FYI, the issue is that TorchVision has a common API across all models to accept the norm_layer as a parameter which is going to be tricky to support if I switch the 1x1 convs to linear. Where or not we will do this, depends on the speed impact.

Copy link

@liuzhuang13 liuzhuang13 Jan 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried switching to conv layers in residual blocks, and using permutation + LN in all cases. I found combined together they cause 20-30% slowdown in inference @ 224 resolution for ConvNeXt-T/S, compared to our released impl. However, in ConvNeXt-B at 224 or any model at 384 resolution, it seems as fast as our released impl. I only tried ConvNeXt T/S/B. This is on V100s, and I cannot say much on other platforms though. It is indeed a bit strange to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for looking into it @liuzhuang13. 20-30% slow sounds very large and we would probably want to make it faster. I'll run benchmarks later and share with you any findings.

x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
if not self.channels_last:
x = x.permute(0, 3, 1, 2)
return x


class CNBlock(nn.Module):
def __init__(self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module]):
super().__init__()
self.block = nn.Sequential(
ConvNormActivation(
dim,
dim,
kernel_size=7,
groups=dim,
norm_layer=norm_layer,
activation_layer=None,
bias=True, # TODO: check
datumbox marked this conversation as resolved.
Show resolved Hide resolved
),
ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None),
ConvNormActivation(
4 * dim,
dim,
kernel_size=1,
norm_layer=None,
activation_layer=None,
),
)
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")

def forward(self, input: Tensor) -> Tensor:
result = self.layer_scale * self.block(input)
result = self.stochastic_depth(result)
result += input
return result


class CNBlockConfig:
# Stores information listed at Section 3 of the ConvNeXt paper
def __init__(
self,
input_channels: int,
out_channels: Optional[int],
num_layers: int,
) -> None:
self.input_channels = input_channels
self.out_channels = out_channels
self.num_layers = num_layers

def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "input_channels={input_channels}"
s += ", out_channels={out_channels}"
s += ", num_layers={num_layers}"
s += ")"
return s.format(**self.__dict__)


class ConvNeXt(nn.Module):
def __init__(
self,
block_setting: List[CNBlockConfig],
stochastic_depth_prob: float = 0.0,
layer_scale: float = 1e-6,
num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
) -> None:
super().__init__()
_log_api_usage_once(self)

if not block_setting:
raise ValueError("The block_setting should not be empty")
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
raise TypeError("The block_setting should be List[CNBlockConfig]")

if block is None:
block = CNBlock

if norm_layer is None:
norm_layer = partial(LayerNorm2d, eps=1e-6)

layers: List[nn.Module] = []

# Stem
firstconv_output_channels = block_setting[0].input_channels
layers.append(
ConvNormActivation(
3,
firstconv_output_channels,
kernel_size=4,
stride=4,
padding=0,
norm_layer=norm_layer,
activation_layer=None,
bias=True,
)
)

total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
stage_block_id = 0
for cnf in block_setting:
# Bottlenecks
stage: List[nn.Module] = []
for _ in range(cnf.num_layers):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer))
stage_block_id += 1
layers.append(nn.Sequential(*stage))
if cnf.out_channels is not None:
# Downsampling
layers.append(
nn.Sequential(
norm_layer(cnf.input_channels),
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
)
)

self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)

lastblock = block_setting[-1]
lastconv_output_channels = (
lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
)
self.classifier = nn.Sequential(
norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
)

for m in self.modules():
datumbox marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)

def _forward_impl(self, x: Tensor) -> Tensor:
x = self.features(x)
x = self.avgpool(x)
x = self.classifier(x)
return x

def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)


def convnext_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt:
r"""ConvNeXt model architecture from the
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
block_setting = [
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 9),
CNBlockConfig(768, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
if pretrained:
arch = "convnext_tiny"
if arch not in model_urls:
raise ValueError(f"No checkpoint is available for model type {arch}")
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
5 changes: 3 additions & 2 deletions torchvision/ops/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
inplace: bool = True,
inplace: Optional[bool] = True,
datumbox marked this conversation as resolved.
Show resolved Hide resolved
bias: Optional[bool] = None,
) -> None:
if padding is None:
Expand All @@ -153,7 +153,8 @@ def __init__(
if norm_layer is not None:
layers.append(norm_layer(out_channels))
if activation_layer is not None:
layers.append(activation_layer(inplace=inplace))
params = {} if inplace is None else {"inplace": inplace}
layers.append(activation_layer(**params))
super().__init__(*layers)
_log_api_usage_once(self)
self.out_channels = out_channels
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .alexnet import *
from .convnext import *
from .densenet import *
from .efficientnet import *
from .googlenet import *
Expand Down
34 changes: 34 additions & 0 deletions torchvision/prototype/models/convnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any, Optional

from ...models.convnext import ConvNeXt, CNBlockConfig
from ._api import WeightsEnum
from ._utils import handle_legacy_interface, _ovewrite_named_param


__all__ = ["ConvNeXt", "ConvNeXt_Tiny_Weights", "convnext_tiny"]


class ConvNeXt_Tiny_Weights(WeightsEnum):
pass


@handle_legacy_interface(weights=("pretrained", None))
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
weights = ConvNeXt_Tiny_Weights.verify(weights)

if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

block_setting = [
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 9),
CNBlockConfig(768, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))

return model