Skip to content

Commit

Permalink
Transitioning default_cfg -> pretrained_cfg. Improving handling of pr…
Browse files Browse the repository at this point in the history
…etrained_cfg source (HF-Hub, files, timm config, etc). Checkpoint handling tweaks.
  • Loading branch information
rwightman committed Jan 26, 2022
1 parent de5fa79 commit abc9ba2
Show file tree
Hide file tree
Showing 61 changed files with 321 additions and 280 deletions.
2 changes: 1 addition & 1 deletion clean_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and
# load with the unmodified model using BatchNorm2d.
continue
name = k[7:] if k.startswith('module') else k
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(checkpoint))

Expand Down
10 changes: 5 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
has_fx_feature_extraction = False

import timm
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
get_model_default_value
from timm import list_models, create_model, set_scriptable, has_pretrained_cfg_key, is_pretrained_cfg_key, \
get_pretrained_cfg_value
from timm.models.fx_features import _leaf_modules, _autowrap_functions

if hasattr(torch._C, '_jit_set_profiling_executor'):
Expand Down Expand Up @@ -54,9 +54,9 @@
def _get_input_size(model=None, model_name='', target=None):
if model is None:
assert model_name, "One of model or model_name must be provided"
input_size = get_model_default_value(model_name, 'input_size')
fixed_input_size = get_model_default_value(model_name, 'fixed_input_size')
min_input_size = get_model_default_value(model_name, 'min_input_size')
input_size = get_pretrained_cfg_value(model_name, 'input_size')
fixed_input_size = get_pretrained_cfg_value(model_name, 'fixed_input_size')
min_input_size = get_pretrained_cfg_value(model_name, 'min_input_size')
else:
default_cfg = model.default_cfg
input_size = default_cfg['input_size']
Expand Down
4 changes: 2 additions & 2 deletions timm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .version import __version__
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
get_model_default_value, is_model_pretrained
is_scriptable, is_exportable, set_scriptable, set_exportable, has_pretrained_cfg_key, is_pretrained_cfg_key, \
get_pretrained_cfg_value, is_model_pretrained
4 changes: 2 additions & 2 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
from .xception_aligned import *
from .xcit import *

from .factory import create_model, split_model_name, safe_model_name
from .factory import create_model, parse_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value
4 changes: 1 addition & 3 deletions timm/models/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,14 +339,12 @@ def forward(self, x):
return x


def _create_beit(variant, pretrained=False, default_cfg=None, **kwargs):
default_cfg = default_cfg or default_cfgs[variant]
def _create_beit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Beit models.')

model = build_model_with_cfg(
Beit, variant, pretrained,
default_cfg=default_cfg,
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
Expand Down
1 change: 0 additions & 1 deletion timm/models/byoanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,6 @@ def _cfg(url='', **kwargs):
def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
return build_model_with_cfg(
ByobNet, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True),
**kwargs)
Expand Down
1 change: 0 additions & 1 deletion timm/models/byobnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,7 +1553,6 @@ def _init_weights(module, name='', zero_init_last=False):
def _create_byobnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ByobNet, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=model_cfgs[variant],
feature_cfg=dict(flatten_sequential=True),
**kwargs)
3 changes: 1 addition & 2 deletions timm/models/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from functools import partial

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .helpers import build_model_with_cfg
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from .registry import register_model

Expand Down Expand Up @@ -318,7 +318,6 @@ def _create_cait(variant, pretrained=False, **kwargs):

model = build_model_with_cfg(
Cait, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
Expand Down
3 changes: 1 addition & 2 deletions timm/models/coat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.nn.functional as F

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .helpers import build_model_with_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model
from .layers import _assert
Expand Down Expand Up @@ -610,7 +610,6 @@ def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs):

model = build_model_with_cfg(
CoaT, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
Expand Down
5 changes: 1 addition & 4 deletions timm/models/convit.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,7 @@ def _create_convit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')

return build_model_with_cfg(
ConViT, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
return build_model_with_cfg(ConViT, variant, pretrained, **kwargs)


@register_model
Expand Down
2 changes: 1 addition & 1 deletion timm/models/convmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(self, x):


def _create_convmixer(variant, pretrained=False, **kwargs):
return build_model_with_cfg(ConvMixer, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs)
return build_model_with_cfg(ConvMixer, variant, pretrained, **kwargs)


@register_model
Expand Down
1 change: 0 additions & 1 deletion timm/models/crossvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,6 @@ def pretrained_filter_fn(state_dict):

return build_model_with_cfg(
CrossViT, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=pretrained_filter_fn,
**kwargs)

Expand Down
1 change: 0 additions & 1 deletion timm/models/cspnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,6 @@ def _create_cspnet(variant, pretrained=False, **kwargs):
cfg_variant = variant.split('_')[0]
return build_model_with_cfg(
CspNet, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant],
**kwargs)

Expand Down
1 change: 0 additions & 1 deletion timm/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs):
kwargs['block_config'] = block_config
return build_model_with_cfg(
DenseNet, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained,
**kwargs)

Expand Down
1 change: 0 additions & 1 deletion timm/models/dla.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,6 @@ def forward(self, x):
def _create_dla(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
DLA, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_strict=False,
feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)),
**kwargs)
Expand Down
1 change: 0 additions & 1 deletion timm/models/dpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ def forward(self, x):
def _create_dpn(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
DPN, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_concat=True, flatten_sequential=True),
**kwargs)

Expand Down
7 changes: 3 additions & 4 deletions timm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg, default_cfg_for_features
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct
from .registry import register_model

Expand Down Expand Up @@ -599,12 +599,11 @@ def _create_effnet(variant, pretrained=False, **kwargs):
model_cls = EfficientNetFeatures
model = build_model_with_cfg(
model_cls, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_strict=not features_only,
kwargs_filter=kwargs_filter,
**kwargs)
if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg)
model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
return model


Expand Down Expand Up @@ -1475,7 +1474,7 @@ def efficientnet_b0_g16_evos(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ group 16 conv + EvoNorm"""
model = _gen_efficientnet(
'efficientnet_b0_g16_evos', group_size=16, channel_divisor=16,
norm_layer=partial(EvoNorm2dS0, group_size=16), pretrained=pretrained, **kwargs)
pretrained=pretrained, **kwargs) #norm_layer=partial(EvoNorm2dS0, group_size=16),
return model


Expand Down
49 changes: 23 additions & 26 deletions timm/models/factory.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,36 @@
from urllib.parse import urlsplit, urlunsplit
import os

from .registry import is_model, is_model_in_modules, model_entrypoint
from .helpers import load_checkpoint
from .layers import set_layer_config
from .hub import load_model_config_from_hf


def split_model_name(model_name):
model_split = model_name.split(':', 1)
if len(model_split) == 1:
return '', model_split[0]
def parse_model_name(model_name):
model_name = model_name.replace('hf_hub', 'hf-hub') # NOTE for backwards compat, to deprecate hf_hub use
parsed = urlsplit(model_name)
assert parsed.scheme in ('', 'timm', 'hf-hub')
if parsed.scheme == 'hf-hub':
# FIXME may use fragment as revision, currently `@` in URI path
return parsed.scheme, parsed.path
else:
source_name, model_name = model_split
assert source_name in ('timm', 'hf_hub')
return source_name, model_name
model_name = os.path.split(parsed.path)[-1]
return 'timm', model_name


def safe_model_name(model_name, remove_source=True):
def make_safe(name):
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
if remove_source:
model_name = split_model_name(model_name)[-1]
model_name = parse_model_name(model_name)[-1]
return make_safe(model_name)


def create_model(
model_name,
pretrained=False,
pretrained_cfg=None,
checkpoint_path='',
scriptable=None,
exportable=None,
Expand All @@ -45,33 +51,24 @@ def create_model(
global_pool (str): global pool type (default: 'avg')
**: other kwargs are model specific
"""
source_name, model_name = split_model_name(model_name)

# handle backwards compat with drop_connect -> drop_path change
drop_connect_rate = kwargs.pop('drop_connect_rate', None)
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'."
" Setting drop_path to %f." % drop_connect_rate)
kwargs['drop_path_rate'] = drop_connect_rate

# Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that
# non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if source_name == 'hf_hub':
# For model names specified in the form `hf_hub:path/architecture_name#revision`,
# load model weights + default_cfg from Hugging Face hub.
hf_default_cfg, model_name = load_model_config_from_hf(model_name)
kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday
model_source, model_name = parse_model_name(model_name)
if model_source == 'hf-hub':
# FIXME hf-hub source overrides any passed in pretrained_cfg, warn?
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name = load_model_config_from_hf(model_name)

if is_model(model_name):
create_fn = model_entrypoint(model_name)
else:
if not is_model(model_name):
raise RuntimeError('Unknown model (%s)' % model_name)

create_fn = model_entrypoint(model_name)
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(pretrained=pretrained, **kwargs)
model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs)

if checkpoint_path:
load_checkpoint(model, checkpoint_path)
Expand Down
50 changes: 43 additions & 7 deletions timm/models/fx_features.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
""" PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html
"""
from typing import Callable
from typing import Callable, List, Dict, Union

import torch
from torch import nn

from .features import _get_feature_info

try:
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False
Expand Down Expand Up @@ -61,18 +63,52 @@ def register_notrace_function(func: Callable):
return func


def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
return _create_feature_extractor(
model, return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
)


class FeatureGraphNet(nn.Module):
""" A FX Graph based feature extractor that works with the model feature_info metadata
"""
def __init__(self, model, out_indices, out_map=None):
super().__init__()
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
self.feature_info = _get_feature_info(model, out_indices)
if out_map is not None:
assert len(out_map) == len(out_indices)
return_nodes = {info['module']: out_map[i] if out_map is not None else info['module']
for i, info in enumerate(self.feature_info) if i in out_indices}
self.graph_module = create_feature_extractor(
model, return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
return_nodes = {
info['module']: out_map[i] if out_map is not None else info['module']
for i, info in enumerate(self.feature_info) if i in out_indices}
self.graph_module = create_feature_extractor(model, return_nodes)

def forward(self, x):
return list(self.graph_module(x).values())


class FeatureExtractNet(nn.Module):
""" A standalone feature extraction wrapper that maps dict -> list or single tensor
NOTE:
* one can use feature_extractor directly if dictionary output is desired
* unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
metadata for builtin feature extraction mode
* feature_extractor can be used directly if dictionary output is acceptable
Args:
model: model to extract features from
return_nodes: node names to return features from (dict or list)
squeeze_out: if only one output, and output in list format, flatten to single tensor
"""
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
super().__init__()
self.squeeze_out = squeeze_out
self.graph_module = create_feature_extractor(model, return_nodes)

def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
out = list(self.graph_module(x).values())
if self.squeeze_out and len(out) == 1:
return out[0]
return out
1 change: 0 additions & 1 deletion timm/models/ghostnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs):
)
return build_model_with_cfg(
GhostNet, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True),
**model_kwargs)

Expand Down
5 changes: 1 addition & 4 deletions timm/models/gluon_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ def _cfg(url='', **kwargs):


def _create_resnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ResNet, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)


@register_model
Expand Down
Loading

0 comments on commit abc9ba2

Please sign in to comment.