-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_restoration.py
65 lines (54 loc) · 2.55 KB
/
train_restoration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import os
import yaml
import torch
import torch.utils.data
import numpy as np
import datasets
from models import DenoisingDiffusion_Restoration
def parse_args_and_config():
parser = argparse.ArgumentParser(description='Training Patch-Based Denoising Diffusion Models')
parser.add_argument("--config", type=str, default='Restoration.yml', help="Path to the config file")
parser.add_argument("--phase", type=str, default='test', help="val(generation)")
parser.add_argument("--data_type", type=str, default='VI', help="data type options: ['LOL', 'IR', 'VI', 'CT', 'MRI', 'CT_norm', MRI_norm']")
parser.add_argument('--name', type=str, default='Restoration_VI_MSRS', help='folder name to save outputs')
parser.add_argument("--data_dir", type=str, default='data/IVIF_degraded/train', help="root data path")
parser.add_argument('--resume', default='', type=str, help='Path for checkpoint to load and resume')
parser.add_argument("--sampling_timesteps", type=int, default=20, help="Number of implicit sampling steps for validation image patches")
parser.add_argument('--seed', default=61, type=int, metavar='N', help='Seed for initializing training (default: 61)')
parser.add_argument('-gpu', '--gpu_ids', type=str, default="0")
args = parser.parse_args()
with open(os.path.join("configs", args.config), "r") as f:
config = yaml.safe_load(f)
new_config = dict2namespace(config)
return args, new_config
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def main():
args, config = parse_args_and_config()
# setup device to run
device = torch.device("cuda:{}".format(args.gpu_ids)) if torch.cuda.is_available() else torch.device("cpu")
print("Using device: {}".format(device))
config.device = device
# set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
# data loading
print("=> using dataset '{}'".format(config.data.dataset))
DATASET = datasets.__dict__[config.data.dataset](config, args)
# create model
print("=> creating denoising-diffusion model...")
diffusion = DenoisingDiffusion_Restoration(args, config)
diffusion.train(DATASET)
if __name__ == "__main__":
main()