File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -11,11 +11,12 @@ class ModelParams(NamedTuple):
11
11
# train loop
12
12
batch_size : int = 32
13
13
device : str = 'cuda:0' if torch .cuda .is_available () else 'cpu'
14
- epochs : int = 5
14
+ epochs : int = 15
15
15
16
16
# lstm
17
17
hidden_size : int = 2
18
18
lr : float = 1e-1
19
+ momentum : float = 0.9
19
20
num_layers : int = 1
20
21
21
22
@@ -53,7 +54,7 @@ def forward(self, inputs):
53
54
def train (params : ModelParams ):
54
55
model = LSTM (params ).to (params .device )
55
56
56
- optimizer = torch .optim .Adam (model .parameters (), lr = params .lr )
57
+ optimizer = torch .optim .SGD (model .parameters (), lr = params .lr , momentum = params . momentum )
57
58
loss_fn = torch .nn .BCEWithLogitsLoss ()
58
59
train_loader = DataLoader (XORDataset (), batch_size = params .batch_size , shuffle = True )
59
60
You can’t perform that action at this time.
0 commit comments