-
Notifications
You must be signed in to change notification settings - Fork 34
/
softmax.py
241 lines (176 loc) · 8.43 KB
/
softmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# Angus Dempster, Daniel F Schmidt, Geoffrey I Webb
# MiniRocket: A Very Fast (Almost) Deterministic Transform for Time Series
# Classification
# https://arxiv.org/abs/2012.08791
import copy
import numpy as np
import pandas as pd
import torch, torch.nn as nn, torch.optim as optim
from minirocket import fit, transform
def train(path, num_classes, training_size, **kwargs):
# -- init ------------------------------------------------------------------
# default hyperparameters are reusable for any dataset
args = \
{
"num_features" : 10_000,
"validation_size" : 2 ** 11,
"chunk_size" : 2 ** 12,
"minibatch_size" : 256,
"lr" : 1e-4,
"max_epochs" : 50,
"patience_lr" : 5, # 50 minibatches
"patience" : 10, # 100 minibatches
"cache_size" : training_size # set to 0 to prevent caching
}
args = {**args, **kwargs}
_num_features = 84 * (args["num_features"] // 84)
num_chunks = np.int32(np.ceil(training_size / args["chunk_size"]))
def init(layer):
if isinstance(layer, nn.Linear):
nn.init.constant_(layer.weight.data, 0)
nn.init.constant_(layer.bias.data, 0)
# -- cache -----------------------------------------------------------------
# cache as much as possible to avoid unecessarily repeating the transform
# consider caching to disk if appropriate, along the lines of numpy.memmap
cache_X = torch.zeros((args["cache_size"], _num_features))
cache_Y = torch.zeros(args["cache_size"], dtype = torch.long)
cache_count = 0
fully_cached = False
# -- model -----------------------------------------------------------------
model = nn.Sequential(nn.Linear(_num_features, num_classes))
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = args["lr"])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.5, min_lr = 1e-8, patience = args["patience_lr"])
model.apply(init)
# -- validation data -------------------------------------------------------
# gotcha: copy() is essential to avoid competition for memory access with read_csv(...)
validation_data = pd.read_csv(path,
header = None,
sep = ",",
nrows = args["validation_size"],
engine = "c").values.copy()
Y_validation, X_validation = torch.LongTensor(validation_data[:, 0]), validation_data[:, 1:].astype(np.float32)
# -- run -------------------------------------------------------------------
minibatch_count = 0
best_validation_loss = np.inf
stall_count = 0
stop = False
print("Training... (faster once caching is finished)")
for epoch in range(args["max_epochs"]):
if epoch > 0 and stop:
break
if not fully_cached:
file = pd.read_csv(path,
header = None,
sep = ",",
skiprows = args["validation_size"],
chunksize = args["chunk_size"],
engine = "c")
for chunk_index in range(num_chunks):
a = chunk_index * args["chunk_size"]
b = min(a + args["chunk_size"], training_size)
_b = b - a
if epoch > 0 and stop:
break
print(f"Epoch {epoch + 1}; Chunk = {chunk_index + 1}...".ljust(80, " "), end = "\r", flush = True)
# if not fully cached, read next file chunk
if not fully_cached:
# gotcha: copy() is essential to avoid competition for memory access with read_csv(...)
training_data = file.get_chunk().values[:_b].copy()
Y_training, X_training = torch.LongTensor(training_data[:, 0]), training_data[:, 1:].astype(np.float32)
if epoch == 0 and chunk_index == 0:
parameters = fit(X_training, args["num_features"])
# transform validation data
X_validation_transform = transform(X_validation, parameters)
# if cached, retrieve from cache
if b <= cache_count:
X_training_transform = cache_X[a:b]
Y_training = cache_Y[a:b]
# else, transform and cache
else:
# transform training data
X_training_transform = transform(X_training, parameters)
if epoch == 0 and chunk_index == 0:
# per-feature mean and standard deviation
f_mean = X_training_transform.mean(0)
f_std = X_training_transform.std(0) + 1e-8
# normalise validation features
X_validation_transform = (X_validation_transform - f_mean) / f_std
X_validation_transform = torch.FloatTensor(X_validation_transform)
# normalise training features
X_training_transform = (X_training_transform - f_mean) / f_std
X_training_transform = torch.FloatTensor(X_training_transform)
# cache as much of the transform as possible
if b <= args["cache_size"]:
cache_X[a:b] = X_training_transform
cache_Y[a:b] = Y_training
cache_count = b
if cache_count >= training_size:
fully_cached = True
minibatches = torch.randperm(len(X_training_transform)).split(args["minibatch_size"])
# train on transformed features
for minibatch_index, minibatch in enumerate(minibatches):
if epoch > 0 and stop:
break
if minibatch_index > 0 and len(minibatch) < args["minibatch_size"]:
break
# -- training --------------------------------------------------
optimizer.zero_grad()
_Y_training = model(X_training_transform[minibatch])
training_loss = loss_function(_Y_training, Y_training[minibatch])
training_loss.backward()
optimizer.step()
minibatch_count += 1
if minibatch_count % 10 == 0:
_Y_validation = model(X_validation_transform)
validation_loss = loss_function(_Y_validation, Y_validation)
scheduler.step(validation_loss)
if validation_loss.item() >= best_validation_loss:
stall_count += 1
if stall_count >= args["patience"]:
stop = True
print(f"\n<Stopped at Epoch {epoch + 1}>")
else:
best_validation_loss = validation_loss.item()
best_model = copy.deepcopy(model)
if not stop:
stall_count = 0
return parameters, best_model, f_mean, f_std
def predict(path,
parameters,
model,
f_mean,
f_std,
**kwargs):
args = \
{
"score" : True,
"chunk_size" : 2 ** 12,
"test_size" : None
}
args = {**args, **kwargs}
file = pd.read_csv(path,
header = None,
sep = ",",
chunksize = args["chunk_size"],
nrows = args["test_size"],
engine = "c")
predictions = []
correct = 0
total = 0
for chunk_index, chunk in enumerate(file):
print(f"Chunk = {chunk_index + 1}...".ljust(80, " "), end = "\r")
# gotcha: copy() is essential to avoid competition for memory access with read_csv(...)
test_data = chunk.values.copy()
Y_test, X_test = test_data[:, 0], test_data[:, 1:].astype(np.float32)
X_test_transform = transform(X_test, parameters)
X_test_transform = (X_test_transform - f_mean) / f_std
X_test_transform = torch.FloatTensor(X_test_transform)
_predictions = model(X_test_transform).argmax(1).numpy()
predictions.append(_predictions)
total += len(test_data)
correct += (_predictions == Y_test).sum()
if args["score"]:
return np.concatenate(predictions), correct / total
else:
return np.concatenate(predictions)