-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
77 lines (65 loc) · 2.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import tqdm
from log import Logger
from sys import argv
torch.manual_seed(2)
grok = "--grok" in argv
log = "--log" in argv
logger = Logger("grokking" if grok else "comprehension") if log else None
# Hyperparameters
num_epochs = 10000
learning_rate = 3e-2
weight_decay = 3e-2 if grok else 5
# Data - sum of two numbers mod 53
P = 53
train_frac = 0.6
X = torch.cartesian_prod(torch.arange(P), torch.arange(P))
y = (X[:, 0] + X[:, 1]) % P
shuffle = torch.randperm(len(X))
X, y = X[shuffle], y[shuffle]
X_train, X_val = X[: int(train_frac * len(X))], X[int(train_frac * len(X)) :]
y_train, y_val = y[: int(train_frac * len(y))], y[int(train_frac * len(y)) :]
# Model
class Model(nn.Module):
def __init__(self, hidden_dim=256):
super(Model, self).__init__()
self.embedding = nn.Embedding(P, hidden_dim)
self.fc1 = nn.Linear(2 * hidden_dim, hidden_dim)
self.readout = nn.Linear(hidden_dim, P)
def forward(self, x):
x = self.embedding(x).flatten(start_dim=1)
x = torch.relu(self.fc1(x))
x = self.readout(x)
return x
model = Model(hidden_dim=128)
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
criterion = nn.CrossEntropyLoss()
# Train
if __name__ == "__main__":
print("Train Loss, Acc | Val Loss, Acc")
pbar = tqdm.trange(num_epochs, leave=True, position=0)
for epoch in pbar:
optimizer.zero_grad()
y_pred = model(X_train)
loss = criterion(y_pred, y_train)
loss.backward()
optimizer.step()
with torch.no_grad():
train_acc = (y_pred.argmax(dim=1) == y_train).float().mean() * 100
y_pred = model(X_val)
val_loss = criterion(y_pred, y_val)
val_acc = (y_pred.argmax(dim=1) == y_val).float().mean() * 100
msg = f"{loss:10.2f}, {train_acc:>3.0f} | {val_loss:>8.2f}, {val_acc:>4.0f}"
pbar.set_description(msg)
if log:
logger.log(
model=model,
epoch=epoch,
train_loss=loss,
train_acc=train_acc,
val_loss=val_loss,
val_acc=val_acc,
)