Skip to content

Commit 0119455

Browse files
committed
Implement graph attention layer
1 parent fff1606 commit 0119455

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

models/gat_encoder.py

Whitespace-only changes.

models/graph_attention_layer.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class GraphAttentionLayer(nn.Module):
7+
"""
8+
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
9+
"""
10+
11+
def __init__(self, in_features, out_features, dropout=0.6, alpha=0.2, concat=True):
12+
super(GraphAttentionLayer, self).__init__()
13+
self.dropout = dropout
14+
self.in_features = in_features
15+
self.out_features = out_features
16+
self.alpha = alpha
17+
self.concat = concat
18+
19+
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
20+
nn.init.xavier_uniform_(self.W.data, gain=1.414)
21+
self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
22+
nn.init.xavier_uniform_(self.a.data, gain=1.414)
23+
24+
self.leakyrelu = nn.LeakyReLU(self.alpha)
25+
26+
def forward(self, input, adj):
27+
h = torch.mm(input, self.W)
28+
N = h.size(0)
29+
30+
a_input = torch.cat((h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)), dim=1).view(N, -1,
31+
2 * self.out_features)
32+
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
33+
34+
zero_vec = -9e15*torch.ones_like(e)
35+
attention = torch.where(adj > 0, e, zero_vec)
36+
attention = F.softmax(attention, dim=1)
37+
attention = F.dropout(attention, self.dropout, training=self.training)
38+
h_prime = torch.matmul(attention, h)
39+
40+
if self.concat:
41+
return F.elu(h_prime)
42+
else:
43+
return h_prime

0 commit comments

Comments
 (0)