-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding ConvNeXt architecture in prototype #5197
Merged
Merged
Changes from 12 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
6682748
Adding CNBlock and skeleton architecture
datumbox 6c49ef8
Completed implementation.
datumbox a3034c4
Adding model in prototypes.
datumbox e57b64f
Add test and minor refactor for JIT.
datumbox 8cddcac
Fix mypy.
datumbox 6aedbcc
Merge branch 'main' into models/convnext
datumbox 0bef112
Fixing naming conventions.
datumbox cf69832
Fixing tests.
datumbox eb4c825
Fix stochastic depth percentages.
datumbox 52960cf
Adding stochastic depth to tiny variant.
datumbox 8ddc17c
Minor refactoring and adding comments.
datumbox 6dd11b7
Merge branch 'main' into models/convnext
datumbox ce05e24
Adding weights.
datumbox c4ffc84
Update default weights.
datumbox 7af0e20
Fix transforms issue
datumbox 442a7bf
Merge branch 'main' into models/convnext
datumbox 1ee5b0f
Move convnext to prototype.
datumbox be2972e
linter fix
datumbox f47a590
fix docs
datumbox 9e6fda1
Addressing code review comments.
datumbox daf07e0
Merge branch 'main' into models/convnext
datumbox 2edbd8d
Merge branch 'main' into models/convnext
datumbox File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
from functools import partial | ||
from typing import Any, Callable, Dict, List, Optional, Sequence | ||
|
||
import torch | ||
from torch import nn, Tensor | ||
from torch.nn import functional as F | ||
|
||
from .._internally_replaced_utils import load_state_dict_from_url | ||
from ..ops.misc import ConvNormActivation | ||
from ..ops.stochastic_depth import StochasticDepth | ||
from ..utils import _log_api_usage_once | ||
|
||
|
||
__all__ = [ | ||
"ConvNeXt", | ||
"convnext_tiny", | ||
] | ||
|
||
|
||
model_urls: Dict[str, Optional[str]] = {} | ||
|
||
|
||
class LayerNorm2d(nn.LayerNorm): | ||
def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
self.channels_last = kwargs.pop("channels_last", False) | ||
super().__init__(*args, **kwargs) | ||
|
||
def forward(self, x): | ||
if not self.channels_last: | ||
x = x.permute(0, 2, 3, 1) | ||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | ||
if not self.channels_last: | ||
x = x.permute(0, 3, 1, 2) | ||
return x | ||
|
||
|
||
class CNBlock(nn.Module): | ||
def __init__(self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module]): | ||
super().__init__() | ||
self.block = nn.Sequential( | ||
ConvNormActivation( | ||
dim, | ||
dim, | ||
kernel_size=7, | ||
groups=dim, | ||
norm_layer=norm_layer, | ||
activation_layer=None, | ||
bias=True, # TODO: check | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
), | ||
ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None), | ||
ConvNormActivation( | ||
4 * dim, | ||
dim, | ||
kernel_size=1, | ||
norm_layer=None, | ||
activation_layer=None, | ||
), | ||
) | ||
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) | ||
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") | ||
|
||
def forward(self, input: Tensor) -> Tensor: | ||
result = self.layer_scale * self.block(input) | ||
result = self.stochastic_depth(result) | ||
result += input | ||
return result | ||
|
||
|
||
class CNBlockConfig: | ||
# Stores information listed at Section 3 of the ConvNeXt paper | ||
def __init__( | ||
self, | ||
input_channels: int, | ||
out_channels: Optional[int], | ||
num_layers: int, | ||
) -> None: | ||
self.input_channels = input_channels | ||
self.out_channels = out_channels | ||
self.num_layers = num_layers | ||
|
||
def __repr__(self) -> str: | ||
s = self.__class__.__name__ + "(" | ||
s += "input_channels={input_channels}" | ||
s += ", out_channels={out_channels}" | ||
s += ", num_layers={num_layers}" | ||
s += ")" | ||
return s.format(**self.__dict__) | ||
|
||
|
||
class ConvNeXt(nn.Module): | ||
def __init__( | ||
self, | ||
block_setting: List[CNBlockConfig], | ||
stochastic_depth_prob: float = 0.0, | ||
layer_scale: float = 1e-6, | ||
num_classes: int = 1000, | ||
block: Optional[Callable[..., nn.Module]] = None, | ||
norm_layer: Optional[Callable[..., nn.Module]] = None, | ||
**kwargs: Any, | ||
) -> None: | ||
super().__init__() | ||
_log_api_usage_once(self) | ||
|
||
if not block_setting: | ||
raise ValueError("The block_setting should not be empty") | ||
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])): | ||
raise TypeError("The block_setting should be List[CNBlockConfig]") | ||
|
||
if block is None: | ||
block = CNBlock | ||
|
||
if norm_layer is None: | ||
norm_layer = partial(LayerNorm2d, eps=1e-6) | ||
|
||
layers: List[nn.Module] = [] | ||
|
||
# Stem | ||
firstconv_output_channels = block_setting[0].input_channels | ||
layers.append( | ||
ConvNormActivation( | ||
3, | ||
firstconv_output_channels, | ||
kernel_size=4, | ||
stride=4, | ||
padding=0, | ||
norm_layer=norm_layer, | ||
activation_layer=None, | ||
bias=True, | ||
) | ||
) | ||
|
||
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting) | ||
stage_block_id = 0 | ||
for cnf in block_setting: | ||
# Bottlenecks | ||
stage: List[nn.Module] = [] | ||
for _ in range(cnf.num_layers): | ||
# adjust stochastic depth probability based on the depth of the stage block | ||
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) | ||
stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer)) | ||
stage_block_id += 1 | ||
layers.append(nn.Sequential(*stage)) | ||
if cnf.out_channels is not None: | ||
# Downsampling | ||
layers.append( | ||
nn.Sequential( | ||
norm_layer(cnf.input_channels), | ||
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2), | ||
) | ||
) | ||
|
||
self.features = nn.Sequential(*layers) | ||
self.avgpool = nn.AdaptiveAvgPool2d(1) | ||
|
||
lastblock = block_setting[-1] | ||
lastconv_output_channels = ( | ||
lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels | ||
) | ||
self.classifier = nn.Sequential( | ||
norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes) | ||
) | ||
|
||
for m in self.modules(): | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(m, (nn.Conv2d, nn.Linear)): | ||
nn.init.trunc_normal_(m.weight, std=0.02) | ||
if m.bias is not None: | ||
nn.init.zeros_(m.bias) | ||
|
||
def _forward_impl(self, x: Tensor) -> Tensor: | ||
x = self.features(x) | ||
x = self.avgpool(x) | ||
x = self.classifier(x) | ||
return x | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
return self._forward_impl(x) | ||
|
||
|
||
def convnext_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: | ||
r"""ConvNeXt model architecture from the | ||
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper. | ||
|
||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
""" | ||
block_setting = [ | ||
CNBlockConfig(96, 192, 3), | ||
CNBlockConfig(192, 384, 3), | ||
CNBlockConfig(384, 768, 9), | ||
CNBlockConfig(768, None, 3), | ||
] | ||
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) | ||
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) | ||
if pretrained: | ||
arch = "convnext_tiny" | ||
if arch not in model_urls: | ||
raise ValueError(f"No checkpoint is available for model type {arch}") | ||
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) | ||
model.load_state_dict(state_dict) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Any, Optional | ||
|
||
from ...models.convnext import ConvNeXt, CNBlockConfig | ||
from ._api import WeightsEnum | ||
from ._utils import handle_legacy_interface, _ovewrite_named_param | ||
|
||
|
||
__all__ = ["ConvNeXt", "ConvNeXt_Tiny_Weights", "convnext_tiny"] | ||
|
||
|
||
class ConvNeXt_Tiny_Weights(WeightsEnum): | ||
pass | ||
|
||
|
||
@handle_legacy_interface(weights=("pretrained", None)) | ||
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: | ||
weights = ConvNeXt_Tiny_Weights.verify(weights) | ||
|
||
if weights is not None: | ||
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) | ||
|
||
block_setting = [ | ||
CNBlockConfig(96, 192, 3), | ||
CNBlockConfig(192, 384, 3), | ||
CNBlockConfig(384, 768, 9), | ||
CNBlockConfig(768, None, 3), | ||
] | ||
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) | ||
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) | ||
|
||
if weights is not None: | ||
model.load_state_dict(weights.get_state_dict(progress=progress)) | ||
|
||
return model |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At TorchVision, we try to reuse as much as possible standard PyTorch components. I understand that in your original implementation, you provide a custom implementation for channels first. Could you talk about the performance degradation you've experienced and how much this was that lead you to reimplementing it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is we don't have a standard PyTorch component for channels_first LN. Here are some discussions with Ross Wightman (https://twitter.com/wightmanr/status/1481383509142818817?s=20). In Ross's timm implementation, he has a presumably better LN implementation for channel_first. See here: https://github.com/rwightman/pytorch-image-models/blob/b669f4a5881d17fe3875656ec40138c1ef50c4c9/timm/models/convnext.py#L109
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the references. I wasn't aware of the concurrent discussions. I'll have a look to measure on our side. Ideally I would like to reuse the existing kernels as much as possible, unless there is a big gap in performance to justify a custom implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @datumbox,
In this post I try to explain why we use linear layers instead of conv layers for 1x1 convs in residual blocks, in case it is any help: facebookresearch/ConvNeXt#18 (comment)
As for why we use the custom LN in the downsampling layers instead of permuting -> PyTorch LN -> permuting back, the reason is similar, we observe the former is slightly faster when used in downsampling layers.
Thanks for your work on incorporating ConvNeXt!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@liuzhuang13 Thanks for the reply. Is it fair to say that the approach that we follow here is expected to be 0-5% slower than the optimum? I haven't had the chance to run benchmarks but that's what I understand from your note.
FYI, the issue is that TorchVision has a common API across all models to accept the
norm_layer
as a parameter which is going to be tricky to support if I switch the 1x1 convs to linear. Where or not we will do this, depends on the speed impact.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just tried switching to conv layers in residual blocks, and using permutation + LN in all cases. I found combined together they cause 20-30% slowdown in inference @ 224 resolution for ConvNeXt-T/S, compared to our released impl. However, in ConvNeXt-B at 224 or any model at 384 resolution, it seems as fast as our released impl. I only tried ConvNeXt T/S/B. This is on V100s, and I cannot say much on other platforms though. It is indeed a bit strange to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for looking into it @liuzhuang13. 20-30% slow sounds very large and we would probably want to make it faster. I'll run benchmarks later and share with you any findings.