Skip to content

Commit da491ee

Browse files
committed
Implement gat encode
1 parent 0119455 commit da491ee

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

models/gat_encoder.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
from models.graph_convolution import GraphConvolution
4+
from models.graph_attention_layer import GraphAttentionLayer
5+
import torch
6+
7+
8+
class GATEncoder(nn.Module):
9+
def __init__(self, num_features, hidden_size, dropout=0):
10+
super(GATEncoder, self).__init__()
11+
12+
self.gc1 = GraphConvolution(num_features, hidden_size)
13+
self.gc2 = GraphConvolution(hidden_size, hidden_size)
14+
15+
self.attention1 = GraphAttentionLayer(num_features, hidden_size)
16+
self.attention2 = GraphAttentionLayer(hidden_size, hidden_size)
17+
18+
self.dropout = dropout
19+
20+
def forward(self, x, adj):
21+
x = self.attention1(x, adj)
22+
x = self.attention1(x, adj, concat=False)
23+
24+
return x.mean(dim=0)
25+

0 commit comments

Comments
 (0)