16
16
17
17
18
18
def evaluate (seq2seq_model , eval_pairs , criterion , eval = 'val' , graph = False ):
19
+ """
20
+ Evaluate model and return metrics.
21
+ """
19
22
with torch .no_grad ():
20
23
loss = 0
21
24
f1 = 0
@@ -63,6 +66,9 @@ def evaluate(seq2seq_model, eval_pairs, criterion, eval='val', graph=False):
63
66
64
67
def train (input_tensor , target_tensor , seq2seq_model , optimizer , criterion , graph ,
65
68
adj_tensor = None , node_features = None ):
69
+ """
70
+ Train model for a single iteration.
71
+ """
66
72
optimizer .zero_grad ()
67
73
68
74
if graph :
@@ -83,6 +89,9 @@ def train(input_tensor, target_tensor, seq2seq_model, optimizer, criterion, grap
83
89
84
90
def train_iters (seq2seq_model , n_iters , pairs , print_every = 1000 , learning_rate = 0.001 ,
85
91
model_dir = None , lang = None , graph = False ):
92
+ """
93
+ Run complete training of the model.
94
+ """
86
95
train_losses = []
87
96
val_losses = []
88
97
@@ -101,6 +110,7 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
101
110
102
111
optimizer = optim .Adam (seq2seq_model .parameters (), lr = learning_rate )
103
112
113
+ # Prepare data
104
114
if graph :
105
115
training_pairs = [tensors_from_pair_tokens_graph (random .choice (train_pairs ), lang )
106
116
for i in range (n_iters )]
@@ -113,6 +123,7 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
113
123
# test_tensor_pairs = [tensors_from_pair_tokens(test_pair, lang) for test_pair in test_pairs]
114
124
criterion = nn .NLLLoss ()
115
125
126
+ # Train
116
127
for iter in range (1 , n_iters + 1 ):
117
128
training_pair = training_pairs [iter - 1 ]
118
129
if graph :
@@ -145,11 +156,8 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
145
156
rouge_2 += rouge_2_temp
146
157
rouge_l += rouge_l_temp
147
158
148
- # print("Pred: {}".format(lang.to_tokens(pred)))
149
- # print("Target: {}".format(lang.to_tokens(target_tensor.numpy().reshape(-1))))
150
- # print()
151
-
152
159
if iter % print_every == 0 :
160
+ # Evaluate
153
161
print_loss_avg = print_loss_total / print_every
154
162
print_loss_total = 0
155
163
print ('train (%d %d%%) %.4f' % (iter , iter / n_iters * 100 , print_loss_avg ))
@@ -160,9 +168,6 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
160
168
train_loss = print_loss_avg
161
169
val_loss , val_f1 , val_rouge_2 , val_rouge_l = evaluate (seq2seq_model , val_tensor_pairs ,
162
170
criterion , graph = graph )
163
- # test_loss, test_f1, test_rouge_2, test_rouge_l = evaluate(seq2seq_model,
164
- # test_tensor_pairs,
165
- # criterion, eval='test')
166
171
167
172
if not val_losses or val_loss < min (val_losses ):
168
173
torch .save (seq2seq_model .state_dict (), model_dir + 'model.pt' )
@@ -176,6 +181,7 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
176
181
val_rouge_2_scores .append (val_rouge_2 )
177
182
val_rouge_l_scores .append (val_rouge_l )
178
183
184
+ # Store results
179
185
results = {'train_losses' : train_losses ,
180
186
'val_losses' : val_losses ,
181
187
'val_f1_scores' : val_f1_scores ,
0 commit comments