Skip to content

Commit f39c71b

Browse files
committed
use sgd with momentum to converge
1 parent 90a08c0 commit f39c71b

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ class ModelParams(NamedTuple):
1111
# train loop
1212
batch_size: int = 32
1313
device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu'
14-
epochs: int = 5
14+
epochs: int = 15
1515

1616
# lstm
1717
hidden_size: int = 2
1818
lr: float = 1e-1
19+
momentum: float = 0.9
1920
num_layers: int = 1
2021

2122

@@ -53,7 +54,7 @@ def forward(self, inputs):
5354
def train(params: ModelParams):
5455
model = LSTM(params).to(params.device)
5556

56-
optimizer = torch.optim.Adam(model.parameters(), lr=params.lr)
57+
optimizer = torch.optim.SGD(model.parameters(), lr=params.lr, momentum=params.momentum)
5758
loss_fn = torch.nn.BCEWithLogitsLoss()
5859
train_loader = DataLoader(XORDataset(), batch_size=params.batch_size, shuffle=True)
5960

0 commit comments

Comments
 (0)