generated from scotthlee/template
-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
186 lines (168 loc) · 7.35 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""Trains a model on one of the three classification tasks."""
import numpy as np
import pandas as pd
import argparse
import os
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras import layers, callbacks
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from hamlet import models
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--task',
type=str,
default='abnormal',
help='Which prediction task to train on.',
choices=['abnormal', 'abnormal_tb', 'findings'])
parser.add_argument('--data_dir',
type=str,
default='C:/Users/yle4/data/',
help='Path to the directory holding the three image \
dataset folders (train, val, and test) and the \
CSV file with the image-level labels and metadata.')
parser.add_argument('--csv_name',
type=str,
default='samp.csv',
help='Name of the CSV file (including the file \
extension) holding the image-level labels and \
metadata.')
parser.add_argument('--train_mod_folder',
type=str,
default=None,
help='Folder holding the model file to be used for \
training. Ignored if --mode is "test".')
parser.add_argument('--model_flavor',
type=str,
default='EfficientNetV2M',
help='What pretrained model to use as the feature \
extractor.')
parser.add_argument('--no_augmentation',
action='store_true')
parser.add_argument('--batch_size',
type=int,
default=12,
help='Minibatch size for model training and inference.')
parser.add_argument('--metric',
type=str,
default='val_ROC_AUC',
help='Which metric to use for early stopping.')
parser.add_argument('--metric_mode',
type=str,
default='max',
help='Whether to min or max the metric',
choices=['min', 'max'])
parser.add_argument('--distributed',
action='store_true',
help='Turns on distributed (multi-GPU) training')
parser.add_argument('--validate_on',
type=str,
default='hamlet',
choices=['hamlet', 'nih'],
help='Which dataset to use for validation.')
parser.set_defaults(no_augmentation=False,
progressive=False,
distributed=False)
args = parser.parse_args()
# Parameters
AUGMENT = not args.no_augmentation
MODEL_FLAVOR = args.model_flavor
BATCH_SIZE = args.batch_size
TASK = args.task
METRIC = args.metric
METRIC_MODE = args.metric_mode
DISTRIBUTED = args.distributed
VALIDATE_ON = args.validate_on
# Directories
DATA_DIR = args.data_dir
HAM_DIR = DATA_DIR + 'hamlet/'
OUT_DIR = 'output/' + args.task + '/'
CHECK_DIR = OUT_DIR + 'checkpoints/'
LOG_DIR = OUT_DIR + 'logs/'
TRAIN_MOD_FOLDER = args.train_mod_folder
# Just some info
if AUGMENT:
print('Augmentation on.')
# Setting training strategy
if DISTRIBUTED:
print('Using multiple GPUs.\n')
cdo = tf.distribute.HierarchicalCopyAllReduce()
strategy = tf.distribute.MirroredStrategy(cross_device_ops=cdo)
else:
strategy = tf.distribute.get_strategy()
# Reading the labels
records = pd.read_csv(HAM_DIR + args.csv_name, encoding='latin')
if TASK == 'findings':
LABEL_COL = [
'infiltrate', 'reticular', 'cavity',
'nodule', 'pleural_effusion', 'hilar_adenopathy',
'linear_opacity', 'discrete_nodule', 'volume_loss',
'pleural_reaction', 'other', 'miliary'
]
NUM_CLASSES = len(LABEL_COL)
else:
LABEL_COL = TASK
NUM_CLASSES = 1
records[LABEL_COL] = records[LABEL_COL].fillna(0).astype(np.uint8)
train = records[records.split == 'train'].reset_index(drop=True)
# Parameters for the data loader
img_height = 600
img_width = 600
# Loading the training data
train_dg = ImageDataGenerator()
train_dir = HAM_DIR + 'train/img/'
train_gen = train_dg.flow_from_dataframe(dataframe=train,
directory=train_dir,
x_col='file',
y_col=LABEL_COL,
class_mode='raw',
target_size=(img_height,
img_width),
batch_size=BATCH_SIZE)
# Loading the validation data
val_dg = ImageDataGenerator()
if VALIDATE_ON == 'hamlet':
val = records[records.split == 'val'].reset_index(drop=True)
val_dir = HAM_DIR + 'val/img/'
elif VALIDATE_ON == 'nih':
nih_labels = pd.read_csv(DATA_DIR + 'nih/labels.csv')
val = nih_labels[nih_labels.split == 'val']
val_dir = DATA_DIR + 'nih/val/img/'
val_gen = val_dg.flow_from_dataframe(dataframe=val,
directory=val_dir,
x_col='file',
y_col=LABEL_COL,
class_mode='raw',
shuffle=False,
target_size=(img_height,
img_width),
batch_size=BATCH_SIZE)
# Setting up callbacks and metrics
tr_callbacks = [
callbacks.EarlyStopping(patience=2,
mode=METRIC_MODE,
monitor=METRIC,
restore_best_weights=True,
verbose=1),
callbacks.ModelCheckpoint(filepath=CHECK_DIR + 'training/',
save_weights_only=True,
mode=METRIC_MODE,
monitor=METRIC,
save_best_only=True),
callbacks.TensorBoard(log_dir=LOG_DIR + 'training/')
]
with strategy.scope():
# Setting up a fresh model
mod = models.EfficientNet(num_classes=NUM_CLASSES,
img_height=img_height,
img_width=img_width,
augmentation=AUGMENT,
learning_rate=1e-4,
model_flavor=MODEL_FLAVOR,
effnet_trainable=True)
if TRAIN_MOD_FOLDER:
mod.load_weights(CHECK_DIR + TRAIN_MOD_FOLDER)
mod.fit(train_gen,
validation_data=val_gen,
callbacks=tr_callbacks,
epochs=20)