Skip to content

Commit 7cdeda7

Browse files
committed
Combine graph hidden with encoder hidden
1 parent 57ca9a4 commit 7cdeda7

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

models/lstm_to_lstm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch.nn as nn
3+
import torch.functional as F
34

45

56
class Seq2Seq(nn.Module):
@@ -11,6 +12,7 @@ def __init__(self, encoder, decoder, device, graph_encoder=None, graph=False):
1112
self.decoder = decoder
1213
self.device = device
1314
self.graph = graph
15+
self.combine = nn.Linear(2 * encoder.hidden_size, encoder.hidden_size)
1416

1517
assert encoder.hidden_size == decoder.hidden_size, "Hidden dimensions of encoder and decoder " \
1618
"must be equal!"
@@ -37,9 +39,10 @@ def forward(self, sequence, target, adj=None, node_features=None):
3739
x[n_tokens:, :] = node_features
3840
graph_hidden = self.graph_encoder(x=x, adj=adj)
3941

40-
# TODO: Combine the graph representation with the seq_encoder final layer using mlp
42+
new_hidden = self.combine(torch.cat((graph_hidden, torch.squeeze(hidden[1]))))
43+
new_hidden = F.relu(new_hidden)
4144

42-
hidden = (graph_hidden.view(1, 1, graph_hidden.size(0)), hidden[1])
45+
hidden = (new_hidden.view(1, 1, new_hidden.size(0)), hidden[1])
4346

4447
# first input to the decoder is the <sos> tokens
4548
input = torch.tensor([[0]], device=self.device)

0 commit comments

Comments
 (0)