-
Notifications
You must be signed in to change notification settings - Fork 14
/
main.py
133 lines (115 loc) · 5.02 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Runs the VaDeSC model.
"""
import argparse
from pathlib import Path
import yaml
import logging
import tensorflow as tf
import tensorflow_probability as tfp
import os
from models.losses import Losses
from train import run_experiment
tfd = tfp.distributions
tfkl = tf.keras.layers
tfpl = tfp.layers
tfk = tf.keras
# Project-wide constants:
ROOT_LOGGER_STR = "GMM_Survival"
LOGGER_RESULT_FILE = "logs.txt"
CHECKPOINT_PATH = 'models/Ours'
logger = logging.getLogger(ROOT_LOGGER_STR + '.' + __name__)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
tf.config.experimental.set_memory_growth(physical_devices[0], True)
def main():
project_dir = Path(__file__).absolute().parent
print(project_dir)
parser = argparse.ArgumentParser()
# Model parameters
parser.add_argument('--data',
default='mnist',
type=str,
choices=['mnist', 'sim', 'support', 'flchain', 'hgg', 'hemo', 'lung1', 'nsclc',
'nsclc_features', 'basel'],
help='the dataset (mnist, sim, support, flchain, hgg, hemo, lung1, nsclc, basel)')
parser.add_argument('--num_epochs',
default=1000,
type=int,
help='the number of training epochs')
parser.add_argument('--batch_size',
default=256,
type=int,
help='the mini-batch size')
parser.add_argument('--lr',
default=0.001,
type=float,
help='the learning rate')
parser.add_argument('--decay',
default=0.00001,
type=float,
help='the decay')
parser.add_argument('--weibull_shape',
default=1.0,
type=float,
help='the Weibull shape parameter (global)')
parser.add_argument('--no-survival',
dest='survival',
action='store_false',
help='specifies if the survival model should not be included')
parser.add_argument('--dsa',
dest='dsa',
action='store_true',
help='specifies if the deep survival analysis with k-means shuld be run')
parser.add_argument('--dsa_k',
default=1,
type=int,
help='number of clusters in deep survival analysis with k-means')
parser.add_argument('--eval-cal',
default=False,
type=bool,
help='specifies if the calibration needs to be evaluated')
parser.set_defaults(survival=True)
# Other parameters
parser.add_argument('--runs',
default=1,
type=int,
help='the number of runs, the results will be averaged')
parser.add_argument('--results_dir',
default=os.path.join(project_dir, 'models/experiments'),
type=lambda p: Path(p).absolute(),
help='the directory where the results will be saved')
parser.add_argument('--results_fname',
default='',
type=str,
help='the name of the .txt file with the results')
parser.add_argument('--pretrain', default=False, type=bool,
help='specifies if the autoencoder should be pretrained')
parser.add_argument('--epochs_pretrain', default=10, type=int,
help='the number of pretraining epochs')
parser.add_argument('--save_model', default=False, type=bool,
help='specifies if the model should be saved')
parser.add_argument('--ex_name', default='', type=str, help='the experiment name')
parser.add_argument('--config_override', default='', type=str, help='the override file name for config.yml')
parser.add_argument('--seed', default=42, type=int, help='random number generator seed')
parser.add_argument('--eager',
default=False,
type=bool,
help='specifies if the TF functions should be run eagerly')
args = parser.parse_args()
data_name = args.data +'.yml'
config_path = project_dir / 'configs' / data_name
# Check for config override
if args.config_override is not "":
config_path = Path(args.config_override)
with config_path.open(mode='r') as yamlfile:
configs = yaml.safe_load(yamlfile)
losses = Losses(configs)
if args.data == "MNIST":
loss = losses.loss_reconstruction_binary
else:
loss = losses.loss_reconstruction_mse
run_experiment(args, configs, loss)
if __name__ == "__main__":
main()