|
| 1 | +import argparse |
1 | 2 | import torch
|
2 | 3 | from torch.utils.data import DataLoader
|
| 4 | +from typing import NamedTuple |
3 | 5 | from xor_dataset import XORDataset
|
| 6 | +from utils import register_parser_types |
4 | 7 |
|
5 |
| -BATCH_SIZE = 32 |
6 |
| -HIDDEN_SIZE = 1 |
7 |
| -NUM_LAYERS = 1 |
8 | 8 |
|
9 |
| -model = torch.nn.LSTM( |
10 |
| - batch_first=True, input_size=1, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS) |
| 9 | +class ModelParams(NamedTuple): |
| 10 | + # train loop |
| 11 | + batch_size: int = 32 |
| 12 | + epochs: int = 10 |
11 | 13 |
|
12 |
| -optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) |
13 |
| -loss_fn = torch.nn.BCEWithLogitsLoss() |
14 |
| -train_loader = DataLoader(XORDataset(), batch_size=BATCH_SIZE, shuffle=True) |
| 14 | + # lstm |
| 15 | + hidden_size: int = 1 |
| 16 | + learning_rate: float = 1e-1 |
| 17 | + num_layers: int = 1 |
15 | 18 |
|
16 |
| -step = 0 |
17 | 19 |
|
18 |
| -for inputs, targets in train_loader: |
19 |
| - # [batch, bits] -> [batch, bits, 1] |
20 |
| - inputs = torch.unsqueeze(inputs, -1) |
| 20 | +def train(params: ModelParams): |
| 21 | + model = torch.nn.LSTM( |
| 22 | + batch_first=True, input_size=1, hidden_size=params.hidden_size, num_layers=params.num_layers) |
21 | 23 |
|
22 |
| - # [1] -> [1, 1] |
23 |
| - targets = torch.unsqueeze(targets, -1) |
| 24 | + optimizer = torch.optim.SGD(model.parameters(), lr=params.learning_rate) |
| 25 | + loss_fn = torch.nn.BCEWithLogitsLoss() |
| 26 | + train_loader = DataLoader(XORDataset(), batch_size=params.batch_size, shuffle=True) |
24 | 27 |
|
25 |
| - optimizer.zero_grad() |
| 28 | + step = 0 |
26 | 29 |
|
27 |
| - # reset hidden state per sequence |
28 |
| - h0 = c0 = inputs.new_zeros((NUM_LAYERS, BATCH_SIZE, HIDDEN_SIZE)) |
| 30 | + for epoch in range(1, params.epochs): |
| 31 | + for inputs, targets in train_loader: |
| 32 | + # [batch, bits] -> [batch, bits, 1] |
| 33 | + inputs = torch.unsqueeze(inputs, -1) |
29 | 34 |
|
30 |
| - final_outputs, _ = model(inputs, (h0, c0)) |
| 35 | + # [batch, parity] -> [batch, parity, 1] |
| 36 | + targets = torch.unsqueeze(targets, -1) |
31 | 37 |
|
32 |
| - # select the last prediction |
33 |
| - # XXX we should calculate parity per bit in the lstm |
34 |
| - loss = loss_fn(final_outputs[:, -1], targets) |
| 38 | + optimizer.zero_grad() |
35 | 39 |
|
36 |
| - loss.backward() |
37 |
| - optimizer.step() |
38 |
| - step += 1 |
| 40 | + # reset hidden state per sequence |
| 41 | + h0 = c0 = inputs.new_zeros((params.num_layers, params.batch_size, params.hidden_size)) |
39 | 42 |
|
40 |
| - loss_val = loss.item() |
41 |
| - if step % 100 == 0: |
42 |
| - print(f'LOSS step {step}: {loss_val}') |
| 43 | + final_outputs, _ = model(inputs, (h0, c0)) |
| 44 | + |
| 45 | + # select the last prediction |
| 46 | + loss = loss_fn(final_outputs, targets) |
| 47 | + |
| 48 | + loss.backward() |
| 49 | + optimizer.step() |
| 50 | + step += 1 |
| 51 | + |
| 52 | + loss_val = loss.item() |
| 53 | + if step % 500 == 0: |
| 54 | + print(f'epoch {epoch}, step {step}, loss {loss_val}') |
| 55 | + |
| 56 | + |
| 57 | +def get_arguments(): |
| 58 | + parser = argparse.ArgumentParser() |
| 59 | + register_parser_types(parser, ModelParams) |
| 60 | + arguments = parser.parse_args() |
| 61 | + return arguments |
| 62 | + |
| 63 | + |
| 64 | +if __name__ == '__main__': |
| 65 | + params = get_arguments() |
| 66 | + train(params) |
0 commit comments