-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
rgat.py
64 lines (50 loc) · 2.02 KB
/
rgat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os.path as osp
import time
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Entities
from torch_geometric.nn import RGATConv
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities')
dataset = Entities(path, 'AIFB')
data = dataset[0]
data.x = torch.randn(data.num_nodes, 16)
class RGAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels,
num_relations):
super().__init__()
self.conv1 = RGATConv(in_channels, hidden_channels, num_relations)
self.conv2 = RGATConv(hidden_channels, hidden_channels, num_relations)
self.lin = torch.nn.Linear(hidden_channels, out_channels)
def forward(self, x, edge_index, edge_type):
x = self.conv1(x, edge_index, edge_type).relu()
x = self.conv2(x, edge_index, edge_type).relu()
x = self.lin(x)
return F.log_softmax(x, dim=-1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
model = RGAT(16, 16, dataset.num_classes, dataset.num_relations).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index, data.edge_type)
loss = F.nll_loss(out[data.train_idx], data.train_y)
loss.backward()
optimizer.step()
return float(loss)
@torch.no_grad()
def test():
model.eval()
pred = model(data.x, data.edge_index, data.edge_type).argmax(dim=-1)
train_acc = float((pred[data.train_idx] == data.train_y).float().mean())
test_acc = float((pred[data.test_idx] == data.test_y).float().mean())
return train_acc, test_acc
times = []
for epoch in range(1, 51):
start = time.time()
loss = train()
train_acc, test_acc = test()
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f} '
f'Test: {test_acc:.4f}')
times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")