-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_diffusion.py
74 lines (60 loc) · 2.34 KB
/
train_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import os
import random
import socket
import yaml
import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
import numpy as np
import torchvision
import models
import datasets
import utils
from models import DenoisingDiffusion
def parse_args_and_config():
parser = argparse.ArgumentParser(description='Training Patch-Based Denoising Diffusion Models')
parser.add_argument("--config", type=str, required=False, default="lowlight.yml",help="Path to the config file")
parser.add_argument('--resume', default=r'', type=str,
help='Path for checkpoint to load and resume')
parser.add_argument("--sampling_timesteps", type=int, default=25,
help="Number of implicit sampling steps for validation image patches")
parser.add_argument("--image_folder", default='results/images/', type=str,
help="Location to save restored validation image patches")
parser.add_argument('--seed', default=61, type=int, metavar='N',
help='Seed for initializing training (default: 61)')
args = parser.parse_args()
with open(os.path.join("configs", args.config), "r") as f:
config = yaml.safe_load(f)
new_config = dict2namespace(config)
return args, new_config
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def main():
args, config = parse_args_and_config()
# setup device to run
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using device: {}".format(device))
config.device = device
# set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
# data loading
print("=> using dataset '{}'".format(config.data.dataset))
DATASET = datasets.__dict__[config.data.dataset](config)
# create model
print("=> creating denoising-diffusion model...")
diffusion = DenoisingDiffusion(args, config)
diffusion.train(DATASET)
if __name__ == "__main__":
main()