-
Notifications
You must be signed in to change notification settings - Fork 6
/
train.py
43 lines (32 loc) · 1.62 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import yaml
import argparse
from utils import yaml_utils
from utils.load import *
from training.trainer import GanTrainer
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
def main(args):
device = torch.device("cuda:0")
config = yaml_utils.Config(yaml.load(open(args.config_path)))
gen, dis = load_gan_model(config)
gen_optimizer = load_optimizer(config, gen.parameters())
dis_optimizer = load_optimizer(config, dis.parameters())
scheduler_g = load_scheduler(config, gen_optimizer)
scheduler_d = load_scheduler(config, dis_optimizer)
dataset = load_dataset(args.batch_size, args.data_dir, args.loaderjob, config)
evaluator = load_evaluator(config, device)
trainer = GanTrainer(args.iterations, dataset, gen, dis, gen_optimizer, dis_optimizer, args.result_dir,
scheduler_g, scheduler_d, evaluator=evaluator, device=device, **config.trainer['args'])
trainer.run()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config_path', type=str, default='configs/base.yml', help='path to config file')
parser.add_argument('--data_dir', type=str, default='./data/imagenet')
parser.add_argument('--iterations', type=int, default=250000)
parser.add_argument('--result_dir', type=str, default='./results/gans',
help='directory to save the results to')
parser.add_argument('--batch_size', type=int, default=64, help="mini batch size")
parser.add_argument('--loaderjob', type=int, default=4,
help='number of parallel data loading processes')
args = parser.parse_args()
main(args)