Skip to content

Commit b94ca56

Browse files
committed
Save results more clearly
1 parent 4ea4acd commit b94ca56

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

models/lstm_to_lstm_full_training.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,14 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
175175
val_rouge_2_scores.append(val_rouge_2)
176176
val_rouge_l_scores.append(val_rouge_l)
177177

178-
pickle.dump([train_losses, val_losses, val_f1_scores, val_rouge_2_scores,
179-
val_rouge_l_scores], open(model_dir + 'results.pkl', 'wb'))
178+
results = {'train_losses': train_losses,
179+
'val_losses': val_losses,
180+
'val_f1_scores': val_f1_scores,
181+
'val_rouge_2_scores': val_rouge_2_scores,
182+
'val_rouge_l_scores': val_rouge_l_scores}
183+
184+
with open(model_dir + 'results.txt', 'w') as f:
185+
f.write(str(results))
186+
pickle.dump(results, open(model_dir + 'results.pkl', 'wb'))
180187

181188
plot_loss(train_losses, val_losses, file_path=model_dir + 'loss.jpg')

train.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import torch
32
import argparse
43

54
from models.lstm_to_lstm import Seq2Seq
@@ -11,10 +10,12 @@
1110
from models.gcn_encoder import GCNEncoder
1211

1312

14-
def main(model_name):
15-
model_dir = '../results/{}/'.format(model_name)
13+
def main():
14+
model_dir = '../results/{}/'.format(opt.model_name)
1615
if not os.path.exists(model_dir):
1716
os.makedirs(model_dir)
17+
with open(model_dir + 'hyperparams.txt', 'w') as f:
18+
f.write(str(opt))
1819

1920
if opt.graph:
2021
lang, pairs = prepare_data(num_samples=opt.n_samples)
@@ -55,4 +56,4 @@ def main(model_name):
5556
print(opt)
5657

5758
if __name__ == "__main__":
58-
main(opt.model_name)
59+
main()

0 commit comments

Comments
 (0)