File tree Expand file tree Collapse file tree 1 file changed +25
-0
lines changed Expand file tree Collapse file tree 1 file changed +25
-0
lines changed Original file line number Diff line number Diff line change
1
+ import torch .nn as nn
2
+ import torch .nn .functional as F
3
+ from models .graph_convolution import GraphConvolution
4
+ from models .graph_attention_layer import GraphAttentionLayer
5
+ import torch
6
+
7
+
8
+ class GATEncoder (nn .Module ):
9
+ def __init__ (self , num_features , hidden_size , dropout = 0 ):
10
+ super (GATEncoder , self ).__init__ ()
11
+
12
+ self .gc1 = GraphConvolution (num_features , hidden_size )
13
+ self .gc2 = GraphConvolution (hidden_size , hidden_size )
14
+
15
+ self .attention1 = GraphAttentionLayer (num_features , hidden_size )
16
+ self .attention2 = GraphAttentionLayer (hidden_size , hidden_size )
17
+
18
+ self .dropout = dropout
19
+
20
+ def forward (self , x , adj ):
21
+ x = self .attention1 (x , adj )
22
+ x = self .attention1 (x , adj , concat = False )
23
+
24
+ return x .mean (dim = 0 )
25
+
You can’t perform that action at this time.
0 commit comments