Skip to content

Commit 44aa961

Browse files
committed
minor cleanup
1 parent f39c71b commit 44aa961

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

train.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def train(params: ModelParams):
5757
optimizer = torch.optim.SGD(model.parameters(), lr=params.lr, momentum=params.momentum)
5858
loss_fn = torch.nn.BCEWithLogitsLoss()
5959
train_loader = DataLoader(XORDataset(), batch_size=params.batch_size, shuffle=True)
60+
test_loader = DataLoader(XORDataset(train=False), batch_size=params.batch_size)
6061

6162
step = 0
6263

@@ -75,31 +76,27 @@ def train(params: ModelParams):
7576
optimizer.step()
7677
step += 1
7778

78-
loss_val = loss.item()
79-
accuracy_val = ((predictions > 0.5) == (targets > 0.5)).type(torch.FloatTensor).mean()
79+
accuracy = ((predictions > 0.5) == (targets > 0.5)).type(torch.FloatTensor).mean()
8080

8181
if step % 500 == 0:
82-
print(f'epoch {epoch}, step {step}, loss {loss_val:.{4}f}, accuracy {accuracy_val:.{3}f}')
82+
print(f'epoch {epoch}, step {step}, loss {loss.item():.{4}f}, accuracy {accuracy:.{3}f}')
8383

8484
# evaluate per epoch
85-
evaluate(model)
86-
85+
evaluate(model, test_loader)
8786

88-
def evaluate(model):
89-
test_loader = DataLoader(XORDataset(train=False), batch_size=params.batch_size)
9087

91-
prediction_is_correct = np.array([])
88+
def evaluate(model, loader):
89+
is_correct = np.array([])
9290

93-
for inputs, targets in test_loader:
91+
for inputs, targets in loader:
9492
inputs = inputs.to(params.device)
9593
targets = targets.to(params.device)
9694
with torch.no_grad():
9795
logits, predictions = model(inputs)
98-
prediction_is_correct = np.append(prediction_is_correct,
99-
((predictions > 0.5) == (targets > 0.5)))
96+
is_correct = np.append(is_correct, ((predictions > 0.5) == (targets > 0.5)))
10097

101-
accuracy_val = prediction_is_correct.mean()
102-
print(f'test accuracy {accuracy_val:.{3}f}')
98+
accuracy = is_correct.mean()
99+
print(f'test accuracy {accuracy:.{3}f}')
103100

104101

105102
def get_arguments():

0 commit comments

Comments
 (0)