Skip to content

Commit e1cd215

Browse files
committed
make state shape parameters explict
1 parent a1fc948 commit e1cd215

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

train.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
from xor_dataset import XORDataset
44

55
BATCH_SIZE = 32
6+
HIDDEN_SIZE = 1
7+
NUM_LAYERS = 1
68

7-
model = torch.nn.LSTM(batch_first=True, input_size=1, hidden_size=1, num_layers=1)
9+
model = torch.nn.LSTM(
10+
batch_first=True, input_size=1, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS)
811

9-
optimizer = torch.optim.SGD(model.parameters(), lr=1e-03)
12+
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
1013
loss_fn = torch.nn.BCEWithLogitsLoss()
1114
train_loader = DataLoader(XORDataset(), batch_size=BATCH_SIZE, shuffle=True)
1215

@@ -22,9 +25,7 @@
2225
optimizer.zero_grad()
2326

2427
# reset hidden state per sequence
25-
# state is (num_cells, batch_size, hidden_size)
26-
state_shape = 1, BATCH_SIZE, 1
27-
h0 = c0 = inputs.new_zeros(state_shape)
28+
h0 = c0 = inputs.new_zeros((NUM_LAYERS, BATCH_SIZE, HIDDEN_SIZE))
2829

2930
final_outputs, _ = model(inputs, (h0, c0))
3031

@@ -37,5 +38,5 @@
3738
step += 1
3839

3940
loss_val = loss.item()
40-
if step % 500 == 0:
41+
if step % 100 == 0:
4142
print(f'LOSS step {step}: {loss_val}')

0 commit comments

Comments
 (0)