Skip to content

Commit

Permalink
[Fix] Fix wrong init usage in transformer models (open-mmlab#1069)
Browse files Browse the repository at this point in the history
* fix wrong trunc_normal_init usage

* fix mit init weights

* fix vit init weights

* fix mit init weights

* fix typo

* fix swin init weights
  • Loading branch information
Junjun2016 authored Dec 6, 2021
1 parent 2918220 commit 3057ef6
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 32 deletions.
20 changes: 8 additions & 12 deletions mmseg/models/backbones/mit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

import torch
import torch.nn as nn
from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer,
constant_init, normal_init, trunc_normal_init)
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import MultiheadAttention
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
trunc_normal_init)
from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint

from ...utils import get_root_logger
Expand Down Expand Up @@ -343,7 +344,7 @@ def __init__(self,
norm_cfg=dict(type='LN', eps=1e-6),
pretrained=None,
init_cfg=None):
super().__init__()
super().__init__(init_cfg=init_cfg)

if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
Expand All @@ -365,7 +366,6 @@ def __init__(self,
self.out_indices = out_indices
assert max(out_indices) < self.num_stages
self.pretrained = pretrained
self.init_cfg = init_cfg

# transformer encoder
dpr = [
Expand Down Expand Up @@ -407,19 +407,15 @@ def init_weights(self):
if self.pretrained is None:
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m.weight, std=.02)
if m.bias is not None:
constant_init(m.bias, 0)
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
constant_init(m, val=1.0, bias=0.)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
constant_init(m.bias, 0)
normal_init(
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
elif isinstance(self.pretrained, str):
logger = get_root_logger()
checkpoint = _load_checkpoint(
Expand Down
15 changes: 7 additions & 8 deletions mmseg/models/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
trunc_normal_init)
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from mmcv.utils import to_2tuple

Expand Down Expand Up @@ -73,7 +75,7 @@ def __init__(self,
self.softmax = nn.Softmax(dim=-1)

def init_weights(self):
trunc_normal_init(self.relative_position_bias_table, std=0.02)
trunc_normal_(self.relative_position_bias_table, std=0.02)

def forward(self, x, mask=None):
"""
Expand Down Expand Up @@ -665,15 +667,12 @@ def init_weights(self):
f'{self.__class__.__name__}, '
f'training start from scratch')
if self.use_abs_pos_embed:
trunc_normal_init(self.absolute_pos_embed, std=0.02)
trunc_normal_(self.absolute_pos_embed, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m.weight, std=.02)
if m.bias is not None:
constant_init(m.bias, 0)
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
constant_init(m, val=1.0, bias=0.)
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
Expand Down
22 changes: 10 additions & 12 deletions mmseg/models/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import torch
import torch.nn as nn
from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
normal_init, trunc_normal_init)
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
trunc_normal_)
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
Expand Down Expand Up @@ -292,23 +293,20 @@ def init_weights(self):
else:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
trunc_normal_init(self.pos_embed, std=.02)
trunc_normal_init(self.cls_token, std=.02)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m.weight, std=.02)
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
if 'ffn' in n:
normal_init(m.bias, std=1e-6)
nn.init.normal_(m.bias, mean=0., std=1e-6)
else:
constant_init(m.bias, 0)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m.weight, mode='fan_in')
if m.bias is not None:
constant_init(m.bias, 0)
kaiming_init(m, mode='fan_in', bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
constant_init(m, val=1.0, bias=0.)

def _pos_embeding(self, patched_img, hw_shape, pos_embed):
"""Positiong embeding method.
Expand Down

0 comments on commit 3057ef6

Please sign in to comment.