forked from huggingface/pytorch-image-models
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MobileViT models (w/ ByobNet base). Close huggingface#1038.
- Loading branch information
Showing
2 changed files
with
249 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
""" MobileViT | ||
Paper: | ||
`MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178 | ||
MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below) | ||
License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source) | ||
Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022, Ross Wightman | ||
""" | ||
# | ||
# For licensing see accompanying LICENSE file. | ||
# Copyright (C) 2020 Apple Inc. All Rights Reserved. | ||
# | ||
import math | ||
from typing import Union, Callable, Dict, Tuple, Optional | ||
|
||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups | ||
from .layers import to_2tuple, make_divisible | ||
from .vision_transformer import Block as TransformerBlock | ||
from .helpers import build_model_with_cfg | ||
from .registry import register_model | ||
|
||
__all__ = [] | ||
|
||
|
||
def _cfg(url='', **kwargs): | ||
return { | ||
'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), | ||
'crop_pct': 0.9, 'interpolation': 'bicubic', | ||
'mean': (0, 0, 0), 'std': (1, 1, 1), | ||
'first_conv': 'stem.conv', 'classifier': 'head.fc', | ||
'fixed_input_size': False, 'min_input_size': (3, 256, 256), | ||
**kwargs | ||
} | ||
|
||
|
||
default_cfgs = { | ||
# GPU-Efficient (ResNet) weights | ||
'mobilevit_xxs': _cfg( | ||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xxs-ad385b40.pth'), | ||
'mobilevit_xs': _cfg( | ||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xs-8fbd6366.pth'), | ||
'mobilevit_s': _cfg( | ||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'), | ||
} | ||
|
||
|
||
def _inverted_residual_block(d, c, s, br=4.0): | ||
# inverted residual is a bottleneck block with bottle_ratio > 1 applied to in_chs, linear output, gs=1 (depthwise) | ||
return ByoBlockCfg( | ||
type='bottle', d=d, c=c, s=s, gs=1, br=br, | ||
block_kwargs=dict(bottle_in=True, linear_out=True)) | ||
|
||
|
||
def _mobilevit_block(d, c, s, transformer_dim, transformer_depth, patch_size=4, br=4.0): | ||
# inverted residual + mobilevit blocks as per MobileViT network | ||
return ( | ||
_inverted_residual_block(d=d, c=c, s=s, br=br), | ||
ByoBlockCfg( | ||
type='mobilevit', d=1, c=c, s=1, | ||
block_kwargs=dict( | ||
transformer_dim=transformer_dim, | ||
transformer_depth=transformer_depth, | ||
patch_size=patch_size) | ||
) | ||
) | ||
|
||
|
||
model_cfgs = dict( | ||
mobilevit_xxs=ByoModelCfg( | ||
blocks=( | ||
_inverted_residual_block(d=1, c=16, s=1, br=2.0), | ||
_inverted_residual_block(d=3, c=24, s=2, br=2.0), | ||
_mobilevit_block(d=1, c=48, s=2, transformer_dim=64, transformer_depth=2, patch_size=2, br=2.0), | ||
_mobilevit_block(d=1, c=64, s=2, transformer_dim=80, transformer_depth=4, patch_size=2, br=2.0), | ||
_mobilevit_block(d=1, c=80, s=2, transformer_dim=96, transformer_depth=3, patch_size=2, br=2.0), | ||
), | ||
stem_chs=16, | ||
stem_type='3x3', | ||
stem_pool='', | ||
downsample='', | ||
act_layer='silu', | ||
num_features=320, | ||
), | ||
|
||
mobilevit_xs=ByoModelCfg( | ||
blocks=( | ||
_inverted_residual_block(d=1, c=32, s=1), | ||
_inverted_residual_block(d=3, c=48, s=2), | ||
_mobilevit_block(d=1, c=64, s=2, transformer_dim=96, transformer_depth=2, patch_size=2), | ||
_mobilevit_block(d=1, c=80, s=2, transformer_dim=120, transformer_depth=4, patch_size=2), | ||
_mobilevit_block(d=1, c=96, s=2, transformer_dim=144, transformer_depth=3, patch_size=2), | ||
), | ||
stem_chs=16, | ||
stem_type='3x3', | ||
stem_pool='', | ||
downsample='', | ||
act_layer='silu', | ||
num_features=384, | ||
), | ||
|
||
mobilevit_s=ByoModelCfg( | ||
blocks=( | ||
_inverted_residual_block(d=1, c=32, s=1), | ||
_inverted_residual_block(d=3, c=64, s=2), | ||
_mobilevit_block(d=1, c=96, s=2, transformer_dim=144, transformer_depth=2, patch_size=2), | ||
_mobilevit_block(d=1, c=128, s=2, transformer_dim=192, transformer_depth=4, patch_size=2), | ||
_mobilevit_block(d=1, c=160, s=2, transformer_dim=240, transformer_depth=3, patch_size=2), | ||
), | ||
stem_chs=16, | ||
stem_type='3x3', | ||
stem_pool='', | ||
downsample='', | ||
act_layer='silu', | ||
num_features=640, | ||
), | ||
) | ||
|
||
|
||
class MobileViTBlock(nn.Module): | ||
""" MobileViT block | ||
Paper: https://arxiv.org/abs/2110.02178?context=cs.LG | ||
""" | ||
def __init__( | ||
self, | ||
in_chs: int, | ||
out_chs: Optional[int] = None, | ||
kernel_size: int = 3, | ||
stride: int = 1, | ||
bottle_ratio: float = 1.0, | ||
group_size: Optional[int] = None, | ||
dilation: Tuple[int, int] = (1, 1), | ||
mlp_ratio: float = 2.0, | ||
transformer_dim: Optional[int] = None, | ||
transformer_depth: int = 2, | ||
patch_size: int = 8, | ||
num_heads: int = 4, | ||
attn_drop: float = 0., | ||
drop: int = 0., | ||
no_fusion: bool = False, | ||
drop_path_rate: float = 0., | ||
layers: LayerFn = None, | ||
transformer_norm_layer: Callable = nn.LayerNorm, | ||
downsample: str = '' | ||
): | ||
super(MobileViTBlock, self).__init__() | ||
|
||
layers = layers or LayerFn() | ||
groups = num_groups(group_size, in_chs) | ||
out_chs = out_chs or in_chs | ||
transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs) | ||
|
||
self.conv_kxk = layers.conv_norm_act( | ||
in_chs, in_chs, kernel_size=kernel_size, | ||
stride=stride, groups=groups, dilation=dilation[0]) | ||
self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False) | ||
|
||
self.transformer = nn.Sequential(*[ | ||
TransformerBlock( | ||
transformer_dim, mlp_ratio=mlp_ratio, num_heads=num_heads, qkv_bias=True, | ||
attn_drop=attn_drop, drop=drop, drop_path=drop_path_rate, | ||
act_layer=layers.act, norm_layer=transformer_norm_layer) | ||
for _ in range(transformer_depth) | ||
]) | ||
self.norm = transformer_norm_layer(transformer_dim) | ||
|
||
self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1) | ||
|
||
if no_fusion: | ||
self.conv_fusion = None | ||
else: | ||
self.conv_fusion = layers.conv_norm_act(in_chs + out_chs, out_chs, kernel_size=kernel_size, stride=1) | ||
|
||
self.patch_size = to_2tuple(patch_size) | ||
self.patch_area = self.patch_size[0] * self.patch_size[1] | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
shortcut = x | ||
|
||
# Local representation | ||
x = self.conv_kxk(x) | ||
x = self.conv_1x1(x) | ||
|
||
# Unfold (feature map -> patches) | ||
patch_h, patch_w = self.patch_size | ||
B, C, H, W = x.shape | ||
new_h, new_w = int(math.ceil(H / patch_h) * patch_h), int(math.ceil(W / patch_w) * patch_w) | ||
num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w | ||
num_patches = num_patch_h * num_patch_w # N | ||
interpolate = False | ||
if new_h != H or new_w != W: | ||
# Note: Padding can be done, but then it needs to be handled in attention function. | ||
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False) | ||
interpolate = True | ||
|
||
# [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w] | ||
x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w, patch_w).transpose(1, 2) | ||
# [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w | ||
x = x.reshape(B, C, num_patches, self.patch_area).transpose(1, 3).reshape(B * self.patch_area, num_patches, -1) | ||
|
||
# Global representations | ||
x = self.transformer(x) | ||
x = self.norm(x) | ||
|
||
# Fold (patch -> feature map) | ||
# [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w] | ||
x = x.contiguous().view(B, self.patch_area, num_patches, -1) | ||
x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w, patch_h, patch_w) | ||
# [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W] | ||
x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w) | ||
if interpolate: | ||
x = F.interpolate(x, size=(H, W), mode="bilinear", align_corners=False) | ||
|
||
x = self.conv_proj(x) | ||
if self.conv_fusion is not None: | ||
x = self.conv_fusion(torch.cat((shortcut, x), dim=1)) | ||
return x | ||
|
||
|
||
register_block('mobilevit', MobileViTBlock) | ||
|
||
|
||
def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs): | ||
return build_model_with_cfg( | ||
ByobNet, variant, pretrained, | ||
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], | ||
feature_cfg=dict(flatten_sequential=True), | ||
**kwargs) | ||
|
||
|
||
@register_model | ||
def mobilevit_xxs(pretrained=False, **kwargs): | ||
return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs) | ||
|
||
|
||
@register_model | ||
def mobilevit_xs(pretrained=False, **kwargs): | ||
return _create_mobilevit('mobilevit_xs', pretrained=pretrained, **kwargs) | ||
|
||
|
||
@register_model | ||
def mobilevit_s(pretrained=False, **kwargs): | ||
return _create_mobilevit('mobilevit_s', pretrained=pretrained, **kwargs) |