-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
pna.py
115 lines (90 loc) · 3.79 KB
/
pna.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os.path as osp
import torch
import torch.nn.functional as F
from torch.nn import Embedding, Linear, ModuleList, ReLU, Sequential
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch_geometric
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import BatchNorm, PNAConv, global_add_pool
from torch_geometric.utils import degree
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ZINC')
train_dataset = ZINC(path, subset=True, split='train')
val_dataset = ZINC(path, subset=True, split='val')
test_dataset = ZINC(path, subset=True, split='test')
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)
test_loader = DataLoader(test_dataset, batch_size=128)
# Compute the maximum in-degree in the training data.
max_degree = -1
for data in train_dataset:
d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
max_degree = max(max_degree, int(d.max()))
# Compute the in-degree histogram tensor
deg = torch.zeros(max_degree + 1, dtype=torch.long)
for data in train_dataset:
d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
deg += torch.bincount(d, minlength=deg.numel())
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.node_emb = Embedding(21, 75)
self.edge_emb = Embedding(4, 50)
aggregators = ['mean', 'min', 'max', 'std']
scalers = ['identity', 'amplification', 'attenuation']
self.convs = ModuleList()
self.batch_norms = ModuleList()
for _ in range(4):
conv = PNAConv(in_channels=75, out_channels=75,
aggregators=aggregators, scalers=scalers, deg=deg,
edge_dim=50, towers=5, pre_layers=1, post_layers=1,
divide_input=False)
self.convs.append(conv)
self.batch_norms.append(BatchNorm(75))
self.mlp = Sequential(Linear(75, 50), ReLU(), Linear(50, 25), ReLU(),
Linear(25, 1))
def forward(self, x, edge_index, edge_attr, batch):
x = self.node_emb(x.squeeze())
edge_attr = self.edge_emb(edge_attr)
for conv, batch_norm in zip(self.convs, self.batch_norms):
x = F.relu(batch_norm(conv(x, edge_index, edge_attr)))
x = global_add_pool(x, batch)
return self.mlp(x)
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
min_lr=0.00001)
def train(epoch):
model.train()
total_loss = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
out = model(data.x, data.edge_index, data.edge_attr, data.batch)
loss = (out.squeeze() - data.y).abs().mean()
loss.backward()
total_loss += loss.item() * data.num_graphs
optimizer.step()
return total_loss / len(train_loader.dataset)
@torch.no_grad()
def test(loader):
model.eval()
total_error = 0
for data in loader:
data = data.to(device)
out = model(data.x, data.edge_index, data.edge_attr, data.batch)
total_error += (out.squeeze() - data.y).abs().sum().item()
return total_error / len(loader.dataset)
for epoch in range(1, 301):
loss = train(epoch)
val_mae = test(val_loader)
test_mae = test(test_loader)
scheduler.step(val_mae)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
f'Test: {test_mae:.4f}')