Skip to content

Commit 3ed1935

Browse files
committed
predict parity per bit
- setup cli arguments for parameters
1 parent 77a60f0 commit 3ed1935

File tree

2 files changed

+74
-30
lines changed

2 files changed

+74
-30
lines changed

train.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,66 @@
1+
import argparse
12
import torch
23
from torch.utils.data import DataLoader
4+
from typing import NamedTuple
35
from xor_dataset import XORDataset
6+
from utils import register_parser_types
47

5-
BATCH_SIZE = 32
6-
HIDDEN_SIZE = 1
7-
NUM_LAYERS = 1
88

9-
model = torch.nn.LSTM(
10-
batch_first=True, input_size=1, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS)
9+
class ModelParams(NamedTuple):
10+
# train loop
11+
batch_size: int = 32
12+
epochs: int = 10
1113

12-
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
13-
loss_fn = torch.nn.BCEWithLogitsLoss()
14-
train_loader = DataLoader(XORDataset(), batch_size=BATCH_SIZE, shuffle=True)
14+
# lstm
15+
hidden_size: int = 1
16+
learning_rate: float = 1e-1
17+
num_layers: int = 1
1518

16-
step = 0
1719

18-
for inputs, targets in train_loader:
19-
# [batch, bits] -> [batch, bits, 1]
20-
inputs = torch.unsqueeze(inputs, -1)
20+
def train(params: ModelParams):
21+
model = torch.nn.LSTM(
22+
batch_first=True, input_size=1, hidden_size=params.hidden_size, num_layers=params.num_layers)
2123

22-
# [1] -> [1, 1]
23-
targets = torch.unsqueeze(targets, -1)
24+
optimizer = torch.optim.SGD(model.parameters(), lr=params.learning_rate)
25+
loss_fn = torch.nn.BCEWithLogitsLoss()
26+
train_loader = DataLoader(XORDataset(), batch_size=params.batch_size, shuffle=True)
2427

25-
optimizer.zero_grad()
28+
step = 0
2629

27-
# reset hidden state per sequence
28-
h0 = c0 = inputs.new_zeros((NUM_LAYERS, BATCH_SIZE, HIDDEN_SIZE))
30+
for epoch in range(1, params.epochs):
31+
for inputs, targets in train_loader:
32+
# [batch, bits] -> [batch, bits, 1]
33+
inputs = torch.unsqueeze(inputs, -1)
2934

30-
final_outputs, _ = model(inputs, (h0, c0))
35+
# [batch, parity] -> [batch, parity, 1]
36+
targets = torch.unsqueeze(targets, -1)
3137

32-
# select the last prediction
33-
# XXX we should calculate parity per bit in the lstm
34-
loss = loss_fn(final_outputs[:, -1], targets)
38+
optimizer.zero_grad()
3539

36-
loss.backward()
37-
optimizer.step()
38-
step += 1
40+
# reset hidden state per sequence
41+
h0 = c0 = inputs.new_zeros((params.num_layers, params.batch_size, params.hidden_size))
3942

40-
loss_val = loss.item()
41-
if step % 100 == 0:
42-
print(f'LOSS step {step}: {loss_val}')
43+
final_outputs, _ = model(inputs, (h0, c0))
44+
45+
# select the last prediction
46+
loss = loss_fn(final_outputs, targets)
47+
48+
loss.backward()
49+
optimizer.step()
50+
step += 1
51+
52+
loss_val = loss.item()
53+
if step % 500 == 0:
54+
print(f'epoch {epoch}, step {step}, loss {loss_val}')
55+
56+
57+
def get_arguments():
58+
parser = argparse.ArgumentParser()
59+
register_parser_types(parser, ModelParams)
60+
arguments = parser.parse_args()
61+
return arguments
62+
63+
64+
if __name__ == '__main__':
65+
params = get_arguments()
66+
train(params)

utils.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,34 @@
11
import os
22
import shutil
3+
import typing
4+
5+
# ------------------------- Parser Utils -------------------------
6+
7+
8+
def register_parser_types(parser, params_named_tuple):
9+
"""Register arguments based on the named tuple"""
10+
# XXX upgrade to support dataclass instead after python 3.7.0
11+
parser.register('type', bool, lambda v: v.lower() == 'true')
12+
parser.register('type', typing.List[int], lambda v: tuple(map(int, v.split(','))))
13+
14+
hints = typing.get_type_hints(params_named_tuple)
15+
defaults = params_named_tuple()._asdict()
16+
17+
for key, _type in hints.items():
18+
parser.add_argument(f'--{key}', type=_type, default=defaults.get(key))
19+
20+
21+
# ------------------------- Path Utils -------------------------
22+
323

424
def ensure_path(path):
5-
"""Create the path if it does not exist
6-
"""
25+
"""Create the path if it does not exist"""
726
if not os.path.exists(path):
827
os.makedirs(path)
928
return path
1029

30+
1131
def remove_path(path):
1232
"""Remove the path if it exists."""
1333
if os.path.exists(path):
14-
shutil.rmtree(path)
34+
shutil.rmtree(path)

0 commit comments

Comments
 (0)