Skip to content

Commit 371d0eb

Browse files
committed
Use node features for graph encoder
1 parent ef11f06 commit 371d0eb

File tree

4 files changed

+37
-41
lines changed

4 files changed

+37
-41
lines changed

models/lstm_decoder.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,6 @@ def forward(self, input, hidden, encoder_hiddens, input_seq=None):
4444
output = torch.cat((output, context), 2)
4545
output = self.attention_combine(output)
4646

47-
elif self.pointer_network:
48-
# Create a matrix of shape [batch_size, seq_len, 2 * hidden_dim] where the last
49-
# dimension is a concatenation of the ith encoder hidden state and the current decoder
50-
# hidden
51-
hiddens = torch.cat((encoder_hiddens, hidden[0].repeat(1, encoder_hiddens.size(1), 1)),
52-
dim=2)
53-
54-
# attention_coeff has shape [seq_len] and contains the attention coeffiecients for
55-
# each encoder hidden state
56-
attention_coeff = F.softmax(torch.squeeze(self.attention_layer(hiddens)), dim=0)
57-
# TODO: This is the output already
58-
5947
output = F.relu(output)
6048
output, hidden = self.gru(output, hidden)
6149
output = self.softmax(self.out(output[0]))

models/lstm_to_lstm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, encoder, decoder, device, graph_encoder=None, graph=False):
1515
assert encoder.hidden_size == decoder.hidden_size, "Hidden dimensions of encoder and decoder " \
1616
"must be equal!"
1717

18-
def forward(self, sequence, target, adj=None):
18+
def forward(self, sequence, target, adj=None, node_features=None):
1919
batch_size = 1
2020
max_len = target.shape[0]
2121
target_vocab_size = self.decoder.output_size
@@ -34,6 +34,7 @@ def forward(self, sequence, target, adj=None):
3434
n_tokens = sequence.size(0)
3535
x = torch.zeros(n_nodes, encoder_output.size(2)).to(self.device)
3636
x[:n_tokens, :] = encoder_output.view(encoder_output.size(1), encoder_output.size(2))
37+
x[n_tokens:, :] = node_features
3738
graph_hidden = self.graph_encoder(x=x, adj=adj)
3839

3940
# TODO: Combine the graph representation with the seq_encoder final layer using mlp

models/lstm_to_lstm_full_training.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ def evaluate(seq2seq_model, eval_pairs, criterion, eval='val', graph=False):
2525
eval_pair = eval_pairs[i]
2626
input_tensor = eval_pair[0][0].to(device)
2727
adj_tensor = eval_pair[0][1].to(device)
28+
node_features = eval_pair[0][2].to(device)
2829
target_tensor = eval_pair[1].to(device)
2930

3031
output = seq2seq_model(sequence=input_tensor.view(-1), adj=adj_tensor,
31-
target=target_tensor.view(-1))
32+
target=target_tensor.view(-1), node_features=node_features)
3233
else:
3334
eval_pair = eval_pairs[i]
3435
input_tensor = eval_pair[0]
@@ -59,12 +60,13 @@ def evaluate(seq2seq_model, eval_pairs, criterion, eval='val', graph=False):
5960
return loss, f1, rouge_2, rouge_l
6061

6162

62-
def train(input_tensor, target_tensor, seq2seq_model, optimizer, criterion, graph, adj_tensor=None):
63+
def train(input_tensor, target_tensor, seq2seq_model, optimizer, criterion, graph,
64+
adj_tensor=None, node_features=None):
6365
optimizer.zero_grad()
6466

6567
if graph:
6668
output = seq2seq_model(sequence=input_tensor.view(-1), adj=adj_tensor,
67-
target=target_tensor.view(-1))
69+
target=target_tensor.view(-1), node_features=node_features)
6870
else:
6971
output = seq2seq_model(sequence=input_tensor.view(-1), target=target_tensor.view(-1))
7072

@@ -83,9 +85,9 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
8385
train_losses = []
8486
val_losses = []
8587

86-
# test_f1_scores = []
87-
# test_rouge_2_scores = []
88-
# test_rouge_l_scores = []
88+
val_f1_scores = []
89+
val_rouge_2_scores = []
90+
val_rouge_l_scores = []
8991

9092
print_loss_total = 0 # Reset every print_every
9193
plot_loss_total = 0 # Reset every plot_every
@@ -115,10 +117,11 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
115117
if graph:
116118
input_tensor = training_pair[0][0].to(device)
117119
adj_tensor = training_pair[0][1].to(device)
120+
node_features = training_pair[0][2].to(device)
118121
target_tensor = training_pair[1].to(device)
119122

120123
loss, pred = train(input_tensor, target_tensor, seq2seq_model, optimizer,
121-
criterion, adj_tensor=adj_tensor, graph=graph)
124+
criterion, adj_tensor=adj_tensor, graph=graph, node_features=node_features)
122125
else:
123126
input_tensor = training_pair[0]
124127
target_tensor = training_pair[1]
@@ -168,11 +171,12 @@ def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0
168171
val_losses.append(val_loss)
169172
# test_losses.append(test_loss)
170173

171-
# test_f1_scores.append(test_f1)
172-
# test_rouge_2_scores.append(test_rouge_2)
173-
# test_rouge_l_scores.append(test_rouge_l)
174+
val_f1_scores.append(val_f1)
175+
val_rouge_2_scores.append(val_rouge_2)
176+
val_rouge_l_scores.append(val_rouge_l)
174177

175-
pickle.dump([train_losses, val_losses],
176-
open(model_dir + 'res.pkl', 'wb'))
178+
pickle.dump([train_losses, val_losses, val_f1_scores, val_rouge_2_scores,
179+
val_rouge_l_scores],
180+
open('results/res.pkl', 'wb'))
177181

178182
plot_loss(train_losses, val_losses, file_path=model_dir + 'loss.jpg')

tokens_util.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,23 @@ def to_tokens(self, idxs):
3737
return np.array([self.index2word[idx] for idx in idxs])
3838

3939

40+
def read_data():
41+
data = pickle.load(open('data/methods_tokens_graphs2.pkl', 'rb'))
42+
# data = pickle.load(open('data/methods_tokens_graphs.pkl', 'rb'))
43+
methods_source = data['methods_source']
44+
methods_graphs = data['methods_graphs']
45+
methods_names = data['methods_names']
46+
47+
pairs = [((methods_source[i], methods_graphs[i]), methods_names[i]) for i in range(len(
48+
methods_source))]
49+
np.random.shuffle(pairs)
50+
51+
return pairs
52+
53+
4054
def read_tokens():
41-
data = pickle.load(open('data/methods_tokens_data.pkl', 'rb'))
55+
data = pickle.load(open('data/methods_tokens_graphs.pkl', 'rb'))
56+
# data = pickle.load(open('data/methods_tokens_data.pkl', 'rb'))
4257
# data = pickle.load(open('../data/methods_tokens_data.pkl', 'rb'))
4358
methods_source = data['methods_source']
4459
methods_names = data['methods_names']
@@ -75,19 +90,6 @@ def prepare_data(num_samples=None):
7590
return lang, pairs
7691

7792

78-
def read_data():
79-
data = pickle.load(open('data/methods_tokens_graphs.pkl', 'rb'))
80-
methods_source = data['methods_source']
81-
methods_graphs = data['methods_graphs']
82-
methods_names = data['methods_names']
83-
84-
pairs = [((methods_source[i], methods_graphs[i]), methods_names[i]) for i in range(len(
85-
methods_source))]
86-
np.random.shuffle(pairs)
87-
88-
return pairs
89-
90-
9193
def indexes_from_sentence_tokens(lang, sentence):
9294
return [lang.word2index[word] for word in sentence]
9395

@@ -117,9 +119,10 @@ def sparse_adj_from_edges(edges):
117119

118120
def tensors_from_pair_tokens_graph(pair, lang):
119121
input_tensor = tensor_from_sentence_tokens(lang, pair[0][0])
120-
input_adj = sparse_adj_from_edges(pair[0][1])
122+
input_adj = sparse_adj_from_edges(pair[0][1][0])
123+
node_features = torch.tensor(pair[0][1][1])
121124
target_tensor = tensor_from_sentence_tokens(lang, pair[1])
122-
return (input_tensor, input_adj), target_tensor
125+
return (input_tensor, input_adj, node_features), target_tensor
123126

124127

125128
def plot_loss(train_losses, val_losses, test_losses=None, file_path='plots/loss.jpg'):

0 commit comments

Comments
 (0)