File tree Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change 3
3
from xor_dataset import XORDataset
4
4
5
5
BATCH_SIZE = 32
6
+ HIDDEN_SIZE = 1
7
+ NUM_LAYERS = 1
6
8
7
- model = torch .nn .LSTM (batch_first = True , input_size = 1 , hidden_size = 1 , num_layers = 1 )
9
+ model = torch .nn .LSTM (
10
+ batch_first = True , input_size = 1 , hidden_size = HIDDEN_SIZE , num_layers = NUM_LAYERS )
8
11
9
- optimizer = torch .optim .SGD (model .parameters (), lr = 1e-03 )
12
+ optimizer = torch .optim .SGD (model .parameters (), lr = 1e-2 )
10
13
loss_fn = torch .nn .BCEWithLogitsLoss ()
11
14
train_loader = DataLoader (XORDataset (), batch_size = BATCH_SIZE , shuffle = True )
12
15
22
25
optimizer .zero_grad ()
23
26
24
27
# reset hidden state per sequence
25
- # state is (num_cells, batch_size, hidden_size)
26
- state_shape = 1 , BATCH_SIZE , 1
27
- h0 = c0 = inputs .new_zeros (state_shape )
28
+ h0 = c0 = inputs .new_zeros ((NUM_LAYERS , BATCH_SIZE , HIDDEN_SIZE ))
28
29
29
30
final_outputs , _ = model (inputs , (h0 , c0 ))
30
31
37
38
step += 1
38
39
39
40
loss_val = loss .item ()
40
- if step % 500 == 0 :
41
+ if step % 100 == 0 :
41
42
print (f'LOSS step { step } : { loss_val } ' )
You can’t perform that action at this time.
0 commit comments