Skip to content

Commit

Permalink
fix: running_mean, running_var as non-parameter
Browse files Browse the repository at this point in the history
- Remove self-defined BatchNorm2d
- Add named_buffers() in load_weight module
  • Loading branch information
hankyul2 committed Dec 3, 2021
1 parent c84a07d commit d3fdee2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
12 changes: 2 additions & 10 deletions efficientnetv2/efficientnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,6 @@
from efficientnetv2.pretrained_weight_loader import load_from_zoo


class BatchNorm2d(nn.BatchNorm2d):
"""Redefine BatchNorm2d to copy running mean/var from tensorflow model"""
def __init__(self, *args, **kwargs):
super(BatchNorm2d, self).__init__(*args, **kwargs)
self.running_mean = nn.Parameter(self.running_mean, requires_grad=False)
self.running_var = nn.Parameter(self.running_var, requires_grad=False)


class ConvBNAct(nn.Sequential):
"""Convolution-Normalization-Activation Module"""
def __init__(self, in_channel, out_channel, kernel_size, stride, groups, norm_layer, act, conv_layer=nn.Conv2d):
Expand Down Expand Up @@ -72,7 +64,7 @@ def forward(self, x):
class MBConvConfig:
"""EfficientNet Building block configuration"""
def __init__(self, expand_ratio: float, kernel: int, stride: int, in_ch: int, out_ch: int, layers: int,
use_se: bool, fused: bool, act=nn.SiLU, norm_layer=BatchNorm2d):
use_se: bool, fused: bool, act=nn.SiLU, norm_layer=nn.BatchNorm2d):
self.expand_ratio = expand_ratio
self.kernel = kernel
self.stride = stride
Expand Down Expand Up @@ -142,7 +134,7 @@ class EfficientNetV2(nn.Module):
- stochastic depth: stochastic depth probability
"""
def __init__(self, layer_infos, out_channels=1280, nclass=0, dropout=0.2, stochastic_depth=0.0,
block=MBConv, act_layer=nn.SiLU, norm_layer=BatchNorm2d):
block=MBConv, act_layer=nn.SiLU, norm_layer=nn.BatchNorm2d):
super(EfficientNetV2, self).__init__()
self.layer_infos = layer_infos
self.norm_layer = norm_layer
Expand Down
4 changes: 3 additions & 1 deletion efficientnetv2/pretrained_weight_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,13 @@ def load_npy(model, weight):
('\\.(\\d+)\\.', lambda x: f'_{int(x.group(1))}/'),
]

for name, param in model.named_parameters():
for name, param in list(model.named_parameters()) + list(model.named_buffers()):
for pattern, sub in name_convertor:
name = re.sub(pattern, sub, name)
if 'dense/kernel' in name and list(param.shape) not in [[1000, 1280], [21843, 1280]]:
continue
if 'dense/bias' in name and list(param.shape) not in [[1000], [21843]]:
continue
if 'num_batches_tracked' in name:
continue
param.data.copy_(npz_dim_convertor(name, weight.get(name)))

0 comments on commit d3fdee2

Please sign in to comment.