Skip to content

Commit

Permalink
[Improve] Fixed typo in RepVGG. (open-mmlab#985)
Browse files Browse the repository at this point in the history
* [Improve] Use `forward_dummy` to calculate FLOPS. (open-mmlab#953)

* fixed

Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com>
  • Loading branch information
techmonsterwang and twmht authored Aug 22, 2022
1 parent 5ad3bed commit ec71d07
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion configs/mobilenet_v3/mobilenet-v3-small_8xb16_cifar10.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
_base_ = [
'../_base_/models/mobilenet-v3-small_8xb16_cifar.py',
'../_base_/models/mobilenet-v3-small_cifar.py',
'../_base_/datasets/cifar10_bs16.py',
'../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
]
Expand Down
8 changes: 4 additions & 4 deletions mmcls/models/backbones/repvgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _fuse_conv_bn(self, branch):

return fused_weight, fused_bias

def _norm_to_conv3x3(self, branch_nrom):
def _norm_to_conv3x3(self, branch_norm):
"""Convert a norm layer to a conv3x3-bn sequence.
Args:
Expand All @@ -242,15 +242,15 @@ def _norm_to_conv3x3(self, branch_nrom):
"""
input_dim = self.in_channels // self.groups
conv_weight = torch.zeros((self.in_channels, input_dim, 3, 3),
dtype=branch_nrom.weight.dtype)
dtype=branch_norm.weight.dtype)

for i in range(self.in_channels):
conv_weight[i, i % input_dim, 1, 1] = 1
conv_weight = conv_weight.to(branch_nrom.weight.device)
conv_weight = conv_weight.to(branch_norm.weight.device)

tmp_conv3x3 = self.create_conv_bn(kernel_size=3)
tmp_conv3x3.conv.weight.data = conv_weight
tmp_conv3x3.norm = branch_nrom
tmp_conv3x3.norm = branch_norm
return tmp_conv3x3


Expand Down

0 comments on commit ec71d07

Please sign in to comment.