diff --git a/tests/test_models.py b/tests/test_models.py index ca5cac05fc..bb98d43e24 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -170,11 +170,12 @@ def test_model_default_cfgs(model_name, batch_size): assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] # check classifier name matches default_cfg - classifier = cfg['classifier'] - if not isinstance(classifier, (tuple, list)): - classifier = classifier, - for c in classifier: - assert c + ".weight" in state_dict.keys(), f'{c} not in model params' + if cfg.get('num_classes', None): + classifier = cfg['classifier'] + if not isinstance(classifier, (tuple, list)): + classifier = classifier, + for c in classifier: + assert c + ".weight" in state_dict.keys(), f'{c} not in model params' # check first conv(s) names match default_cfg first_conv = cfg['first_conv'] @@ -222,11 +223,12 @@ def test_model_default_cfgs_non_std(model_name, batch_size): assert outputs.shape[1] == model.num_features # check classifier name matches default_cfg - classifier = cfg['classifier'] - if not isinstance(classifier, (tuple, list)): - classifier = classifier, - for c in classifier: - assert c + ".weight" in state_dict.keys(), f'{c} not in model params' + if cfg.get('num_classes', None): + classifier = cfg['classifier'] + if not isinstance(classifier, (tuple, list)): + classifier = classifier, + for c in classifier: + assert c + ".weight" in state_dict.keys(), f'{c} not in model params' # check first conv(s) names match default_cfg first_conv = cfg['first_conv'] diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 16ce64d0bf..880fcc637f 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -221,8 +221,8 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte if num_classes != default_cfg['num_classes']: for classifier_name in classifiers: # completely discard fully connected if model num_classes doesn't match pretrained weights - del state_dict[classifier_name + '.weight'] - del state_dict[classifier_name + '.bias'] + state_dict.pop(classifier_name + '.weight', None) + state_dict.pop(classifier_name + '.bias', None) strict = False elif label_offset > 0: for classifier_name in classifiers: diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 5c2346ce26..65acacab09 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -140,11 +140,25 @@ def _cfg(url='', **kwargs): num_classes=21843), # SAM trained models (https://arxiv.org/abs/2106.01548) - 'vit_base_patch32_sam_224': _cfg( + 'vit_base_patch32_224_sam': _cfg( url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), - 'vit_base_patch16_sam_224': _cfg( + 'vit_base_patch16_224_sam': _cfg( url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), + # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) + 'vit_small_patch16_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_small_patch8_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_base_patch16_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_base_patch8_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + # deit models (FB weights) 'deit_tiny_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', @@ -699,26 +713,6 @@ def vit_large_patch16_384(pretrained=False, **kwargs): return model -@register_model -def vit_base_patch16_sam_224(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 - """ - # NOTE original SAM weights release worked with representation_size=768 - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs) - model = _create_vision_transformer('vit_base_patch16_sam_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_base_patch32_sam_224(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 - """ - # NOTE original SAM weights release worked with representation_size=768 - model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs) - model = _create_vision_transformer('vit_base_patch32_sam_224', pretrained=pretrained, **model_kwargs) - return model - - @register_model def vit_huge_patch14_224(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). @@ -851,6 +845,62 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): return model +@register_model +def vit_base_patch16_224_sam(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 + """ + # NOTE original SAM weights release worked with representation_size=768 + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_224_sam(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 + """ + # NOTE original SAM weights release worked with representation_size=768 + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224_dino(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224_dino', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch8_224_dino(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 + """ + model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch8_224_dino', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_dino(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) /w DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_dino', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch8_224_dino(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 + """ + model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch8_224_dino', pretrained=pretrained, **model_kwargs) + return model + + @register_model def deit_tiny_patch16_224(pretrained=False, **kwargs): """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).