This repository has been archived by the owner on Jul 15, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathconfig.py
140 lines (120 loc) · 7.01 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import argparse
import pickle
import os
import utils.utils as utils
def read_arguments(train=True):
parser = argparse.ArgumentParser()
parser = add_all_arguments(parser, train)
parser.add_argument('--phase', type=str, default='train')
opt = parser.parse_args()
if train:
set_dataset_default_lm(opt, parser)
if opt.continue_train:
update_options_from_file(opt, parser)
opt = parser.parse_args()
opt.phase = 'train' if train else 'test'
if train:
opt.loaded_latest_iter = 0 if not opt.continue_train else load_iter(opt)
utils.fix_seed(opt.seed)
print_options(opt, parser)
if train:
save_options(opt, parser)
return opt
def add_all_arguments(parser, train):
#--- general options ---
parser.add_argument('--name', type=str, default='label2coco', help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
parser.add_argument('--no_spectral_norm', action='store_true', help='if true, no spectral norm is applied')
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/')
parser.add_argument('--dataset_mode', type=str, default='coco')
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
# for generator
parser.add_argument('--num_res_blocks', type=int, default=6, help='number of residual blocks in G and D')
parser.add_argument('--channels_G', type=int, default=64, help='# of gen filters in first conv layer')
parser.add_argument('--param_free_norm', type=str, default='syncbatch', help='which param free norm to use at G')
parser.add_argument('--spade_ks', type=int, default=3, help='kernel size of convs inside spade')
parser.add_argument('--no_EMA', action='store_true', help='compute exponential moving averages')
parser.add_argument('--EMA_decay', type=float, default=0.9999, help='decay in exponential moving averages')
parser.add_argument('--no_3dnoise', action='store_true', default=False)
parser.add_argument('--z_dim', type=int, default=64, help="dimension of the latent z vector")
if train:
parser.add_argument('--freq_print', type=int, default=1000, help='frequency of showing training results on console')
parser.add_argument('--freq_save_ckpt', type=int, default=20000, help='frequency of saving the checkpoints')
parser.add_argument('--freq_save_latest', type=int, default=10000, help='frequency of saving the latest model')
parser.add_argument('--freq_smooth_loss', type=int, default=250, help='smoothing window for loss visualization')
parser.add_argument('--freq_save_loss', type=int, default=2500, help='frequency of loss plot updates')
parser.add_argument('--freq_fid', type=int, default=5000, help='frequency of saving the fid score (in epochs)')
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
parser.add_argument('--which_iter', type=str, default='latest', help='which epoch to load when continue_train')
parser.add_argument('--num_epochs', type=int, default=200, help='number of epochs to train')
parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam')
parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
parser.add_argument('--lr_g', type=float, default=0.0001, help='G learning rate, default=0.0001')
parser.add_argument('--lr_d', type=float, default=0.0004, help='D learning rate, default=0.0004')
parser.add_argument('--channels_D', type=int, default=64, help='# of discrim filters in first conv layer')
parser.add_argument('--add_vgg_loss', action='store_true', help='if specified, add VGG feature matching loss')
parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for vgg loss')
parser.add_argument('--no_balancing_inloss', action='store_true', default=False)
parser.add_argument('--no_labelmix', action='store_true', default=False)
parser.add_argument('--lambda_labelmix', type=float, default=10.0, help='weight for feature consistency regularization')
else:
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
parser.add_argument('--ckpt_iter', type=str, default='best', help='which epoch to load to evaluate model')
return parser
def set_dataset_default_lm(opt, parser):
if opt.dataset_mode == "ade20k":
parser.set_defaults(lambda_labelmix=10.0)
parser.set_defaults(EMA_decay=0.9999)
if opt.dataset_mode == "cityscapes":
parser.set_defaults(lr_g=0.0004)
parser.set_defaults(lambda_labelmix=5.0)
parser.set_defaults(lambda_labelmix=5.0)
parser.set_defaults(freq_fid=2500)
parser.set_defaults(EMA_decay=0.999)
if opt.dataset_mode == "coco":
parser.set_defaults(lambda_labelmix=10.0)
parser.set_defaults(EMA_decay=0.9999)
def save_options(opt, parser):
path_name = os.path.join(opt.checkpoints_dir,opt.name)
os.makedirs(path_name, exist_ok=True)
with open(path_name + '/opt.txt', 'wt') as opt_file:
for k, v in sorted(vars(opt).items()):
comment = ''
default = parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment))
with open(path_name + '/opt.pkl', 'wb') as opt_file:
pickle.dump(opt, opt_file)
def update_options_from_file(opt, parser):
new_opt = load_options(opt)
for k, v in sorted(vars(opt).items()):
if hasattr(new_opt, k) and v != getattr(new_opt, k):
new_val = getattr(new_opt, k)
parser.set_defaults(**{k: new_val})
return parser
def load_options(opt):
file_name = os.path.join(opt.checkpoints_dir, opt.name, "opt.pkl")
new_opt = pickle.load(open(file_name, 'rb'))
return new_opt
def load_iter(opt):
if opt.which_iter == "latest":
with open(os.path.join(opt.checkpoints_dir, opt.name, "latest_iter.txt"), "r") as f:
res = int(f.read())
return res
else:
return int(opt.which_iter)
def print_options(opt, parser):
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)