Skip to content

Commit ef11f06

Browse files
committed
Refactor code so that all models are run from same script based on configurations
1 parent da491ee commit ef11f06

File tree

4 files changed

+91
-60
lines changed

4 files changed

+91
-60
lines changed

models/lstm_to_lstm.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33

44

55
class Seq2Seq(nn.Module):
6-
def __init__(self, encoder, decoder, device):
6+
def __init__(self, encoder, decoder, device, graph_encoder=None, graph=False):
77
super().__init__()
88

99
self.encoder = encoder
10+
self.graph_encoder = graph_encoder
1011
self.decoder = decoder
1112
self.device = device
13+
self.graph = graph
1214

1315
assert encoder.hidden_size == decoder.hidden_size, "Hidden dimensions of encoder and decoder " \
1416
"must be equal!"
1517

16-
def forward(self, sequence, target):
18+
def forward(self, sequence, target, adj=None):
1719
batch_size = 1
1820
max_len = target.shape[0]
1921
target_vocab_size = self.decoder.output_size
@@ -26,6 +28,18 @@ def forward(self, sequence, target):
2628
# output contains the hidden states for all input elements
2729
encoder_output, hidden = self.encoder(sequence)
2830

31+
if self.graph:
32+
# graph_hidden has shape [1, 1, hidden_size] and contains a graph representation
33+
n_nodes = adj.size(0)
34+
n_tokens = sequence.size(0)
35+
x = torch.zeros(n_nodes, encoder_output.size(2)).to(self.device)
36+
x[:n_tokens, :] = encoder_output.view(encoder_output.size(1), encoder_output.size(2))
37+
graph_hidden = self.graph_encoder(x=x, adj=adj)
38+
39+
# TODO: Combine the graph representation with the seq_encoder final layer using mlp
40+
41+
hidden = (graph_hidden.view(1, 1, graph_hidden.size(0)), hidden[1])
42+
2943
# first input to the decoder is the <sos> tokens
3044
input = torch.tensor([[0]], device=self.device)
3145

models/lstm_to_lstm_full_training.py

Lines changed: 50 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
from __future__ import unicode_literals, print_function, division
22
import random
3-
from models.lstm_encoder import LSTMEncoder
4-
from models.lstm_decoder import LSTMDecoder
5-
from models.attention_decoder import AttentionDecoder
6-
from models.lstm_to_lstm import Seq2Seq
7-
from tokens_util import prepare_tokens, tensors_from_pair_tokens, plot_loss
3+
from tokens_util import tensors_from_pair_tokens, plot_loss, tensors_from_pair_tokens_graph
84

95
import torch
106
import torch.nn as nn
@@ -13,35 +9,34 @@
139
import numpy as np
1410
from metrics import compute_rouge_scores
1511
import pickle
16-
import os
17-
import sys
1812

19-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2013

21-
# lang, pairs = prepare_tokens()
22-
# # test_pairs = pairs[-10000:]
23-
# # val_pairs = pairs[-20000:-10000]
24-
# # train_pairs = pairs[:-20000]
25-
# # pairs = pairs[:100]
26-
# train_pairs, val_pairs, test_pairs = np.split(pairs, [int(.8*len(pairs)), int(.9*len(pairs))])
27-
#
28-
# test_pairs = test_pairs
29-
# val_pairs = val_pairs
30-
# train_pairs = train_pairs
14+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3115

3216

33-
def evaluate(seq2seq_model, eval_pairs, criterion, eval='val'):
17+
def evaluate(seq2seq_model, eval_pairs, criterion, eval='val', graph=False):
3418
with torch.no_grad():
3519
loss = 0
3620
f1 = 0
3721
rouge_2 = 0
3822
rouge_l = 0
3923
for i in range(len(eval_pairs)):
40-
eval_pair = eval_pairs[i]
41-
input_tensor = eval_pair[0]
42-
target_tensor = eval_pair[1]
24+
if graph:
25+
eval_pair = eval_pairs[i]
26+
input_tensor = eval_pair[0][0].to(device)
27+
adj_tensor = eval_pair[0][1].to(device)
28+
target_tensor = eval_pair[1].to(device)
29+
30+
output = seq2seq_model(sequence=input_tensor.view(-1), adj=adj_tensor,
31+
target=target_tensor.view(-1))
32+
else:
33+
eval_pair = eval_pairs[i]
34+
input_tensor = eval_pair[0]
35+
target_tensor = eval_pair[1]
36+
37+
output = seq2seq_model(sequence=input_tensor.view(-1), target=target_tensor.view(
38+
-1))
4339

44-
output = seq2seq_model(input_tensor.view(-1), target_tensor.view(-1))
4540
loss += criterion(output.view(-1, output.size(2)), target_tensor.view(-1))
4641
pred = output.view(-1, output.size(2)).argmax(1).cpu().numpy()
4742

@@ -64,10 +59,15 @@ def evaluate(seq2seq_model, eval_pairs, criterion, eval='val'):
6459
return loss, f1, rouge_2, rouge_l
6560

6661

67-
def train(input_tensor, target_tensor, seq2seq_model, optimizer, criterion):
62+
def train(input_tensor, target_tensor, seq2seq_model, optimizer, criterion, graph, adj_tensor=None):
6863
optimizer.zero_grad()
6964

70-
output = seq2seq_model(input_tensor.view(-1), target_tensor.view(-1))
65+
if graph:
66+
output = seq2seq_model(sequence=input_tensor.view(-1), adj=adj_tensor,
67+
target=target_tensor.view(-1))
68+
else:
69+
output = seq2seq_model(sequence=input_tensor.view(-1), target=target_tensor.view(-1))
70+
7171
loss = criterion(output.view(-1, output.size(2)), target_tensor.view(-1))
7272
pred = output.view(-1, output.size(2)).argmax(1).cpu().numpy()
7373

@@ -79,7 +79,7 @@ def train(input_tensor, target_tensor, seq2seq_model, optimizer, criterion):
7979

8080

8181
def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0.001,
82-
model_dir=None, lang=None):
82+
model_dir=None, lang=None, graph=False):
8383
train_losses = []
8484
val_losses = []
8585

@@ -97,18 +97,35 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
9797
[int(.8 * len(pairs)), int(.9 * len(pairs))])
9898

9999
optimizer = optim.Adam(seq2seq_model.parameters(), lr=learning_rate)
100-
training_pairs = [tensors_from_pair_tokens(random.choice(train_pairs), lang)
101-
for i in range(n_iters)]
102-
val_tensor_pairs = [tensors_from_pair_tokens(val_pair, lang) for val_pair in val_pairs]
100+
101+
if graph:
102+
training_pairs = [tensors_from_pair_tokens_graph(random.choice(train_pairs), lang)
103+
for i in range(n_iters)]
104+
val_tensor_pairs = [tensors_from_pair_tokens_graph(val_pair, lang) for val_pair in val_pairs]
105+
else:
106+
training_pairs = [tensors_from_pair_tokens(random.choice(train_pairs), lang)
107+
for i in range(n_iters)]
108+
val_tensor_pairs = [tensors_from_pair_tokens(val_pair, lang) for val_pair in val_pairs]
109+
103110
# test_tensor_pairs = [tensors_from_pair_tokens(test_pair, lang) for test_pair in test_pairs]
104111
criterion = nn.NLLLoss()
105112

106113
for iter in range(1, n_iters + 1):
107114
training_pair = training_pairs[iter - 1]
108-
input_tensor = training_pair[0]
109-
target_tensor = training_pair[1]
115+
if graph:
116+
input_tensor = training_pair[0][0].to(device)
117+
adj_tensor = training_pair[0][1].to(device)
118+
target_tensor = training_pair[1].to(device)
119+
120+
loss, pred = train(input_tensor, target_tensor, seq2seq_model, optimizer,
121+
criterion, adj_tensor=adj_tensor, graph=graph)
122+
else:
123+
input_tensor = training_pair[0]
124+
target_tensor = training_pair[1]
125+
126+
loss, pred = train(input_tensor, target_tensor, seq2seq_model, optimizer, criterion,
127+
graph=graph)
110128

111-
loss, pred = train(input_tensor, target_tensor, seq2seq_model, optimizer, criterion)
112129
print_loss_total += loss
113130
plot_loss_total += loss
114131

@@ -138,7 +155,7 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
138155

139156
train_loss = print_loss_avg
140157
val_loss, val_f1, val_rouge_2, val_rouge_l = evaluate(seq2seq_model, val_tensor_pairs,
141-
criterion)
158+
criterion, graph=graph)
142159
# test_loss, test_f1, test_rouge_2, test_rouge_l = evaluate(seq2seq_model,
143160
# test_tensor_pairs,
144161
# criterion, eval='test')
@@ -159,21 +176,3 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
159176
open(model_dir + 'res.pkl', 'wb'))
160177

161178
plot_loss(train_losses, val_losses, file_path=model_dir + 'loss.jpg')
162-
163-
164-
# def main(model_name):
165-
# model_dir = '../results/{}/'.format(model_name)
166-
# if not os.path.exists(model_dir):
167-
# os.makedirs(model_dir)
168-
#
169-
# hidden_size = 256
170-
# encoder1 = LSTMEncoder(lang.n_words, hidden_size).to(device)
171-
# attn_decoder1 = LSTMDecoder(hidden_size, lang.n_words).to(device)
172-
# # attn_decoder1 = AttentionDecoder(hidden_size, lang.n_words).to(device)
173-
# lstm2lstm = Seq2Seq(encoder1, attn_decoder1, device)
174-
# train_iters(lstm2lstm, 500000, print_every=100, model_dir=model_dir)
175-
# # train_iters(lstm2lstm, 50, print_every=10, plot_every=1000)
176-
#
177-
#
178-
# if __name__ == "__main__":
179-
# main(sys.argv[1])

tokens_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@ def prepare_tokens(num_samples=None):
6262
return lang, pairs
6363

6464

65-
def prepare_data():
65+
def prepare_data(num_samples=None):
6666
lang = TokenLang('code')
6767
pairs = read_data()
68+
pairs = pairs if not num_samples else pairs[:num_samples]
6869
print("Read %s sentence pairs" % len(pairs))
6970
for pair in pairs:
7071
lang.addSentence(pair[0][0])

train.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,39 @@
66
from models.lstm_to_lstm_full_training import train_iters
77
from models.lstm_encoder import LSTMEncoder
88
from models.lstm_decoder import LSTMDecoder
9-
from tokens_util import prepare_tokens
9+
from tokens_util import prepare_tokens, prepare_data
10+
from models.gat_encoder import GATEncoder
11+
from models.gcn_encoder import GCNEncoder
1012

1113

1214
def main(model_name):
1315
model_dir = '../results/{}/'.format(model_name)
1416
if not os.path.exists(model_dir):
1517
os.makedirs(model_dir)
1618

17-
lang, pairs = prepare_tokens(num_samples=opt.n_samples)
19+
if opt.graph:
20+
lang, pairs = prepare_data(num_samples=opt.n_samples)
21+
pairs = [pair for pair in pairs if len(pair[0][1]) > 0]
22+
else:
23+
lang, pairs = prepare_tokens(num_samples=opt.n_samples)
1824

1925
hidden_size = 256
20-
encoder1 = LSTMEncoder(lang.n_words, hidden_size, opt.device).to(opt.device)
26+
encoder = LSTMEncoder(lang.n_words, hidden_size, opt.device).to(opt.device)
2127

2228
decoder = LSTMDecoder(hidden_size, lang.n_words, opt.device, attention=opt.attention).to(
2329
opt.device)
24-
lstm2lstm = Seq2Seq(encoder1, decoder, opt.device)
25-
train_iters(lstm2lstm, opt.iterations, pairs, print_every=opt.print_every, model_dir=model_dir,
26-
lang=lang)
30+
if opt.graph:
31+
if opt.gat:
32+
graph_encoder = GATEncoder(hidden_size, hidden_size)
33+
else:
34+
graph_encoder = GCNEncoder(hidden_size, hidden_size)
35+
model = Seq2Seq(encoder=encoder, graph_encoder=graph_encoder, decoder=decoder,
36+
device=opt.device)
37+
else:
38+
model = Seq2Seq(encoder=encoder, decoder=decoder, device=opt.device)
39+
40+
train_iters(model, opt.iterations, pairs, print_every=opt.print_every, model_dir=model_dir,
41+
lang=lang, graph=opt.graph)
2742

2843

2944
parser = argparse.ArgumentParser()
@@ -33,6 +48,8 @@ def main(model_name):
3348
parser.add_argument('--n_samples', type=int, default=None, help='Number of samples to train on')
3449
parser.add_argument('--print_every', type=int, default=1000, help='Number of samples to train on')
3550
parser.add_argument('--iterations', type=int, default=100, help='Number of samples to train on')
51+
parser.add_argument('--graph', type=bool, default=False, help='Number of samples to train on')
52+
parser.add_argument('--gat', type=bool, default=False, help='Number of samples to train on')
3653

3754
opt = parser.parse_args()
3855
print(opt)

0 commit comments

Comments
 (0)