-
Notifications
You must be signed in to change notification settings - Fork 43
/
utils.py
105 lines (90 loc) · 3.97 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import logging
import json
import numpy as np
import torch
def sort_dataset(data, labels, num_classes=10, stack=False):
"""Sort dataset based on classes.
Parameters:
data (np.ndarray): data array
labels (np.ndarray): one dimensional array of class labels
num_classes (int): number of classes
stack (bol): combine sorted data into one numpy array
Return:
sorted data (np.ndarray), sorted_labels (np.ndarray)
"""
sorted_data = [[] for _ in range(num_classes)]
for i, lbl in enumerate(labels):
sorted_data[lbl].append(data[i])
sorted_data = [np.stack(class_data) for class_data in sorted_data]
sorted_labels = [np.repeat(i, (len(sorted_data[i]))) for i in range(num_classes)]
if stack:
sorted_data = np.vstack(sorted_data)
sorted_labels = np.hstack(sorted_labels)
return sorted_data, sorted_labels
def init_pipeline(model_dir, headers=None):
"""Initialize folder and .csv logger."""
# project folder
os.makedirs(model_dir)
os.makedirs(os.path.join(model_dir, 'checkpoints'))
os.makedirs(os.path.join(model_dir, 'figures'))
os.makedirs(os.path.join(model_dir, 'plabels'))
if headers is None:
headers = ["epoch", "step", "loss", "discrimn_loss_e", "compress_loss_e",
"discrimn_loss_t", "compress_loss_t"]
create_csv(model_dir, 'losses.csv', headers)
print("project dir: {}".format(model_dir))
def create_csv(model_dir, filename, headers):
"""Create .csv file with filename in model_dir, with headers as the first line
of the csv. """
csv_path = os.path.join(model_dir, filename)
if os.path.exists(csv_path):
os.remove(csv_path)
with open(csv_path, 'w+') as f:
f.write(','.join(map(str, headers)))
return csv_path
def save_params(model_dir, params):
"""Save params to a .json file. Params is a dictionary of parameters."""
path = os.path.join(model_dir, 'params.json')
with open(path, 'w') as f:
json.dump(params, f, indent=2, sort_keys=True)
def update_params(model_dir, pretrain_dir):
"""Updates architecture and feature dimension from pretrain directory
to new directoy. """
params = load_params(model_dir)
old_params = load_params(pretrain_dir)
params['arch'] = old_params["arch"]
params['fd'] = old_params['fd']
save_params(model_dir, params)
def load_params(model_dir):
"""Load params.json file in model directory and return dictionary."""
_path = os.path.join(model_dir, "params.json")
with open(_path, 'r') as f:
_dict = json.load(f)
return _dict
def save_state(model_dir, *entries, filename='losses.csv'):
"""Save entries to csv. Entries is list of numbers. """
csv_path = os.path.join(model_dir, filename)
assert os.path.exists(csv_path), 'CSV file is missing in project directory.'
with open(csv_path, 'a') as f:
f.write('\n'+','.join(map(str, entries)))
def save_ckpt(model_dir, net, epoch):
"""Save PyTorch checkpoint to ./checkpoints/ directory in model directory. """
torch.save(net.state_dict(), os.path.join(model_dir, 'checkpoints',
'model-epoch{}.pt'.format(epoch)))
def save_labels(model_dir, labels, epoch):
"""Save labels of a certain epoch to directory. """
path = os.path.join(model_dir, 'plabels', f'epoch{epoch}.npy')
np.save(path, labels)
def compute_accuracy(y_pred, y_true):
"""Compute accuracy by counting correct classification. """
assert y_pred.shape == y_true.shape
return 1 - np.count_nonzero(y_pred - y_true) / y_true.size
def clustering_accuracy(labels_true, labels_pred):
"""Compute clustering accuracy."""
from sklearn.metrics.cluster import supervised
from scipy.optimize import linear_sum_assignment
labels_true, labels_pred = supervised.check_clusterings(labels_true, labels_pred)
value = supervised.contingency_matrix(labels_true, labels_pred)
[r, c] = linear_sum_assignment(-value)
return value[r, c].sum() / len(labels_true)