Skip to content

Commit 8345720

Browse files
Yshuo-Liliyinshuo
and
liyinshuo
authored
Do not load pretrained VGG (open-mmlab#466)
* Don not load pretrained VGG * Add load_pretrained_vgg Co-authored-by: liyinshuo <liyinshuo@sensetime.com>
1 parent a5e1561 commit 8345720

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

mmedit/models/extractors/lte.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,18 @@ class LTE(nn.Module):
1818
requires_grad (bool): Require grad or not. Default: True.
1919
pixel_range (float): Pixel range of geature. Default: 1.
2020
pretrained (str): Path for pretrained model. Default: None.
21+
load_pretrained_vgg (bool): Load pretrained VGG from torchvision.
22+
Default: True.
23+
Train: must load pretrained VGG
24+
Eval: needn't load pretrained VGG, because we will load pretrained
25+
LTE.
2126
"""
2227

23-
def __init__(self, requires_grad=True, pixel_range=1., pretrained=None):
28+
def __init__(self,
29+
requires_grad=True,
30+
pixel_range=1.,
31+
pretrained=None,
32+
load_pretrained_vgg=True):
2433
super().__init__()
2534

2635
vgg_mean = (0.485, 0.456, 0.406)
@@ -30,7 +39,8 @@ def __init__(self, requires_grad=True, pixel_range=1., pretrained=None):
3039
pixel_range=pixel_range, img_mean=vgg_mean, img_std=vgg_std)
3140

3241
# use vgg19 weights to initialize
33-
vgg_pretrained_features = models.vgg19(pretrained=True).features
42+
vgg_pretrained_features = models.vgg19(
43+
pretrained=load_pretrained_vgg).features
3444

3545
self.slice1 = torch.nn.Sequential()
3646
self.slice2 = torch.nn.Sequential()

tests/test_models/test_extractors/test_lte.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
def test_lte():
88
model_cfg = dict(
9-
type='LTE', requires_grad=False, pixel_range=1., pretrained=None)
9+
type='LTE',
10+
requires_grad=False,
11+
pixel_range=1.,
12+
pretrained=None,
13+
load_pretrained_vgg=False)
1014

1115
lte = build_component(model_cfg)
1216
assert lte.__class__.__name__ == 'LTE'

tests/test_models/test_restorers/test_ttsr.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_ttsr():
7070
out_channels=3,
7171
mid_channels=64,
7272
num_blocks=(16, 16, 8, 4)),
73-
extractor=dict(type='LTE'),
73+
extractor=dict(type='LTE', load_pretrained_vgg=False),
7474
transformer=dict(type='SearchTransformer'),
7575
discriminator=dict(type='TTSRDiscriminator', in_size=64),
7676
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'),

0 commit comments

Comments
 (0)