Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add with cp to mit and vit #1431

Merged
merged 3 commits into from
Apr 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions mmseg/models/backbones/mit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import MultiheadAttention
Expand Down Expand Up @@ -235,6 +236,8 @@ class TransformerEncoderLayer(BaseModule):
Default:None.
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
Attention of Segformer. Default: 1.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""

def __init__(self,
Expand All @@ -248,7 +251,8 @@ def __init__(self,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
batch_first=True,
sr_ratio=1):
sr_ratio=1,
with_cp=False):
super(TransformerEncoderLayer, self).__init__()

# The ret[0] of build_norm_layer is norm name.
Expand All @@ -275,9 +279,19 @@ def __init__(self,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)

self.with_cp = with_cp

def forward(self, x, hw_shape):
x = self.attn(self.norm1(x), hw_shape, identity=x)
x = self.ffn(self.norm2(x), hw_shape, identity=x)

def _inner_forward(x):
x = self.attn(self.norm1(x), hw_shape, identity=x)
x = self.ffn(self.norm2(x), hw_shape, identity=x)
return x

if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
jiangyitong marked this conversation as resolved.
Show resolved Hide resolved
else:
x = _inner_forward(x)
return x


Expand Down Expand Up @@ -319,6 +333,8 @@ class MixVisionTransformer(BaseModule):
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""

def __init__(self,
Expand All @@ -339,7 +355,8 @@ def __init__(self,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN', eps=1e-6),
pretrained=None,
init_cfg=None):
init_cfg=None,
with_cp=False):
super(MixVisionTransformer, self).__init__(init_cfg=init_cfg)

assert not (init_cfg and pretrained), \
Expand All @@ -358,8 +375,9 @@ def __init__(self,
self.patch_sizes = patch_sizes
self.strides = strides
self.sr_ratios = sr_ratios
self.with_cp = with_cp
assert num_stages == len(num_layers) == len(num_heads) \
== len(patch_sizes) == len(strides) == len(sr_ratios)
== len(patch_sizes) == len(strides) == len(sr_ratios)

self.out_indices = out_indices
assert max(out_indices) < self.num_stages
Expand Down Expand Up @@ -392,6 +410,7 @@ def __init__(self,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
])
in_channels = embed_dims_i
Expand Down
21 changes: 18 additions & 3 deletions mmseg/models/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
Expand Down Expand Up @@ -41,6 +42,8 @@ class TransformerEncoderLayer(BaseModule):
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default: True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""

def __init__(self,
Expand All @@ -54,7 +57,8 @@ def __init__(self,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
batch_first=True):
batch_first=True,
with_cp=False):
super(TransformerEncoderLayer, self).__init__()

self.norm1_name, norm1 = build_norm_layer(
Expand Down Expand Up @@ -82,6 +86,8 @@ def __init__(self,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)

self.with_cp = with_cp

@property
def norm1(self):
return getattr(self, self.norm1_name)
Expand All @@ -91,8 +97,16 @@ def norm2(self):
return getattr(self, self.norm2_name)

def forward(self, x):
x = self.attn(self.norm1(x), identity=x)
x = self.ffn(self.norm2(x), identity=x)

def _inner_forward(x):
x = self.attn(self.norm1(x), identity=x)
x = self.ffn(self.norm2(x), identity=x)
return x

if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x


Expand Down Expand Up @@ -251,6 +265,7 @@ def __init__(self,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
batch_first=True))

self.final_norm = final_norm
Expand Down