-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathmisc_utils.py
461 lines (401 loc) · 14.6 KB
/
misc_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
# Built-in
import os
import time
import json
import pickle
import collections.abc
from glob import glob
from functools import wraps
# Libs
import torch
import scipy.signal
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from skimage import io
from natsort import natsorted
from torchsummary import summary
# Own modules
def set_gpu(gpu, enable_benchmark=True):
"""
Set which gpu to use, also return True as indicator for parallel model if multi-gpu selected
:param gpu: which gpu(s) to use, could allow a string with device ids separated by ','
:param enable_benchmark: if True, will let CUDNN find optimal set of algorithms for input configuration
:return: device instance
"""
if not isinstance(gpu, str):
gpu = str(int(gpu))
if len(str(gpu)) > 1:
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
parallel = True
device = torch.device("cuda:{}".format(','.join([str(a) for a in range(len(gpu.split(',')))])))
print("Devices being used: cuda:", gpu)
else:
parallel = False
device = torch.device("cuda:{}".format(gpu))
print("Device being used:", device)
torch.backends.cudnn.benchmark = enable_benchmark
return device, parallel
def make_dir_if_not_exist(dir_path):
"""
Make the directory if it does not exists
:param dir_path: absolute path to the directory
:return:
"""
if not os.path.exists(dir_path):
os.makedirs(dir_path)
def timer_decorator(func):
"""
This is a decorator to print out running time of executing func
:param func:
:return:
"""
@wraps(func)
def timer_wrapper(*args, **kwargs):
start_time = time.time()
func(*args, **kwargs)
duration = time.time() - start_time
print('duration: {:.3f}s'.format(duration))
return timer_wrapper
def str2list(s, sep=',', d_type=int):
"""
Change a {sep} separated string into a list of items with d_type
:param s: input string
:param sep: separator for string
:param d_type: data type of each element
:return:
"""
if type(s) is not list:
s = [d_type(a) for a in s.split(sep)]
return s
def load_file(file_name, **kwargs):
"""
Read data file of given path, use numpy.load if it is in .npy format,
otherwise use pickle or imageio
:param file_name: absolute path to the file
:return: file data, or IOError if it cannot be read by either numpy or pickle or imageio
"""
try:
if file_name[-3:] == 'npy':
data = np.load(file_name)
elif file_name[-3:] == 'pkl' or file_name[-6:] == 'pickle':
with open(file_name, 'rb') as f:
data = pickle.load(f)
elif file_name[-3:] == 'txt':
with open(file_name, 'r') as f:
data = f.readlines()
elif file_name[-3:] == 'csv':
data = np.genfromtxt(file_name, delimiter=',', dtype=None, encoding='UTF-8')
elif file_name[-4:] == 'json':
data = json.load(open(file_name))
elif 'pil' in kwargs and kwargs['pil']:
try:
data = Image.open(file_name)
except Image.DecompressionBombError:
Image.MAX_IMAGE_PIXELS = None
data = Image.open(file_name)
else:
try:
data = io.imread(file_name)
except Image.DecompressionBombError:
Image.MAX_IMAGE_PIXELS = None
data = io.imread(file_name)
except ValueError or OSError:
data = np.array(Image.open(file_name).convert('RGB'))
return data
except Exception: # so many things could go wrong, can't be more specific.
raise IOError('Problem loading {}'.format(file_name))
def save_file(file_name, data, fmt='%.8e', sort_keys=True, indent=4):
"""
Save data file of given path, use numpy.load if it is in .npy format,
otherwise use pickle or imageio
:param file_name: absolute path to the file
:param data: data to save
:return: file data, or IOError if it cannot be saved by either numpy or or pickle imageio
"""
try:
if file_name[-3:] == 'npy':
np.save(file_name, data)
elif file_name[-3:] == 'pkl':
with open(file_name, 'wb') as f:
pickle.dump(data, f)
elif file_name[-3:] == 'txt':
with open(file_name, 'w') as f:
f.writelines(data)
elif file_name[-3:] == 'csv':
np.savetxt(file_name, data, delimiter=',', fmt=fmt)
elif file_name[-4:] == 'json':
json.dump(data, open(file_name, 'w'), sort_keys=sort_keys, indent=indent)
else:
data = Image.fromarray(data.astype(np.uint8))
data.save(file_name)
except Exception: # so many things could go wrong, can't be more specific.
raise IOError('Problem saving this data')
def get_img_channel_num(file_name):
"""
Get #channels of the image file
:param file_name: absolute path to the image file
:return: #channels or ValueError
"""
img = load_file(file_name)
if len(img.shape) == 2:
channel_num = 1
elif len(img.shape) == 3:
channel_num = img.shape[-1]
else:
raise ValueError('Image can only have 2 or 3 dimensions')
return channel_num
def rotate_list(l):
"""
Rotate a list of lists
:param l: list of lists to rotate
:return:
"""
return list(map(list, zip(*l)))
def make_center_string(char, length, center_str=''):
"""
Make one line decoration string that has center_str at the center and surrounded by char
The total length of the string equals to length
:param char: character to be repeated
:param length: total length of the string
:param center_str: string that shown at the center
:return:
"""
return center_str.center(length, char)
def float2str(f):
"""
Return a string for float number and change '.' to character 'p'
:param f: float number
:return: changed string
"""
return '{}'.format(f).replace('.', 'p')
def stem_string(s, lower=True):
"""
Return a string that with spaces at the begining or end removed and all casted to lower cases
:param s: input string
:param lower: if True, the string will be casted to lower cases
:return: stemmed string
"""
if lower:
return s.strip().lower()
else:
return s.strip()
def remove_digits(s):
"""
Remove digits in the given string
:param s: input string
:return: digits removed string
"""
return ''.join([c for c in s if not c.isdigit()])
def get_digits(s):
"""
Get digits in the given string, cast to int
:param s: input string
:return: int from string
"""
return int(''.join([c for c in s if c.isdigit()]))
def get_model_summary(model, shape, device=None):
"""
Get model summary with torchsummary
:param model: the model to visualize summary
:param shape: shape of the input data
:return:
"""
if not device:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
summary(model.to(device), shape)
def set_random_seed(seed_):
"""
Set random seed for torch, cudnn and numpy
:param seed_: random seed to use, could be your lucky number :)
:return:
"""
torch.manual_seed(seed_)
torch.backends.cudnn.deterministic = True
np.random.seed(seed_)
def normalize_rgb(rgb):
"""
Normalize rgb to 0~1 range
:param rgb: the rgb values to be normalized, could be a tuple or list of tuples
:return:
"""
if isinstance(rgb, tuple):
return tuple([float(a)/255 for a in rgb])
elif isinstance(rgb, list):
norm_rgb = []
for item in rgb:
norm_rgb.append(normalize_rgb(item))
return norm_rgb
else:
raise NotImplementedError('Data type: {} not understood'.format(type(rgb)))
def args_getter(inspect_class):
"""
Inspect parameters inside a class
:param inspect_class: the class to be inspected
:return: a dict of key value pairs of all parameters in this class
"""
params = {}
for k, v in inspect_class.__dict__.items():
if not k.startswith('__'):
params[k] = v
return params
def args_writer(file_name, inspect_class):
"""
Save parameters inside a class into json file
:param file_name: path to the file to be saved
:param inspect_class: the class which parameters are going to be saved
:return:
"""
params = args_getter(inspect_class)
save_file(file_name, params, sort_keys=True, indent=4)
def read_tensorboard_csv(file, field='Value', smooth=True, window_size=11, order=2):
"""
Read values from tensorboard csv files, perform savgol smoothing if user specified
:param file: the csv file downloaded from the tensorboard
:param field: the name of the column in the csv file to be read
:param smooth: if True, perform savgol smoothing on the read data
:param window_size: window size of the savgol filter
:param order: order of the savgol filter
:return: data read from the csv file w/o smoothing
"""
df = pd.read_csv(file, skipinitialspace=True, usecols=['Step', field])
if smooth:
value = scipy.signal.savgol_filter(np.array(df[field]), window_size, order)
else:
value = np.array(df[field])
step = np.array(df['Step'])
return step, value
def get_default_colors():
"""
Get plt default colors
:return: a list of rgb colors
"""
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
return colors
def get_file_name_no_extension(file_name):
"""
Handy function for getting the file name without path and extension
:param file_name: the name or path to the file
:return: filename without extension
"""
return os.path.splitext(os.path.basename(file_name))[0]
def get_files(path_list, extension):
"""
Get files in the given folder that matches certain regex
:param path_list: list of path to the directory
:param extension: regex that filters the desired files
:return: list of files
"""
if isinstance(path_list, str):
path_list = [path_list]
return natsorted(glob(os.path.join(*path_list, extension)))
def recursive_update(d, u):
"""
Recursively update nested dictionary d with u
:param d: the dictionary to be updated
:param u: the new dictionary
:return:
"""
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = recursive_update(d.get(k, {}), v)
else:
d[k] = v
return d
def parse_args(arg_list):
"""
Parse the arguments in a recursive way
:param arg_list: the arguments in a list where each element is either key or val
:return: dictionary of arguments
"""
def parse_args_helper(arg_l):
"""
Recursively calling itself if it's a key or return the value otherwise
:param arg_l: argument list
:return:
"""
while len(arg_l) > 0:
item = arg_list.pop(0)
if '--' in item:
return {item[2:]: parse_args_helper(arg_l)}
else:
try:
return float(item)
except ValueError:
return item
arg_dict = {}
while len(arg_list) > 0:
item = parse_args_helper(arg_list)
recursive_update(arg_dict, item)
return arg_dict
def update_flags(flags, cf_dict):
"""
Overwrite the configs in flags if it is given by cf_dict
:param flags: dictionary of configurations, this is from the config.json file
:param cf_dict: dictionary of configurations, this is from command line
:return:
"""
recursive_update(flags, cf_dict)
return historical_update_flag(flags, cf_dict)
def historical_update_flag(flags, cf_dict):
"""
This function updates flag to make it backward compatible with old versions
:param flags: dictionary of configurations, this is from the config.json file
:param cf_dict: dictionary of configurations, this is from command line
:return:
"""
flags['config'] = cf_dict['config']
return historical_process_flag(flags)
def historical_process_flag(flags):
"""
This function updates flag to make it backward compatible with old versions
:param flags: dictionary of configurations, this is from the config.json file
"""
if 'imagenet' not in flags:
flags['imagenet'] = 'True'
if 'name' not in flags['optimizer']:
flags['optimizer']['name'] = 'sgd'
if 'aux_loss' not in flags['optimizer']:
flags['optimizer']['aux_loss'] = 0
if 'aux_loss' in flags['optimizer']:
if 'aux_loss_weight' not in flags['optimizer']:
flags['optimizer']['aux_loss_weight'] = 0.4
if 'class_weight' not in flags['trainer']:
flags['trainer']['class_weight'] = '({})'.format(','.join(['1' for _ in range(flags['dataset']['class_num'])]))
if 'loss_weights' not in flags['trainer']:
flags['trainer']['loss_weights'] = 'None'
if isinstance(flags['trainer']['bp_loss_idx'], str) and len(flags['trainer']['bp_loss_idx']) == 1:
flags['trainer']['bp_loss_idx'] = '({},)'.format(flags['trainer']['bp_loss_idx'])
if isinstance(flags['trainer']['bp_loss_idx'], int):
flags['trainer']['bp_loss_idx'] = '({},)'.format(flags['trainer']['bp_loss_idx'])
if isinstance(flags['trainer']['loss_weights'], int):
flags['trainer']['loss_weights'] = (flags['trainer']['loss_weights'],)
if 'further_train' not in flags['trainer']:
flags['trainer']['further_train'] = False
elif isinstance(flags['trainer']['further_train'], str):
flags['trainer']['further_train'] = eval(flags['trainer']['further_train'])
if 'use_emau' not in flags:
flags['use_emau'] = False
elif isinstance(flags['use_emau'], str):
flags['use_emau'] = eval(flags['use_emau'])
if 'use_ocr' not in flags:
flags['use_ocr'] = False
else:
try:
flags['use_ocr'] = eval(flags['use_ocr'])
except TypeError:
pass
flags['ds_cfgs'] = [a for a in sorted(flags.keys()) if 'dataset' in a]
assert flags['ds_cfgs'][0] == 'dataset'
if 'gamma' not in flags['trainer']:
flags['trainer']['gamma'] = 2
if 'alpha' not in flags['trainer']:
flags['trainer']['alpha'] = 0.25
for ds_flag in flags['ds_cfgs']:
if 'load_func' not in flags[ds_flag]:
flags[ds_flag]['load_func'] = 'default'
else:
assert flags[ds_flag]['load_func'] == 'default' or flags[ds_flag]['load_func'] == 'None'
return flags