Skip to content

Commit

Permalink
[Feature] Add with cp to mit and vit (open-mmlab#1431)
Browse files Browse the repository at this point in the history
* add with cp to mit and vit

* add test unit

Co-authored-by: jiangyitong <jiangyitong1@sensetime.com>
  • Loading branch information
jiangyitong and jiangyitong1 authored Apr 1, 2022
1 parent 368d821 commit 7b6953f
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 10 deletions.
29 changes: 24 additions & 5 deletions mmseg/models/backbones/mit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import MultiheadAttention
Expand Down Expand Up @@ -235,6 +236,8 @@ class TransformerEncoderLayer(BaseModule):
Default:None.
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
Attention of Segformer. Default: 1.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""

def __init__(self,
Expand All @@ -248,7 +251,8 @@ def __init__(self,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
batch_first=True,
sr_ratio=1):
sr_ratio=1,
with_cp=False):
super(TransformerEncoderLayer, self).__init__()

# The ret[0] of build_norm_layer is norm name.
Expand All @@ -275,9 +279,19 @@ def __init__(self,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)

self.with_cp = with_cp

def forward(self, x, hw_shape):
x = self.attn(self.norm1(x), hw_shape, identity=x)
x = self.ffn(self.norm2(x), hw_shape, identity=x)

def _inner_forward(x):
x = self.attn(self.norm1(x), hw_shape, identity=x)
x = self.ffn(self.norm2(x), hw_shape, identity=x)
return x

if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x


Expand Down Expand Up @@ -319,6 +333,8 @@ class MixVisionTransformer(BaseModule):
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""

def __init__(self,
Expand All @@ -339,7 +355,8 @@ def __init__(self,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN', eps=1e-6),
pretrained=None,
init_cfg=None):
init_cfg=None,
with_cp=False):
super(MixVisionTransformer, self).__init__(init_cfg=init_cfg)

assert not (init_cfg and pretrained), \
Expand All @@ -358,8 +375,9 @@ def __init__(self,
self.patch_sizes = patch_sizes
self.strides = strides
self.sr_ratios = sr_ratios
self.with_cp = with_cp
assert num_stages == len(num_layers) == len(num_heads) \
== len(patch_sizes) == len(strides) == len(sr_ratios)
== len(patch_sizes) == len(strides) == len(sr_ratios)

self.out_indices = out_indices
assert max(out_indices) < self.num_stages
Expand Down Expand Up @@ -392,6 +410,7 @@ def __init__(self,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
])
in_channels = embed_dims_i
Expand Down
21 changes: 18 additions & 3 deletions mmseg/models/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
Expand Down Expand Up @@ -41,6 +42,8 @@ class TransformerEncoderLayer(BaseModule):
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default: True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""

def __init__(self,
Expand All @@ -54,7 +57,8 @@ def __init__(self,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
batch_first=True):
batch_first=True,
with_cp=False):
super(TransformerEncoderLayer, self).__init__()

self.norm1_name, norm1 = build_norm_layer(
Expand Down Expand Up @@ -82,6 +86,8 @@ def __init__(self,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)

self.with_cp = with_cp

@property
def norm1(self):
return getattr(self, self.norm1_name)
Expand All @@ -91,8 +97,16 @@ def norm2(self):
return getattr(self, self.norm2_name)

def forward(self, x):
x = self.attn(self.norm1(x), identity=x)
x = self.ffn(self.norm2(x), identity=x)

def _inner_forward(x):
x = self.attn(self.norm1(x), identity=x)
x = self.ffn(self.norm2(x), identity=x)
return x

if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x


Expand Down Expand Up @@ -251,6 +265,7 @@ def __init__(self,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
batch_first=True))

self.final_norm = final_norm
Expand Down
11 changes: 10 additions & 1 deletion tests/test_models/test_backbones/test_mit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch

from mmseg.models.backbones import MixVisionTransformer
from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN
from mmseg.models.backbones.mit import (EfficientMultiheadAttention, MixFFN,
TransformerEncoderLayer)


def test_mit():
Expand Down Expand Up @@ -56,6 +57,14 @@ def test_mit():
outs = MHA(temp, hw_shape, temp)
assert out.shape == (1, token_len, 64)

# Test TransformerEncoderLayer with checkpoint forward
block = TransformerEncoderLayer(
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
assert block.with_cp
x = torch.randn(1, 56 * 56, 64)
x_out = block(x, (56, 56))
assert x_out.shape == torch.Size([1, 56 * 56, 64])


def test_mit_init():
path = 'PATH_THAT_DO_NOT_EXIST'
Expand Down
11 changes: 10 additions & 1 deletion tests/test_models/test_backbones/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import pytest
import torch

from mmseg.models.backbones.vit import VisionTransformer
from mmseg.models.backbones.vit import (TransformerEncoderLayer,
VisionTransformer)
from .utils import check_norm_state


Expand Down Expand Up @@ -119,6 +120,14 @@ def test_vit_backbone():
assert feat[0][0].shape == (1, 768, 14, 14)
assert feat[0][1].shape == (1, 768)

# Test TransformerEncoderLayer with checkpoint forward
block = TransformerEncoderLayer(
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
assert block.with_cp
x = torch.randn(1, 56 * 56, 64)
x_out = block(x)
assert x_out.shape == torch.Size([1, 56 * 56, 64])


def test_vit_init():
path = 'PATH_THAT_DO_NOT_EXIST'
Expand Down

0 comments on commit 7b6953f

Please sign in to comment.