-
Notifications
You must be signed in to change notification settings - Fork 79
/
Copy pathdefault_config.py
104 lines (87 loc) · 2.67 KB
/
default_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python3
"""
Default arguments from [1]. Entries can be manually overriden via
command line arguments in `train.py`.
[1]: arXiv 2006.09965
"""
class ModelTypes(object):
COMPRESSION = 'compression'
COMPRESSION_GAN = 'compression_gan'
class ModelModes(object):
TRAINING = 'training'
VALIDATION = 'validation'
EVALUATION = 'evaluation'
class Datasets(object):
OPENIMAGES = 'openimages'
CITYSCAPES = 'cityscapes'
JETS = 'jetimages'
class DatasetPaths(object):
OPENIMAGES = 'data/openimages'
CITYSCAPES = ''
JETS = ''
class directories(object):
experiments = 'experiments'
class checkpoints(object):
gan1 = 'experiments/lossless.pt'
class args(object):
"""
Shared config
"""
name = 'hific_v0.1'
silent = True
n_epochs = 8
n_steps = 1e6
batch_size = 8
log_interval = 1000
save_interval = 50000
gpu = 0
multigpu = True
dataset = Datasets.OPENIMAGES
dataset_path = DatasetPaths.OPENIMAGES
shuffle = True
discriminator_steps = 0
model_mode = ModelModes.TRAINING
# Architecture params - Table 3a) of [1]
latent_channels = 220
n_residual_blocks = 7 # Authors use 9 blocks, performance saturates at 5
lambda_B = 2**(-4) # Loose rate
k_M = 0.075 * 2**(-5) # Distortion
k_P = 1. # Perceptual loss
beta = 0.15 # Generator loss
use_channel_norm = True
likelihood_type = 'gaussian' # Latent likelihood model
normalize_input_image = False # Normalize inputs to range [-1,1]
# Shapes
crop_size = 256
image_dims = (3,256,256)
latent_dims = (latent_channels,16,16)
# Optimizer params
learning_rate = 1e-4
weight_decay = 1e-6
# Scheduling
lambda_schedule = dict(vals=[2., 1.], steps=[50000])
lr_schedule = dict(vals=[1., 0.1], steps=[500000])
target_schedule = dict(vals=[0.20/0.14, 1.], steps=[50000]) # Rate allowance
# match target rate to lambda_A coefficient
regime = 'low' # -> 0.14
target_rate_map = dict(low=0.14, med=0.3, high=0.45)
lambda_A_map = dict(low=2**1, med=2**0, high=2**(-1))
target_rate = target_rate_map[regime]
lambda_A = lambda_A_map[regime]
"""
Specialized configs
"""
class mse_lpips_args(args):
"""
Config for model trained with distortion and
perceptual loss only.
"""
model_type = ModelTypes.COMPRESSION
class hific_args(args):
"""
Config for model trained with full generative
loss terms.
"""
model_type = ModelTypes.COMPRESSION_GAN
gan_loss = 'non_saturating' # ('non_saturating', 'least_squares')
discriminator_steps = 1