-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
edge_pool.py
55 lines (49 loc) · 1.83 KB
/
edge_pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import (
EdgePooling,
GraphConv,
JumpingKnowledge,
global_mean_pool,
)
class EdgePool(torch.nn.Module):
def __init__(self, dataset, num_layers, hidden):
super().__init__()
self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean')
self.convs = torch.nn.ModuleList()
self.pools = torch.nn.ModuleList()
self.convs.extend([
GraphConv(hidden, hidden, aggr='mean')
for i in range(num_layers - 1)
])
self.pools.extend(
[EdgePooling(hidden) for i in range((num_layers) // 2)])
self.jump = JumpingKnowledge(mode='cat')
self.lin1 = Linear(num_layers * hidden, hidden)
self.lin2 = Linear(hidden, dataset.num_classes)
def reset_parameters(self):
self.conv1.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for pool in self.pools:
pool.reset_parameters()
self.lin1.reset_parameters()
self.lin2.reset_parameters()
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
xs = [global_mean_pool(x, batch)]
for i, conv in enumerate(self.convs):
x = F.relu(conv(x, edge_index))
xs += [global_mean_pool(x, batch)]
if i % 2 == 0 and i < len(self.convs) - 1:
pool = self.pools[i // 2]
x, edge_index, batch, _ = pool(x, edge_index, batch=batch)
x = self.jump(xs)
x = F.relu(self.lin1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)
return F.log_softmax(x, dim=-1)
def __repr__(self):
return self.__class__.__name__