diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 09be56dbb7e..f8714170b10 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -21,17 +21,57 @@ jobs: - name: Linting run: pre-commit run --all-files - build: + build_without_timm: runs-on: ubuntu-latest env: UBUNTU_VERSION: ubuntu1804 strategy: matrix: - python-version: [3.7] - torch: [1.3.0, 1.5.0, 1.6.0, 1.7.0, 1.8.0, 1.9.0] + python-version: [3.6] + torch: [1.3.0, 1.8.0, 1.9.0] include: - torch: 1.3.0 torchvision: 0.4.2 + - torch: 1.8.0 + torchvision: 0.9.0 + - torch: 1.9.0 + torchvision: 0.10.0 + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install Pillow + run: pip install Pillow==6.2.2 + if: ${{matrix.torchvision < 0.5}} + - name: Install PyTorch + run: pip install --use-deprecated=legacy-resolver torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html + - name: Install MMCV + run: | + pip install --use-deprecated=legacy-resolver mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html + python -c 'import mmcv; print(mmcv.__version__)' + - name: Install mmcls dependencies + run: | + pip install -r requirements.txt + - name: Build and install + run: | + rm -rf .eggs + pip install -e . -U + - name: Run unittests + run: | + pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py + + build: + runs-on: ubuntu-latest + env: + UBUNTU_VERSION: ubuntu1804 + strategy: + matrix: + python-version: [3.7] + torch: [1.5.0, 1.6.0, 1.7.0, 1.8.0, 1.9.0] + include: - torch: 1.5.0 torchvision: 0.6.0 - torch: 1.6.0 @@ -40,9 +80,6 @@ jobs: torchvision: 0.8.1 - torch: 1.8.0 torchvision: 0.9.0 - - torch: 1.8.0 - torchvision: 0.9.0 - python-version: 3.6 - torch: 1.8.0 torchvision: 0.9.0 python-version: 3.8 @@ -51,9 +88,6 @@ jobs: python-version: 3.9 - torch: 1.9.0 torchvision: 0.10.0 - - torch: 1.9.0 - torchvision: 0.10.0 - python-version: 3.6 - torch: 1.9.0 torchvision: 0.10.0 python-version: 3.8 @@ -79,6 +113,9 @@ jobs: - name: Install mmcls dependencies run: | pip install -r requirements.txt + - name: Install timm + run: | + pip install timm - name: Build and install run: | rm -rf .eggs diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index 47beb4c9b3e..ee255a0b0c9 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -13,6 +13,7 @@ from .shufflenet_v1 import ShuffleNetV1 from .shufflenet_v2 import ShuffleNetV2 from .swin_transformer import SwinTransformer +from .timm_backbone import TIMMBackbone from .tnt import TNT from .vgg import VGG from .vision_transformer import VisionTransformer @@ -21,5 +22,5 @@ 'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer', - 'SwinTransformer', 'TNT' + 'SwinTransformer', 'TNT', 'TIMMBackbone' ] diff --git a/mmcls/models/backbones/timm_backbone.py b/mmcls/models/backbones/timm_backbone.py new file mode 100644 index 00000000000..6a7109be11d --- /dev/null +++ b/mmcls/models/backbones/timm_backbone.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +try: + import timm +except ImportError: + timm = None + +from ..builder import BACKBONES +from .base_backbone import BaseBackbone + + +@BACKBONES.register_module() +class TIMMBackbone(BaseBackbone): + """Wrapper to use backbones from timm library. More details can be found in + `timm `_ . + + Args: + model_name (str): Name of timm model to instantiate. + pretrained (bool): Load pretrained weights if True. + checkpoint_path (str): Path of checkpoint to load after + model is initialized. + in_channels (int): Number of input image channels. Default: 3. + init_cfg (dict, optional): Initialization config dict + **kwargs: Other timm & model specific arguments. + """ + + def __init__( + self, + model_name, + pretrained=False, + checkpoint_path='', + in_channels=3, + init_cfg=None, + **kwargs, + ): + if timm is None: + raise RuntimeError('timm is not installed') + super(TIMMBackbone, self).__init__(init_cfg) + self.timm_model = timm.create_model( + model_name=model_name, + pretrained=pretrained, + in_chans=in_channels, + checkpoint_path=checkpoint_path, + **kwargs, + ) + + # Make unused parameters None + self.timm_model.global_pool = None + self.timm_model.fc = None + self.timm_model.classifier = None + + # Hack to use pretrained weights from timm + if pretrained or checkpoint_path: + self._is_init = True + + def forward(self, x): + features = self.timm_model.forward_features(x) + return features diff --git a/tests/test_models/test_backbones/test_timm_backbone.py b/tests/test_models/test_backbones/test_timm_backbone.py new file mode 100644 index 00000000000..660741c7f19 --- /dev/null +++ b/tests/test_models/test_backbones/test_timm_backbone.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch +from torch.nn.modules.batchnorm import _BatchNorm + +from mmcls.models.backbones import TIMMBackbone + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + + +def test_timm_backbone(): + with pytest.raises(TypeError): + # pretrained must be a string path + model = TIMMBackbone() + model.init_weights(pretrained=0) + + # Test resnet18 from timm + model = TIMMBackbone(model_name='resnet18') + model.init_weights() + model.train() + assert check_norm_state(model.modules(), True) + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat.shape == torch.Size((1, 512, 7, 7)) + + # Test efficientnet_b1 with pretrained weights + model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat.shape == torch.Size((1, 1280, 7, 7))