Skip to content

Prep for PyPi release. Tweak NFNet, ResNetV2, RexNet feature extraction, use pre-act feature… #475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ jobs:
matrix:
os: [ubuntu-latest, macOS-latest]
python: ['3.8']
torch: ['1.7.0']
torchvision: ['0.8.1']
torch: ['1.8.0']
torchvision: ['0.9.0']
runs-on: ${{ matrix.os }}

steps:
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## What's New

### March 7, 2021
* First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc).
* Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation.

### Feb 18, 2021
* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets).
* Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm',
'*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*'] + NON_STD_FILTERS
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*'] + NON_STD_FILTERS
else:
EXCLUDE_FILTERS = NON_STD_FILTERS

Expand Down
6 changes: 4 additions & 2 deletions timm/models/cspnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,11 @@ def forward(self, x):
if self.conv_down is not None:
x = self.conv_down(x)
x = self.conv_exp(x)
xs, xb = x.chunk(2, dim=1)
split = x.shape[1] // 2
xs, xb = x[:, :split], x[:, split:]
xb = self.blocks(xb)
out = self.conv_transition(torch.cat([xs, self.conv_transition_b(xb)], dim=1))
xb = self.conv_transition_b(xb).contiguous()
out = self.conv_transition(torch.cat([xs, xb], dim=1))
return out


Expand Down
2 changes: 1 addition & 1 deletion timm/models/layers/inplace_abn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def inplace_abn(x, weight, bias, running_mean, running_var,
training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
raise ImportError(
"Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
"Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'")

def inplace_abn_sync(**kwargs):
inplace_abn(**kwargs)
Expand Down
26 changes: 11 additions & 15 deletions timm/models/nfnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,12 @@ def _dcfg(url='', **kwargs):
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),

nfnet_l0a=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288)),
nfnet_l0b=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288)),
nfnet_l0c=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0c-ad1045c2.pth',
pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0),

nf_regnet_b0=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
Expand Down Expand Up @@ -376,9 +377,9 @@ def forward(self, x):
return out


def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None, preact_feature=True):
stem_stride = 2
stem_feature = dict(num_chs=out_chs, reduction=2, module='')
stem_feature = dict(num_chs=out_chs, reduction=2, module='stem.conv')
stem = OrderedDict()
assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
if 'deep' in stem_type:
Expand All @@ -388,14 +389,14 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs)
strides = (2, 1, 1, 2)
stem_stride = 4
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act4')
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv3')
else:
if 'tiered' in stem_type:
stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # 'T' resnets in resnet.py
else:
stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets
strides = (2, 1, 1)
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act3')
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv2')
last_idx = len(stem_chs) - 1
for i, (c, s) in enumerate(zip(stem_chs, strides)):
stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s)
Expand Down Expand Up @@ -477,7 +478,7 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
self.stem, stem_stride, stem_feat = create_stem(
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)

self.feature_info = [stem_feat] if stem_stride == 4 else []
self.feature_info = [stem_feat]
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
prev_chs = stem_chs
net_stride = stem_stride
Expand All @@ -486,8 +487,6 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
stages = []
for stage_idx, stage_depth in enumerate(cfg.depths):
stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
if stride == 2:
self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1')]
if net_stride >= output_stride and stride > 1:
dilation *= stride
stride = 1
Expand Down Expand Up @@ -522,18 +521,19 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
expected_var += cfg.alpha ** 2 # Even if reset occurs, increment expected variance
first_dilation = dilation
prev_chs = out_chs
self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')]
stages += [nn.Sequential(*blocks)]
self.stages = nn.Sequential(*stages)

if cfg.num_features:
# The paper NFRegNet models have an EfficientNet-like final head convolution.
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
self.final_conv = conv_layer(prev_chs, self.num_features, 1)
self.feature_info[-1] = dict(num_chs=self.num_features, reduction=net_stride, module=f'final_conv')
else:
self.num_features = prev_chs
self.final_conv = nn.Identity()
self.final_act = act_layer(inplace=cfg.num_features > 0)
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')]

self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)

Expand Down Expand Up @@ -572,10 +572,6 @@ def forward(self, x):
def _create_normfreenet(variant, pretrained=False, **kwargs):
model_cfg = model_cfgs[variant]
feature_cfg = dict(flatten_sequential=True)
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
if 'pool' in model_cfg.stem_type and 'deep' not in model_cfg.stem_type:
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2 feat for stride 4, 1 layer maxpool stems

return build_model_with_cfg(
NormFreeNet, variant, pretrained,
default_cfg=default_cfgs[variant],
Expand Down
15 changes: 3 additions & 12 deletions timm/models/resnetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ def __init__(self, layers, channels=(256, 512, 1024, 2048),
self.feature_info = []
stem_chs = make_div(stem_chs * wf)
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
# NOTE no, reduction 2 feature if preact
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module='' if preact else 'stem.norm'))
stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv') if preact else 'stem.norm'
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))

prev_chs = stem_chs
curr_stride = 4
Expand All @@ -343,10 +343,7 @@ def __init__(self, layers, channels=(256, 512, 1024, 2048),
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
prev_chs = out_chs
curr_stride *= stride
feat_name = f'stages.{stage_idx}'
if preact:
feat_name = f'stages.{stage_idx + 1}.blocks.0.norm1' if (stage_idx + 1) != len(channels) else 'norm'
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=feat_name)]
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')]
self.stages.add_module(str(stage_idx), stage)

self.num_features = prev_chs
Expand Down Expand Up @@ -414,13 +411,7 @@ def load_pretrained(self, checkpoint_path, prefix='resnet/'):


def _create_resnetv2(variant, pretrained=False, **kwargs):
# FIXME feature map extraction is not setup properly for pre-activation mode right now
preact = kwargs.get('preact', True)
feature_cfg = dict(flatten_sequential=True)
if preact:
feature_cfg['feature_cls'] = 'hook'
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for preact

return build_model_with_cfg(
ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True,
feature_cfg=feature_cfg, **kwargs)
Expand Down
30 changes: 12 additions & 18 deletions timm/models/rexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,23 @@ def forward(self, x):


class LinearBottleneck(nn.Module):
def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1, drop_path=None):
def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1,
act_layer='swish', dw_act_layer='relu6', drop_path=None):
super(LinearBottleneck, self).__init__()
self.use_shortcut = stride == 1 and in_chs <= out_chs
self.in_channels = in_chs
self.out_channels = out_chs

if exp_ratio != 1.:
dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div)
self.conv_exp = ConvBnAct(in_chs, dw_chs, act_layer="swish")
self.conv_exp = ConvBnAct(in_chs, dw_chs, act_layer=act_layer)
else:
dw_chs = in_chs
self.conv_exp = None

self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
self.se = SEWithNorm(dw_chs, se_ratio=se_ratio, divisor=ch_div) if se_ratio > 0. else None
self.act_dw = nn.ReLU6()
self.act_dw = create_act_layer(dw_act_layer)

self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False)
self.drop_path = drop_path
Expand Down Expand Up @@ -131,8 +132,7 @@ def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se


def _build_blocks(
block_cfg, prev_chs, width_mult, ch_div=1, drop_path_rate=0., feature_location='bottleneck'):
feat_exp = feature_location == 'expansion'
block_cfg, prev_chs, width_mult, ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_path_rate=0.):
feat_chs = [prev_chs]
feature_info = []
curr_stride = 2
Expand All @@ -141,41 +141,37 @@ def _build_blocks(
for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg):
if stride > 1:
fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}'
if block_idx > 0 and feat_exp:
fname += '.act_dw'
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)]
curr_stride *= stride
block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule
drop_path = DropPath(block_dpr) if block_dpr > 0. else None
features.append(LinearBottleneck(
in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, se_ratio=se_ratio,
ch_div=ch_div, drop_path=drop_path))
ch_div=ch_div, act_layer=act_layer, dw_act_layer=dw_act_layer, drop_path=drop_path))
prev_chs = chs
feat_chs += [features[-1].feat_channels(feat_exp)]
feat_chs += [features[-1].feat_channels()]
pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
feature_info += [dict(
num_chs=pen_chs if feat_exp else feat_chs[-1], reduction=curr_stride,
module=f'features.{len(features) - int(not feat_exp)}')]
features.append(ConvBnAct(prev_chs, pen_chs, act_layer="swish"))
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=f'features.{len(features) - 1}')]
features.append(ConvBnAct(prev_chs, pen_chs, act_layer=act_layer))
return features, feature_info


class ReXNetV1(nn.Module):
def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32,
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12.,
ch_div=1, drop_rate=0.2, drop_path_rate=0., feature_location='bottleneck'):
ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_rate=0.2, drop_path_rate=0.):
super(ReXNetV1, self).__init__()
self.drop_rate = drop_rate
self.num_classes = num_classes

assert output_stride == 32 # FIXME support dilation
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32
stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer='swish')
self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer)

block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div)
features, self.feature_info = _build_blocks(
block_cfg, stem_chs, width_mult, ch_div, drop_path_rate, feature_location)
block_cfg, stem_chs, width_mult, ch_div, act_layer, dw_act_layer, drop_path_rate)
self.num_features = features[-1].out_channels
self.features = nn.Sequential(*features)

Expand All @@ -202,8 +198,6 @@ def forward(self, x):

def _create_rexnet(variant, pretrained, **kwargs):
feature_cfg = dict(flatten_sequential=True)
if kwargs.get('feature_location', '') == 'expansion':
feature_cfg['feature_cls'] = 'hook'
return build_model_with_cfg(
ReXNetV1, variant, pretrained, default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion timm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.4.4'
__version__ = '0.4.5'