Skip to content

Commit 152c4ab

Browse files
committed
Add documentation in training files
1 parent e33b412 commit 152c4ab

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

training/train.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,27 @@
1111

1212

1313
def main():
14+
"""
15+
Entry-point for running the models.
16+
"""
17+
18+
# Create directory for saving results
1419
model_dir = '../results/{}/'.format(opt.model_name)
1520
if not os.path.exists(model_dir):
1621
os.makedirs(model_dir)
22+
23+
# Store hyperparams
1724
with open(model_dir + 'hyperparams.txt', 'w') as f:
1825
f.write(str(opt))
1926

27+
# Prepare data
2028
if opt.graph:
2129
lang, pairs = prepare_data(num_samples=opt.n_samples)
2230
pairs = [pair for pair in pairs if len(pair[0][1][0]) > 0]
2331
else:
2432
lang, pairs = prepare_tokens(num_samples=opt.n_samples)
2533

34+
# Create model
2635
hidden_size = 256
2736
encoder = LSTMEncoder(lang.n_words, hidden_size, opt.device).to(opt.device)
2837

@@ -38,6 +47,7 @@ def main():
3847
else:
3948
model = FullModel(encoder=encoder, decoder=decoder, device=opt.device)
4049

50+
# Train model
4151
train_iters(model, opt.iterations, pairs, print_every=opt.print_every, model_dir=model_dir,
4252
lang=lang, graph=opt.graph)
4353

training/train_model.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717

1818
def evaluate(seq2seq_model, eval_pairs, criterion, eval='val', graph=False):
19+
"""
20+
Evaluate model and return metrics.
21+
"""
1922
with torch.no_grad():
2023
loss = 0
2124
f1 = 0
@@ -63,6 +66,9 @@ def evaluate(seq2seq_model, eval_pairs, criterion, eval='val', graph=False):
6366

6467
def train(input_tensor, target_tensor, seq2seq_model, optimizer, criterion, graph,
6568
adj_tensor=None, node_features=None):
69+
"""
70+
Train model for a single iteration.
71+
"""
6672
optimizer.zero_grad()
6773

6874
if graph:
@@ -83,6 +89,9 @@ def train(input_tensor, target_tensor, seq2seq_model, optimizer, criterion, grap
8389

8490
def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0.001,
8591
model_dir=None, lang=None, graph=False):
92+
"""
93+
Run complete training of the model.
94+
"""
8695
train_losses = []
8796
val_losses = []
8897

@@ -101,6 +110,7 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
101110

102111
optimizer = optim.Adam(seq2seq_model.parameters(), lr=learning_rate)
103112

113+
# Prepare data
104114
if graph:
105115
training_pairs = [tensors_from_pair_tokens_graph(random.choice(train_pairs), lang)
106116
for i in range(n_iters)]
@@ -113,6 +123,7 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
113123
# test_tensor_pairs = [tensors_from_pair_tokens(test_pair, lang) for test_pair in test_pairs]
114124
criterion = nn.NLLLoss()
115125

126+
# Train
116127
for iter in range(1, n_iters + 1):
117128
training_pair = training_pairs[iter - 1]
118129
if graph:
@@ -145,11 +156,8 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
145156
rouge_2 += rouge_2_temp
146157
rouge_l += rouge_l_temp
147158

148-
# print("Pred: {}".format(lang.to_tokens(pred)))
149-
# print("Target: {}".format(lang.to_tokens(target_tensor.numpy().reshape(-1))))
150-
# print()
151-
152159
if iter % print_every == 0:
160+
# Evaluate
153161
print_loss_avg = print_loss_total / print_every
154162
print_loss_total = 0
155163
print('train (%d %d%%) %.4f' % (iter, iter / n_iters * 100, print_loss_avg))
@@ -160,9 +168,6 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
160168
train_loss = print_loss_avg
161169
val_loss, val_f1, val_rouge_2, val_rouge_l = evaluate(seq2seq_model, val_tensor_pairs,
162170
criterion, graph=graph)
163-
# test_loss, test_f1, test_rouge_2, test_rouge_l = evaluate(seq2seq_model,
164-
# test_tensor_pairs,
165-
# criterion, eval='test')
166171

167172
if not val_losses or val_loss < min(val_losses):
168173
torch.save(seq2seq_model.state_dict(), model_dir + 'model.pt')
@@ -176,6 +181,7 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
176181
val_rouge_2_scores.append(val_rouge_2)
177182
val_rouge_l_scores.append(val_rouge_l)
178183

184+
# Store results
179185
results = {'train_losses': train_losses,
180186
'val_losses': val_losses,
181187
'val_f1_scores': val_f1_scores,

0 commit comments

Comments
 (0)