forked from pyg-team/pytorch_geometric
-
Notifications
You must be signed in to change notification settings - Fork 0
/
reddit.py
116 lines (90 loc) · 4.01 KB
/
reddit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import copy
import os.path as osp
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')
dataset = Reddit(path)
# Already send node features/labels to GPU for faster access during sampling:
data = dataset[0].to(device, 'x', 'y')
kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}
train_loader = NeighborLoader(data, input_nodes=data.train_mask,
num_neighbors=[25, 10], shuffle=True, **kwargs)
subgraph_loader = NeighborLoader(copy.copy(data), input_nodes=None,
num_neighbors=[-1], shuffle=False, **kwargs)
# No need to maintain these features during evaluation:
del subgraph_loader.data.x, subgraph_loader.data.y
# Add global node index information.
subgraph_loader.data.num_nodes = data.num_nodes
subgraph_loader.data.n_id = torch.arange(data.num_nodes)
class SAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
self.convs.append(SAGEConv(hidden_channels, out_channels))
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < len(self.convs) - 1:
x = x.relu_()
x = F.dropout(x, p=0.5, training=self.training)
return x
@torch.no_grad()
def inference(self, x_all, subgraph_loader):
pbar = tqdm(total=len(subgraph_loader.dataset) * len(self.convs))
pbar.set_description('Evaluating')
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch:
for i, conv in enumerate(self.convs):
xs = []
for batch in subgraph_loader:
x = x_all[batch.n_id.to(x_all.device)].to(device)
x = conv(x, batch.edge_index.to(device))
if i < len(self.convs) - 1:
x = x.relu_()
xs.append(x[:batch.batch_size].cpu())
pbar.update(batch.batch_size)
x_all = torch.cat(xs, dim=0)
pbar.close()
return x_all
model = SAGE(dataset.num_features, 256, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train(epoch):
model.train()
pbar = tqdm(total=int(len(train_loader.dataset)))
pbar.set_description(f'Epoch {epoch:02d}')
total_loss = total_correct = total_examples = 0
for batch in train_loader:
optimizer.zero_grad()
y = batch.y[:batch.batch_size]
y_hat = model(batch.x, batch.edge_index.to(device))[:batch.batch_size]
loss = F.cross_entropy(y_hat, y)
loss.backward()
optimizer.step()
total_loss += float(loss) * batch.batch_size
total_correct += int((y_hat.argmax(dim=-1) == y).sum())
total_examples += batch.batch_size
pbar.update(batch.batch_size)
pbar.close()
return total_loss / total_examples, total_correct / total_examples
@torch.no_grad()
def test():
model.eval()
y_hat = model.inference(data.x, subgraph_loader).argmax(dim=-1)
y = data.y.to(y_hat.device)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
accs.append(int((y_hat[mask] == y[mask]).sum()) / int(mask.sum()))
return accs
for epoch in range(1, 11):
loss, acc = train(epoch)
print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}')
train_acc, val_acc, test_acc = test()
print(f'Epoch: {epoch:02d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
f'Test: {test_acc:.4f}')