Skip to content

Commit

Permalink
Move DeiT to own file, vit getting crowded. Working towards fixing hu…
Browse files Browse the repository at this point in the history
…ggingface#1029, make pooling interface for transformers and mlp closer to convnets. Still working through some details...
  • Loading branch information
rwightman committed Jan 27, 2022
1 parent 95cfc9b commit 5f81d4d
Show file tree
Hide file tree
Showing 19 changed files with 370 additions and 291 deletions.
11 changes: 6 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,22 +205,23 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
outputs = model.forward_features(input_tensor)
if isinstance(outputs, (tuple, list)):
outputs = outputs[0]
assert outputs.shape[1] == model.num_features
feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[feat_dim] == model.num_features

# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0)
outputs = model.forward(input_tensor)
if isinstance(outputs, (tuple, list)):
outputs = outputs[0]
assert len(outputs.shape) == 2
assert outputs.shape[1] == model.num_features
feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[feat_dim] == model.num_features

model = create_model(model_name, pretrained=False, num_classes=0).eval()
outputs = model.forward(input_tensor)
if isinstance(outputs, (tuple, list)):
outputs = outputs[0]
assert len(outputs.shape) == 2
assert outputs.shape[1] == model.num_features
feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[feat_dim] == model.num_features

# check classifier name matches default_cfg
if cfg.get('num_classes', None):
Expand Down
1 change: 1 addition & 0 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .convnext import *
from .crossvit import *
from .cspnet import *
from .deit import *
from .densenet import *
from .dla import *
from .dpn import *
Expand Down
37 changes: 19 additions & 18 deletions timm/models/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,15 @@ class Beit(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""

def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=None,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
use_mean_pooling=True, init_scale=0.001):
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=None, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
head_init_scale=0.001):
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models

self.patch_embed = PatchEmbed(
Expand All @@ -247,10 +249,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None
self.pos_drop = nn.Dropout(p=drop_rate)

if use_shared_rel_pos_bias:
Expand All @@ -266,8 +265,9 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None)
for i in range(depth)])
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
use_fc_norm = self.global_pool == 'avg'
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else None
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

self.apply(self._init_weights)
Expand All @@ -278,8 +278,8 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
self.fix_init_weight()
if isinstance(self.head, nn.Linear):
trunc_normal_(self.head.weight, std=.02)
self.head.weight.data.mul_(init_scale)
self.head.bias.data.mul_(init_scale)
self.head.weight.data.mul_(head_init_scale)
self.head.bias.data.mul_(head_init_scale)

def fix_init_weight(self):
def rescale(param, layer_id):
Expand Down Expand Up @@ -327,14 +327,15 @@ def forward_features(self, x):
x = blk(x, rel_pos_bias=rel_pos_bias)

x = self.norm(x)
if self.fc_norm is not None:
t = x[:, 1:, :]
return self.fc_norm(t.mean(1))
else:
return x[:, 0]
return x

def forward(self, x):
x = self.forward_features(x)
if self.fc_norm is not None:
x = x[:, 1:].mean(dim=1)
x = self.fc_norm(x)
else:
x = x[:, 0]
x = self.head(x)
return x

Expand Down
44 changes: 20 additions & 24 deletions timm/models/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,11 @@ def __init__(
act_layer=nn.GELU,
attn_block=TalkingHeadAttn,
mlp_block=Mlp,
init_scale=1e-4,
init_values=1e-4,
attn_block_token_only=ClassAttn,
mlp_block_token_only=Mlp,
depth_token_only=2,
mlp_ratio_clstk=4.0
mlp_ratio_token_only=4.0
):
super().__init__()

Expand All @@ -234,19 +234,19 @@ def __init__(
self.pos_drop = nn.Dropout(p=drop_rate)

dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList([
self.blocks = nn.Sequential(*[
block_layers(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_scale)
act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_values)
for i in range(depth)])

self.blocks_token_only = nn.ModuleList([
block_layers_token(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias,
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_token_only, qkv_bias=qkv_bias,
drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer,
act_layer=act_layer, attn_block=attn_block_token_only,
mlp_block=mlp_block_token_only, init_values=init_scale)
mlp_block=mlp_block_token_only, init_values=init_values)
for i in range(depth_token_only)])

self.norm = norm_layer(embed_dim)
Expand Down Expand Up @@ -281,25 +281,21 @@ def reset_classifier(self, num_classes, global_pool=''):
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)

cls_tokens = self.cls_token.expand(B, -1, -1)

x = x + self.pos_embed
x = self.pos_drop(x)
x = self.blocks(x)

for i, blk in enumerate(self.blocks):
x = blk(x)

cls_tokens = self.cls_token.expand(B, -1, -1)
for i, blk in enumerate(self.blocks_token_only):
cls_tokens = blk(x, cls_tokens)

x = torch.cat((cls_tokens, x), dim=1)

x = self.norm(x)
return x[:, 0]
return x

def forward(self, x):
x = self.forward_features(x)
x = x[:, 0]
x = self.head(x)
return x

Expand All @@ -326,69 +322,69 @@ def _create_cait(variant, pretrained=False, **kwargs):

@register_model
def cait_xxs24_224(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_scale=1e-5, **kwargs)
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5, **kwargs)
model = _create_cait('cait_xxs24_224', pretrained=pretrained, **model_args)
return model


@register_model
def cait_xxs24_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_scale=1e-5, **kwargs)
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5, **kwargs)
model = _create_cait('cait_xxs24_384', pretrained=pretrained, **model_args)
return model


@register_model
def cait_xxs36_224(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_scale=1e-5, **kwargs)
model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5, **kwargs)
model = _create_cait('cait_xxs36_224', pretrained=pretrained, **model_args)
return model


@register_model
def cait_xxs36_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_scale=1e-5, **kwargs)
model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5, **kwargs)
model = _create_cait('cait_xxs36_384', pretrained=pretrained, **model_args)
return model


@register_model
def cait_xs24_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_scale=1e-5, **kwargs)
model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_values=1e-5, **kwargs)
model = _create_cait('cait_xs24_384', pretrained=pretrained, **model_args)
return model


@register_model
def cait_s24_224(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_scale=1e-5, **kwargs)
model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5, **kwargs)
model = _create_cait('cait_s24_224', pretrained=pretrained, **model_args)
return model


@register_model
def cait_s24_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_scale=1e-5, **kwargs)
model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5, **kwargs)
model = _create_cait('cait_s24_384', pretrained=pretrained, **model_args)
return model


@register_model
def cait_s36_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_scale=1e-6, **kwargs)
model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_values=1e-6, **kwargs)
model = _create_cait('cait_s36_384', pretrained=pretrained, **model_args)
return model


@register_model
def cait_m36_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_scale=1e-6, **kwargs)
model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_values=1e-6, **kwargs)
model = _create_cait('cait_m36_384', pretrained=pretrained, **model_args)
return model


@register_model
def cait_m48_448(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_scale=1e-6, **kwargs)
model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_values=1e-6, **kwargs)
model = _create_cait('cait_m48_448', pretrained=pretrained, **model_args)
return model
24 changes: 12 additions & 12 deletions timm/models/coat.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def __init__(
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
else:
# CoaT-Lite series: Use feature of last scale for classification.
self.aggregate = None
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

# Initialize weights.
Expand Down Expand Up @@ -542,8 +543,7 @@ def forward_features(self, x0):
else:
# Return features for classification.
x4 = self.norm4(x4)
x4_cls = x4[:, 0]
return x4_cls
return x4

# Parallel blocks.
for blk in self.parallel_blocks:
Expand Down Expand Up @@ -574,20 +574,20 @@ def forward_features(self, x0):
x2 = self.norm2(x2)
x3 = self.norm3(x3)
x4 = self.norm4(x4)
x2_cls = x2[:, :1] # [B, 1, C]
x3_cls = x3[:, :1]
x4_cls = x4[:, :1]
merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # [B, 3, C]
merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C]
return merged_cls

def forward(self, x):
if self.return_interm_layers:
return [x2, x3, x4]

def forward(self, x) -> torch.Tensor:
if not torch.jit.is_scripting() and self.return_interm_layers:
# Return intermediate features (for down-stream tasks).
return self.forward_features(x)
else:
# Return features for classification.
x = self.forward_features(x)
x_feat = self.forward_features(x)
if isinstance(x_feat, (tuple, list)):
x = torch.cat([xl[:, :1] for xl in x_feat], dim=1) # [B, 3, C]
x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C]
else:
x = x_feat[:, 0]
x = self.head(x)
return x

Expand Down
3 changes: 2 additions & 1 deletion timm/models/convit.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,11 @@ def forward_features(self, x):
x = blk(x)

x = self.norm(x)
return x[:, 0]
return x

def forward(self, x):
x = self.forward_features(x)
x = x[:, 0]
x = self.head(x)
return x

Expand Down
3 changes: 1 addition & 2 deletions timm/models/convmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,12 @@ def reset_classifier(self, num_classes, global_pool=''):
def forward_features(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.pooling(x)
return x

def forward(self, x):
x = self.forward_features(x)
x = self.pooling(x)
x = self.head(x)

return x


Expand Down
1 change: 0 additions & 1 deletion timm/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ def checkpoint_filter_fn(state_dict, model):
def _create_convnext(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
ConvNeXt, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
**kwargs)
Expand Down
6 changes: 3 additions & 3 deletions timm/models/crossvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def reset_classifier(self, num_classes, global_pool=''):
[nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in
range(self.num_branches)])

def forward_features(self, x):
def forward_features(self, x) -> List[torch.Tensor]:
B = x.shape[0]
xs = []
for i, patch_embed in enumerate(self.patch_embed):
Expand All @@ -389,11 +389,11 @@ def forward_features(self, x):

# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
return [xo[:, 0] for xo in xs]
return xs

def forward(self, x):
xs = self.forward_features(x)
ce_logits = [head(xs[i]) for i, head in enumerate(self.head)]
ce_logits = [head(xs[i][:, 0]) for i, head in enumerate(self.head)]
if not isinstance(self.head[0], nn.Identity):
ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)
return ce_logits
Expand Down
Loading

0 comments on commit 5f81d4d

Please sign in to comment.