-
Notifications
You must be signed in to change notification settings - Fork 16
/
utils.py
168 lines (139 loc) · 6.38 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import torch
import torch.nn.functional as F
import os
import math
from enum import Enum
import sys
class TestAccuracies:
"""
Determines if an evaluation on the validation set is better than the best so far.
In particular, this handles the case for meta-dataset where we validate on multiple datasets and we deem
the evaluation to be better if more than half of the validation accuracies on the individual validation datsets
are better than the previous best.
"""
def __init__(self, validation_datasets):
self.datasets = validation_datasets
self.dataset_count = len(self.datasets)
# self.current_best_accuracy_dict = {}
# for dataset in self.datasets:
# self.current_best_accuracy_dict[dataset] = {"accuracy": 0.0, "confidence": 0.0}
# def is_better(self, accuracies_dict):
# is_better = False
# is_better_count = 0
# for i, dataset in enumerate(self.datasets):
# if accuracies_dict[dataset]["accuracy"] > self.current_best_accuracy_dict[dataset]["accuracy"]:
# is_better_count += 1
#
# if is_better_count >= int(math.ceil(self.dataset_count / 2.0)):
# is_better = True
#
# return is_better
# def replace(self, accuracies_dict):
# self.current_best_accuracy_dict = accuracies_dict
def print(self, logfile, accuracy_dict):
print_and_log(logfile, "") # add a blank line
print_and_log(logfile, "Test Accuracies:")
for dataset in self.datasets:
print_and_log(logfile, "{0:}: {1:.1f}+/-{2:.1f}".format(dataset, accuracy_dict[dataset]["accuracy"],
accuracy_dict[dataset]["confidence"]))
print_and_log(logfile, "") # add a blank line
# def get_current_best_accuracy_dict(self):
# return self.current_best_accuracy_dict
def verify_checkpoint_dir(checkpoint_dir, resume, test_mode):
if resume: # verify that the checkpoint directory and file exists
if not os.path.exists(checkpoint_dir):
print("Can't resume for checkpoint. Checkpoint directory ({}) does not exist.".format(checkpoint_dir), flush=True)
sys.exit()
checkpoint_file = os.path.join(checkpoint_dir, 'checkpoint.pt')
if not os.path.isfile(checkpoint_file):
print("Can't resume for checkpoint. Checkpoint file ({}) does not exist.".format(checkpoint_file), flush=True)
sys.exit()
#elif test_mode:
# if not os.path.exists(checkpoint_dir):
# print("Can't test. Checkpoint directory ({}) does not exist.".format(checkpoint_dir), flush=True)
# sys.exit()
else:
if os.path.exists(checkpoint_dir):
print("Checkpoint directory ({}) already exits.".format(checkpoint_dir), flush=True)
print("If starting a new training run, specify a directory that does not already exist.", flush=True)
print("If you want to resume a training run, specify the -r option on the command line.", flush=True)
sys.exit()
def print_and_log(log_file, message):
"""
Helper function to print to the screen and the cnaps_layer_log.txt file.
"""
print(message, flush=True)
log_file.write(message + '\n')
def get_log_files(checkpoint_dir, resume, test_mode):
"""
Function that takes a path to a checkpoint directory and returns a reference to a logfile and paths to the
fully trained model and the model with the best validation score.
"""
verify_checkpoint_dir(checkpoint_dir, resume, test_mode)
#if not test_mode and not resume:
if not resume:
os.makedirs(checkpoint_dir)
checkpoint_path_validation = os.path.join(checkpoint_dir, 'best_validation.pt')
checkpoint_path_final = os.path.join(checkpoint_dir, 'fully_trained.pt')
logfile_path = os.path.join(checkpoint_dir, 'log.txt')
if os.path.isfile(logfile_path):
logfile = open(logfile_path, "a", buffering=1)
else:
logfile = open(logfile_path, "w", buffering=1)
return checkpoint_dir, logfile, checkpoint_path_validation, checkpoint_path_final
def stack_first_dim(x):
"""
Method to combine the first two dimension of an array
"""
x_shape = x.size()
new_shape = [x_shape[0] * x_shape[1]]
if len(x_shape) > 2:
new_shape += x_shape[2:]
return x.view(new_shape)
def split_first_dim_linear(x, first_two_dims):
"""
Undo the stacking operation
"""
x_shape = x.size()
new_shape = first_two_dims
if len(x_shape) > 1:
new_shape += [x_shape[-1]]
return x.view(new_shape)
def sample_normal(mean, var, num_samples):
"""
Generate samples from a reparameterized normal distribution
:param mean: tensor - mean parameter of the distribution
:param var: tensor - variance of the distribution
:param num_samples: np scalar - number of samples to generate
:return: tensor - samples from distribution of size numSamples x dim(mean)
"""
sample_shape = [num_samples] + len(mean.size())*[1]
normal_distribution = torch.distributions.Normal(mean.repeat(sample_shape), var.repeat(sample_shape))
return normal_distribution.rsample()
def loss(test_logits_sample, test_labels, device):
"""
Compute the classification loss.
"""
size = test_logits_sample.size()
sample_count = size[0] # scalar for the loop counter
num_samples = torch.tensor([sample_count], dtype=torch.float, device=device, requires_grad=False)
log_py = torch.empty(size=(size[0], size[1]), dtype=torch.float, device=device)
for sample in range(sample_count):
log_py[sample] = -F.cross_entropy(test_logits_sample[sample], test_labels, reduction='none')
score = torch.logsumexp(log_py, dim=0) - torch.log(num_samples)
return -torch.sum(score, dim=0)
def aggregate_accuracy(test_logits_sample, test_labels):
"""
Compute classification accuracy.
"""
averaged_predictions = torch.logsumexp(test_logits_sample, dim=0)
return torch.mean(torch.eq(test_labels, torch.argmax(averaged_predictions, dim=-1)).float())
def task_confusion(test_logits, test_labels, real_test_labels, batch_class_list):
preds = torch.argmax(torch.logsumexp(test_logits, dim=0), dim=-1)
real_preds = batch_class_list[preds]
return real_preds
def linear_classifier(x, param_dict):
"""
Classifier.
"""
return F.linear(x, param_dict['weight_mean'], param_dict['bias_mean'])