Skip to content

Commit 3b3cbf2

Browse files
Yshuo-Liliyinshuo
and
liyinshuo
authored
[Feature] Add config of ttsr-gan (open-mmlab#398)
* Update * [Feature] Add config of ttsr-gan Co-authored-by: liyinshuo <liyinshuo@sensetime.com>
1 parent f52c456 commit 3b3cbf2

File tree

3 files changed

+298
-4
lines changed

3 files changed

+298
-4
lines changed

configs/restorers/ttsr/README.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Evaluated on RGB channels, `scale` pixels in each border are cropped before eval
2020

2121
The metrics are `PSNR / SSIM`.
2222

23-
| Method | scale | CUFED | Download |
24-
| :---------------------------------------------------------------------------------------------: | :---: | :--------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
25-
| [ttsr-rec_x4_c64b16_g1_200k_CUFED](/configs/restorers/ttsr/ttsr-rec_x4_c64b16_g1_200k_CUFED.py) | x4 | 25.2433 / 0.7491 | [model](https://download.openmmlab.com/mmediting/restorers/ttsr/ttsr-rec_x4_c64b16_g1_200k_CUFED_20210525-b0dba584.pth?versionId=CAEQKxiBgIDht5ONzRciIDdjZTQ1NmFmYzhjNjQ5NGFhNjkyNzU1N2UxMjMyZWE4) \| [log](https://download.openmmlab.com/mmediting/restorers/ttsr/ttsr-rec_x4_c64b16_g1_200k_CUFED_20210525-b0dba584.log.json?versionId=CAEQKxiCgMCnuJONzRciIDUzNmVkNGNmNTlkMDQzMmFhZDAzYzQ5NmUzNTI5YmYz) |
23+
| Method | scale | CUFED | Download |
24+
| :---------------------------------------------------------------------------------------------: | :---: | :--------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
25+
| [ttsr-rec_x4_c64b16_g1_200k_CUFED](/configs/restorers/ttsr/ttsr-rec_x4_c64b16_g1_200k_CUFED.py) | x4 | 25.2433 / 0.7491 | [model](https://download.openmmlab.com/mmediting/restorers/ttsr/ttsr-rec_x4_c64b16_g1_200k_CUFED_20210525-b0dba584.pth) \| [log](https://download.openmmlab.com/mmediting/restorers/ttsr/ttsr-rec_x4_c64b16_g1_200k_CUFED_20210525-b0dba584.log.json) |
26+
| [ttsr-gan_x4_c64b16_g1_500k_CUFED](/configs/restorers/ttsr/ttsr-gan_x4_c64b16_g1_500k_CUFED.py) | x4 | 24.6075 / 0.7234 | [model](https://download.openmmlab.com/mmediting/restorers/ttsr/ttsr-gan_x4_c64b16_g1_500k_CUFED_20210626-2ab28ca0.pth) \| [log](https://download.openmmlab.com/mmediting/restorers/ttsr/ttsr-gan_x4_c64b16_g1_500k_CUFED_20210626-2ab28ca0.log.json) |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
exp_name = 'ttsr-gan_x4_c64b16_g1_500k_CUFED'
2+
scale = 4
3+
4+
# model settings
5+
model = dict(
6+
type='TTSR',
7+
generator=dict(
8+
type='TTSRNet',
9+
in_channels=3,
10+
out_channels=3,
11+
mid_channels=64,
12+
num_blocks=(16, 16, 8, 4)),
13+
extractor=dict(type='LTE'),
14+
transformer=dict(type='SearchTransformer'),
15+
discriminator=dict(type='TTSRDiscriminator', in_size=160),
16+
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'),
17+
perceptual_loss=dict(
18+
type='PerceptualLoss',
19+
layer_weights={'29': 1.0},
20+
vgg_type='vgg19',
21+
perceptual_weight=1e-2,
22+
style_weight=0,
23+
criterion='mse'),
24+
transferal_perceptual_loss=dict(
25+
type='TransferalPerceptualLoss',
26+
loss_weight=1e-2,
27+
use_attention=False,
28+
criterion='mse'),
29+
gan_loss=dict(
30+
type='GANLoss',
31+
gan_type='vanilla',
32+
loss_weight=1e-3,
33+
real_label_val=1.0,
34+
fake_label_val=0))
35+
# model training and testing settings
36+
train_cfg = dict(fix_iter=25000, disc_steps=2)
37+
test_cfg = dict(metrics=['PSNR', 'SSIM'], crop_border=scale)
38+
39+
# dataset settings
40+
train_dataset_type = 'SRFolderRefDataset'
41+
val_dataset_type = 'SRFolderRefDataset'
42+
test_dataset_type = 'SRFolderRefDataset'
43+
train_pipeline = [
44+
dict(
45+
type='LoadImageFromFile',
46+
io_backend='disk',
47+
key='gt',
48+
flag='color',
49+
channel_order='rgb',
50+
backend='pillow'),
51+
dict(
52+
type='LoadImageFromFile',
53+
io_backend='disk',
54+
key='ref',
55+
flag='color',
56+
channel_order='rgb',
57+
backend='pillow'),
58+
dict(type='CropLike', target_key='ref', reference_key='gt'),
59+
dict(
60+
type='Resize',
61+
scale=1 / scale,
62+
keep_ratio=True,
63+
keys=['gt', 'ref'],
64+
output_keys=['lq', 'ref_down'],
65+
interpolation='bicubic',
66+
backend='pillow'),
67+
dict(
68+
type='Resize',
69+
scale=float(scale),
70+
keep_ratio=True,
71+
keys=['lq', 'ref_down'],
72+
output_keys=['lq_up', 'ref_downup'],
73+
interpolation='bicubic',
74+
backend='pillow'),
75+
dict(
76+
type='Normalize',
77+
keys=['lq', 'gt'],
78+
mean=[127.5, 127.5, 127.5],
79+
std=[127.5, 127.5, 127.5]),
80+
dict(
81+
type='Normalize',
82+
keys=['lq_up', 'ref', 'ref_downup'],
83+
mean=[0., 0., 0.],
84+
std=[255., 255., 255.]),
85+
dict(
86+
type='Flip',
87+
keys=['lq', 'gt', 'lq_up'],
88+
flip_ratio=0.5,
89+
direction='horizontal'),
90+
dict(
91+
type='Flip',
92+
keys=['lq', 'gt', 'lq_up'],
93+
flip_ratio=0.5,
94+
direction='vertical'),
95+
dict(
96+
type='RandomTransposeHW',
97+
keys=['lq', 'gt', 'lq_up'],
98+
transpose_ratio=0.5),
99+
dict(
100+
type='Flip',
101+
keys=['ref', 'ref_downup'],
102+
flip_ratio=0.5,
103+
direction='horizontal'),
104+
dict(
105+
type='Flip',
106+
keys=['ref', 'ref_downup'],
107+
flip_ratio=0.5,
108+
direction='vertical'),
109+
dict(
110+
type='RandomTransposeHW',
111+
keys=['ref', 'ref_downup'],
112+
transpose_ratio=0.5),
113+
dict(
114+
type='ImageToTensor', keys=['lq', 'gt', 'lq_up', 'ref', 'ref_downup']),
115+
dict(
116+
type='Collect',
117+
keys=['lq', 'gt', 'lq_up', 'ref', 'ref_downup'],
118+
meta_keys=['gt_path', 'ref_path'])
119+
]
120+
valid_pipeline = [
121+
dict(
122+
type='LoadImageFromFile',
123+
io_backend='disk',
124+
key='gt',
125+
flag='color',
126+
channel_order='rgb',
127+
backend='pillow'),
128+
dict(
129+
type='LoadImageFromFile',
130+
io_backend='disk',
131+
key='ref',
132+
flag='color',
133+
channel_order='rgb',
134+
backend='pillow'),
135+
dict(type='CropLike', target_key='ref', reference_key='gt'),
136+
dict(
137+
type='Resize',
138+
scale=1 / scale,
139+
keep_ratio=True,
140+
keys=['gt', 'ref'],
141+
output_keys=['lq', 'ref_down'],
142+
interpolation='bicubic',
143+
backend='pillow'),
144+
dict(
145+
type='Resize',
146+
scale=float(scale),
147+
keep_ratio=True,
148+
keys=['lq', 'ref_down'],
149+
output_keys=['lq_up', 'ref_downup'],
150+
interpolation='bicubic',
151+
backend='pillow'),
152+
dict(
153+
type='Normalize',
154+
keys=['lq', 'gt'],
155+
mean=[127.5, 127.5, 127.5],
156+
std=[127.5, 127.5, 127.5]),
157+
dict(
158+
type='Normalize',
159+
keys=['lq_up', 'ref', 'ref_downup'],
160+
mean=[0., 0., 0.],
161+
std=[255., 255., 255.]),
162+
dict(
163+
type='ImageToTensor', keys=['lq', 'gt', 'lq_up', 'ref', 'ref_downup']),
164+
dict(
165+
type='Collect',
166+
keys=['lq', 'gt', 'lq_up', 'ref', 'ref_downup'],
167+
meta_keys=['gt_path', 'ref_path'])
168+
]
169+
test_pipeline = [
170+
dict(
171+
type='LoadImageFromFile',
172+
io_backend='disk',
173+
key='lq',
174+
flag='color',
175+
channel_order='rgb',
176+
backend='pillow'),
177+
dict(
178+
type='LoadImageFromFile',
179+
io_backend='disk',
180+
key='ref',
181+
flag='color',
182+
channel_order='rgb',
183+
backend='pillow'),
184+
dict(
185+
type='Resize',
186+
scale=1 / scale,
187+
keep_ratio=True,
188+
keys=['ref'],
189+
output_keys=['ref_down'],
190+
interpolation='bicubic',
191+
backend='pillow'),
192+
dict(
193+
type='Resize',
194+
scale=float(scale),
195+
keep_ratio=True,
196+
keys=['lq', 'ref_down'],
197+
output_keys=['lq_up', 'ref_downup'],
198+
interpolation='bicubic',
199+
backend='pillow'),
200+
dict(
201+
type='Normalize',
202+
keys=['lq'],
203+
mean=[127.5, 127.5, 127.5],
204+
std=[127.5, 127.5, 127.5]),
205+
dict(
206+
type='Normalize',
207+
keys=['lq_up', 'ref', 'ref_downup'],
208+
mean=[0., 0., 0.],
209+
std=[255., 255., 255.]),
210+
dict(type='ImageToTensor', keys=['lq', 'lq_up', 'ref', 'ref_downup']),
211+
dict(
212+
type='Collect',
213+
keys=['lq', 'lq_up', 'ref', 'ref_downup'],
214+
meta_keys=['lq_path', 'ref_path'])
215+
]
216+
217+
data = dict(
218+
workers_per_gpu=9,
219+
train_dataloader=dict(samples_per_gpu=9, drop_last=True),
220+
val_dataloader=dict(samples_per_gpu=1),
221+
test_dataloader=dict(samples_per_gpu=1),
222+
train=dict(
223+
type='RepeatDataset',
224+
times=52,
225+
dataset=dict(
226+
type=train_dataset_type,
227+
gt_folder='data/CUFED/train/input/',
228+
ref_folder='data/CUFED/train/ref/',
229+
pipeline=train_pipeline,
230+
scale=scale)),
231+
val=dict(
232+
type=val_dataset_type,
233+
gt_folder='data/CUFED/valid/input_format/',
234+
ref_folder='data/CUFED/valid/ref1_format/',
235+
pipeline=valid_pipeline,
236+
scale=scale),
237+
test=dict(
238+
type=test_dataset_type,
239+
gt_folder='data/CUFED/valid/input_format/',
240+
ref_folder='data/CUFED/valid/ref1_format/',
241+
pipeline=valid_pipeline,
242+
scale=scale))
243+
244+
# optimizer
245+
optimizers = dict(
246+
generator=dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)),
247+
discriminator=dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)))
248+
249+
# learning policy
250+
total_iters = 500000
251+
lr_config = dict(
252+
policy='Step',
253+
by_epoch=False,
254+
step=[100000, 200000, 300000, 400000],
255+
gamma=0.5)
256+
257+
checkpoint_config = dict(interval=100, save_optimizer=True, by_epoch=False)
258+
evaluation = dict(interval=5000, save_image=True, gpu_collect=True)
259+
log_config = dict(
260+
interval=100,
261+
hooks=[
262+
dict(type='TextLoggerHook', by_epoch=False),
263+
# dict(type='TensorboardLoggerHook')
264+
])
265+
visual_config = None
266+
267+
# runtime settings
268+
dist_params = dict(backend='nccl')
269+
log_level = 'INFO'
270+
work_dir = f'./work_dirs/{exp_name}'
271+
load_from = None
272+
resume_from = None
273+
workflow = [('train', 1)]
274+
find_unused_parameters = True

tests/test_models/test_restorers/test_ttsr.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,26 @@ def test_ttsr():
7272
num_blocks=(16, 16, 8, 4)),
7373
extractor=dict(type='LTE'),
7474
transformer=dict(type='SearchTransformer'),
75-
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'))
75+
discriminator=dict(type='TTSRDiscriminator', in_size=64),
76+
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'),
77+
perceptual_loss=dict(
78+
type='PerceptualLoss',
79+
layer_weights={'29': 1.0},
80+
vgg_type='vgg19',
81+
perceptual_weight=1e-2,
82+
style_weight=0.001,
83+
criterion='mse'),
84+
transferal_perceptual_loss=dict(
85+
type='TransferalPerceptualLoss',
86+
loss_weight=1e-2,
87+
use_attention=False,
88+
criterion='mse'),
89+
gan_loss=dict(
90+
type='GANLoss',
91+
gan_type='vanilla',
92+
loss_weight=1e-3,
93+
real_label_val=1.0,
94+
fake_label_val=0))
7695

7796
scale = 4
7897
train_cfg = None

0 commit comments

Comments
 (0)