Skip to content

Commit

Permalink
fix(info): fix errors when save info
Browse files Browse the repository at this point in the history
  • Loading branch information
Yidadaa committed Jun 23, 2019
1 parent 45a303a commit 26644db
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
8 changes: 5 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,14 @@ def train(model:nn.Sequential, dataloader:torch.utils.data.DataLoader, optimizer
loss.backward()
optimizer.step()

y_ = y_.argmax(dim=1)
acc = accuracy_score(y_.cpu().numpy(), y.cpu().numpy())

# 保存loss等信息
train_losses.append(loss)
train_losses.append(loss.item())
train_scores.append(acc)

if (i + 1) % config.log_interval == 0:
y_ = y_.argmax(dim=1)
acc = accuracy_score(y_.cpu().numpy(), y.cpu().numpy())
print('[Epoch %3d]Training %3d of %3d: acc = %.2f, loss = %.2f' % (epoch, i + 1, len(dataloader), acc, loss.item()))

return train_losses, train_scores
Expand Down
2 changes: 1 addition & 1 deletion train.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@


CUDA_VISIBLE_DEVICES=2 python train.py -r ./checkpoints/ep-1.pth
CUDA_VISIBLE_DEVICES=2 python train.py
1 change: 1 addition & 0 deletions train_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"train_losses": [[

0 comments on commit 26644db

Please sign in to comment.