-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
144 lines (122 loc) · 4.25 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import argparse
import traceback
import logging
import yaml
import sys
import os
import torch
import numpy as np
from solver_2d import Diffusion as Diffusion_2d
from solver_3d import Diffusion as Diffusion_3d
from pathlib import Path
torch.set_printoptions(sci_mode=False)
def parse_args_and_config():
parser = argparse.ArgumentParser(description=globals()["__doc__"])
parser.add_argument(
"--config", type=str, required=True, help="Path to the config file"
)
parser.add_argument(
"--type", type=str, required=True, help="Either [2d, 3d]"
)
parser.add_argument(
"--CG_iter", type=int, default=5, help="Inner number of iterations for CG"
)
parser.add_argument(
"--Nview", type=int, default=16, help="number of projections for CT"
)
parser.add_argument("--seed", type=int, default=1234, help="Set different seeds for diverse results")
parser.add_argument(
"--exp", type=str, default="./exp", help="Path for saving running related data."
)
parser.add_argument(
"--ckpt_load_name", type=str, default="AAPM256_1M.pt", help="Load pre-trained ckpt"
)
parser.add_argument(
"--deg", type=str, required=True, help="Degradation"
)
parser.add_argument(
"--sigma_y", type=float, default=0., help="sigma_y"
)
parser.add_argument(
"--eta", type=float, default=0.85, help="Eta"
)
parser.add_argument(
"--rho", type=float, default=10.0, help="rho"
)
parser.add_argument(
"--lamb", type=float, default=0.04, help="lambda for TV"
)
parser.add_argument(
"--gamma", type=float, default=1.0, help="regularizer for noisy recon"
)
parser.add_argument(
"--T_sampling", type=int, default=50, help="Total number of sampling steps"
)
parser.add_argument(
"-i",
"--image_folder",
type=str,
default="./results",
help="The folder name of samples",
)
parser.add_argument(
"--dataset_path", type=str, default="/media/harry/tomo/AAPM_data_vol/256_sorted/L067", help="The folder of the dataset"
)
# MRI-exp arguments
parser.add_argument(
"--mask_type", type=str, default="uniform1d", help="Undersampling type"
)
parser.add_argument(
"--acc_factor", type=int, default=4, help="acceleration factor"
)
parser.add_argument(
"--nspokes", type=int, default=30, help="Number of sampled lines in radial trajectory"
)
parser.add_argument(
"--center_fraction", type=float, default=0.08, help="ACS region"
)
args = parser.parse_args()
# parse config file
with open(os.path.join("configs/vp", args.config), "r") as f:
config = yaml.safe_load(f)
new_config = dict2namespace(config)
if "CT" in args.deg:
args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"view{args.Nview}"
elif "MRI" in args.deg:
args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"{args.mask_type}_acc{args.acc_factor}"
args.image_folder.mkdir(exist_ok=True, parents=True)
if not os.path.exists(args.image_folder):
os.makedirs(args.image_folder)
# add device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
logging.info("Using device: {}".format(device))
new_config.device = device
# set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
return args, new_config
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def main():
args, config = parse_args_and_config()
try:
if args.type == "2d":
runner = Diffusion_2d(args, config)
elif args.type == "3d":
runner = Diffusion_3d(args, config)
runner.sample()
except Exception:
logging.error(traceback.format_exc())
return 0
if __name__ == "__main__":
sys.exit(main())