-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathUNET_train.py
168 lines (132 loc) · 5.36 KB
/
UNET_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import sys
import time
import argparse
from glob import glob
import h5py
import numpy as np
from random import shuffle
from tensorflow import keras
# custom tools
sys.path.insert(0, '/your-path-of-repo/utils/')
from namelist import *
import model_utils as mu
import train_utils as tu
# ---------------------------------------------------------- #
# parse user inputs
parser = argparse.ArgumentParser()
# one of the 'TMAX', 'TMIN', 'TMEAN', 'PCT'
parser.add_argument('v', help='Downscaling variable name')
# one of the 'annual', 'summer', 'winter', 'djf', 'mam', 'jja', 'son'
parser.add_argument('s', help='Training seasons')
# a number of 1, 2, ..., 5
# but 3 is expected, the rest are trial-and-error options
parser.add_argument('c1', help='Number of input channels (1<c<5)')
# 1 is expected, other numbers are reserved for UNET-AE
parser.add_argument('c2', help='Number of output channels (1<c<3)')
args = vars(parser.parse_args())
# parser handling
VAR, seasons, input_flag, output_flag = tu.parser_handler(args)
N_input = int(np.sum(input_flag))
N_output = int(np.sum(output_flag))
if N_output > 1:
raise ValueError('UNet accepts only one target')
# ---------------------------------------------------------- #
# number of filters based on the downscaling variable
if VAR == 'PCT':
print('PCT hidden layer setup')
N = [64, 96, 128, 160]
else:
print('T2 hidden layer setup')
N = [56, 112, 224, 448]
# ---------------------------------------------------------- #
lr = 5e-5
epochs = 150
activation='relu'
pool=False # stride convolution instead of maxpooling
# early stopping settings
min_del = 0
max_tol = 3 # early stopping with patience
# ---------------------------------------------------------- #
# training by seasons
for sea in seasons:
# UNET configuration
dscale_unet = mu.UNET(N, (None, None, N_input), pool=pool, activation=activation)
opt_ = keras.optimizers.Adam(lr=lr)
dscale_unet.compile(loss=keras.losses.mean_absolute_error, optimizer=opt_)
# check point settings
# "temp_dir" is where models are saved, defined in the namelist.py
save_name = 'UNET_raw_{}_{}'.format(VAR, sea)
save_path = temp_dir+save_name+'/'
hist_path = temp_dir+'{}_loss.npy'.format(save_name)
# allocate arrays for training/validation loss
T_LOSS = np.empty((int(epochs*L_train),)); T_LOSS[...] = np.nan
V_LOSS = np.empty((epochs,)); V_LOSS[...] = np.nan
# ---------------------- data section ---------------------- #
# before training starts, the authors have saved individual
# training batches separatly as numpy files, and the entire
# validation set as a hdf5 file
# "input_flag" and "output_flag" are boolean indices that
# subset input and training slices from the saved numpy arrays
# Size of data is out of repo capacity, so pesudo code is
# provided here
# ---------------------------------------------------------- #
trainfiles = glob('/user/drive/dscal_proj/batch_train_{}_{}_*.npy'.format(VAR, sea))
validfile = glob('/user/drive/dscal_proj/valid_{}_{}_*.npy'.format(VAR, sea))
with h5py.File(validfile, 'r') as h5io:
valid_all = h5io['valid'][...] # shape = (sample, x, y, channel)
X_valid = valid_all[..., input_flag]
Y_valid = valid_all[..., output_flag]
#
shuffle(trainfiles)
L_train = len(trainfiles)
# epoch begins
tol = 0
for i in range(epochs):
print('epoch = {}'.format(i))
if i == 0:
Y_pred = dscale_unet.predict([X_valid])
record = tu.mean_absolute_error(Y_pred, Y_valid)
print('Initial validation loss: {}'.format(record))
# shuffling on epoch-begin
shuffle(trainfiles)
# loop over batches
for j, name in enumerate(trainfiles):
# import batch data (numpy arrays)
temp_batch = np.load(name)
X = temp_batch[..., input_flag]
Y = temp_batch[..., output_flag]
# temp_batch.shape = (sample_per_batch, x, y, channel)
# train on batch
loss_ = dscale_unet.train_on_batch(X, Y)
# Backup training loss
T_LOSS[i*L_train+j] = loss_
# print out
if j%50 == 0:
print('\t{} step loss = {}'.format(j, loss_))
# validate on epoch-end
Y_pred = dscale_unet.predict([X_valid])
record_temp = tu.mean_absolute_error(Y_pred, Y_valid)
# Backup validation loss
V_LOSS[i] = record_temp
# Save loss info
loss_dict = {'T_LOSS':T_LOSS, 'V_LOSS':V_LOSS}
np.save(hist_path, loss_dict)
# early stopping
if record - record_temp > min_del:
print('Validation loss improved from {} to {}'.format(record, record_temp))
record = record_temp
tol = 0
print('tol: {}'.format(tol))
# save
print('save to: {}'.format(save_path))
dscale_unet.save(save_path)
else:
print('Validation loss {} NOT improved'.format(record_temp))
tol += 1
print('tol: {}'.format(tol))
if tol >= max_tol:
print('Early stopping')
sys.exit();
else:
print('Pass to the next epoch')
continue;