Skip to content

Commit

Permalink
[Feature] timm backbones wrapper (open-mmlab#427)
Browse files Browse the repository at this point in the history
* Add wrapper to use backbones from timm

* Add tests

* Remove timm from optional deps and modify GitHub workflow.

Co-authored-by: mzr1996 <mzr1996@163.com>
  • Loading branch information
amirassov and mzr1996 authored Sep 6, 2021
1 parent 1b92cac commit 6ec098d
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 10 deletions.
55 changes: 46 additions & 9 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,57 @@ jobs:
- name: Linting
run: pre-commit run --all-files

build:
build_without_timm:
runs-on: ubuntu-latest
env:
UBUNTU_VERSION: ubuntu1804
strategy:
matrix:
python-version: [3.7]
torch: [1.3.0, 1.5.0, 1.6.0, 1.7.0, 1.8.0, 1.9.0]
python-version: [3.6]
torch: [1.3.0, 1.8.0, 1.9.0]
include:
- torch: 1.3.0
torchvision: 0.4.2
- torch: 1.8.0
torchvision: 0.9.0
- torch: 1.9.0
torchvision: 0.10.0

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install Pillow
run: pip install Pillow==6.2.2
if: ${{matrix.torchvision < 0.5}}
- name: Install PyTorch
run: pip install --use-deprecated=legacy-resolver torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install MMCV
run: |
pip install --use-deprecated=legacy-resolver mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html
python -c 'import mmcv; print(mmcv.__version__)'
- name: Install mmcls dependencies
run: |
pip install -r requirements.txt
- name: Build and install
run: |
rm -rf .eggs
pip install -e . -U
- name: Run unittests
run: |
pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
build:
runs-on: ubuntu-latest
env:
UBUNTU_VERSION: ubuntu1804
strategy:
matrix:
python-version: [3.7]
torch: [1.5.0, 1.6.0, 1.7.0, 1.8.0, 1.9.0]
include:
- torch: 1.5.0
torchvision: 0.6.0
- torch: 1.6.0
Expand All @@ -40,9 +80,6 @@ jobs:
torchvision: 0.8.1
- torch: 1.8.0
torchvision: 0.9.0
- torch: 1.8.0
torchvision: 0.9.0
python-version: 3.6
- torch: 1.8.0
torchvision: 0.9.0
python-version: 3.8
Expand All @@ -51,9 +88,6 @@ jobs:
python-version: 3.9
- torch: 1.9.0
torchvision: 0.10.0
- torch: 1.9.0
torchvision: 0.10.0
python-version: 3.6
- torch: 1.9.0
torchvision: 0.10.0
python-version: 3.8
Expand All @@ -79,6 +113,9 @@ jobs:
- name: Install mmcls dependencies
run: |
pip install -r requirements.txt
- name: Install timm
run: |
pip install timm
- name: Build and install
run: |
rm -rf .eggs
Expand Down
3 changes: 2 additions & 1 deletion mmcls/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .shufflenet_v1 import ShuffleNetV1
from .shufflenet_v2 import ShuffleNetV2
from .swin_transformer import SwinTransformer
from .timm_backbone import TIMMBackbone
from .tnt import TNT
from .vgg import VGG
from .vision_transformer import VisionTransformer
Expand All @@ -21,5 +22,5 @@
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
'SwinTransformer', 'TNT'
'SwinTransformer', 'TNT', 'TIMMBackbone'
]
57 changes: 57 additions & 0 deletions mmcls/models/backbones/timm_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
try:
import timm
except ImportError:
timm = None

from ..builder import BACKBONES
from .base_backbone import BaseBackbone


@BACKBONES.register_module()
class TIMMBackbone(BaseBackbone):
"""Wrapper to use backbones from timm library. More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_ .
Args:
model_name (str): Name of timm model to instantiate.
pretrained (bool): Load pretrained weights if True.
checkpoint_path (str): Path of checkpoint to load after
model is initialized.
in_channels (int): Number of input image channels. Default: 3.
init_cfg (dict, optional): Initialization config dict
**kwargs: Other timm & model specific arguments.
"""

def __init__(
self,
model_name,
pretrained=False,
checkpoint_path='',
in_channels=3,
init_cfg=None,
**kwargs,
):
if timm is None:
raise RuntimeError('timm is not installed')
super(TIMMBackbone, self).__init__(init_cfg)
self.timm_model = timm.create_model(
model_name=model_name,
pretrained=pretrained,
in_chans=in_channels,
checkpoint_path=checkpoint_path,
**kwargs,
)

# Make unused parameters None
self.timm_model.global_pool = None
self.timm_model.fc = None
self.timm_model.classifier = None

# Hack to use pretrained weights from timm
if pretrained or checkpoint_path:
self._is_init = True

def forward(self, x):
features = self.timm_model.forward_features(x)
return features
41 changes: 41 additions & 0 deletions tests/test_models/test_backbones/test_timm_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from torch.nn.modules.batchnorm import _BatchNorm

from mmcls.models.backbones import TIMMBackbone


def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True


def test_timm_backbone():
with pytest.raises(TypeError):
# pretrained must be a string path
model = TIMMBackbone()
model.init_weights(pretrained=0)

# Test resnet18 from timm
model = TIMMBackbone(model_name='resnet18')
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)

imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat.shape == torch.Size((1, 512, 7, 7))

# Test efficientnet_b1 with pretrained weights
model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True)
model.init_weights()
model.train()

imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat.shape == torch.Size((1, 1280, 7, 7))

0 comments on commit 6ec098d

Please sign in to comment.