forked from pyg-team/pytorch_geometric
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ogbn_proteins_deepgcn.py
144 lines (104 loc) · 4.47 KB
/
ogbn_proteins_deepgcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import torch.nn.functional as F
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
from torch.nn import LayerNorm, Linear, ReLU
from tqdm import tqdm
from torch_geometric.loader import RandomNodeLoader
from torch_geometric.nn import DeepGCNLayer, GENConv
from torch_geometric.utils import scatter
dataset = PygNodePropPredDataset('ogbn-proteins', root='../data')
splitted_idx = dataset.get_idx_split()
data = dataset[0]
data.node_species = None
data.y = data.y.to(torch.float)
# Initialize features of nodes by aggregating edge features.
row, col = data.edge_index
data.x = scatter(data.edge_attr, col, dim_size=data.num_nodes, reduce='sum')
# Set split indices to masks.
for split in ['train', 'valid', 'test']:
mask = torch.zeros(data.num_nodes, dtype=torch.bool)
mask[splitted_idx[split]] = True
data[f'{split}_mask'] = mask
train_loader = RandomNodeLoader(data, num_parts=40, shuffle=True,
num_workers=5)
test_loader = RandomNodeLoader(data, num_parts=5, num_workers=5)
class DeeperGCN(torch.nn.Module):
def __init__(self, hidden_channels, num_layers):
super().__init__()
self.node_encoder = Linear(data.x.size(-1), hidden_channels)
self.edge_encoder = Linear(data.edge_attr.size(-1), hidden_channels)
self.layers = torch.nn.ModuleList()
for i in range(1, num_layers + 1):
conv = GENConv(hidden_channels, hidden_channels, aggr='softmax',
t=1.0, learn_t=True, num_layers=2, norm='layer')
norm = LayerNorm(hidden_channels, elementwise_affine=True)
act = ReLU(inplace=True)
layer = DeepGCNLayer(conv, norm, act, block='res+', dropout=0.1,
ckpt_grad=i % 3)
self.layers.append(layer)
self.lin = Linear(hidden_channels, data.y.size(-1))
def forward(self, x, edge_index, edge_attr):
x = self.node_encoder(x)
edge_attr = self.edge_encoder(edge_attr)
x = self.layers[0].conv(x, edge_index, edge_attr)
for layer in self.layers[1:]:
x = layer(x, edge_index, edge_attr)
x = self.layers[0].act(self.layers[0].norm(x))
x = F.dropout(x, p=0.1, training=self.training)
return self.lin(x)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DeeperGCN(hidden_channels=64, num_layers=28).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()
evaluator = Evaluator('ogbn-proteins')
def train(epoch):
model.train()
pbar = tqdm(total=len(train_loader))
pbar.set_description(f'Training epoch: {epoch:04d}')
total_loss = total_examples = 0
for data in train_loader:
optimizer.zero_grad()
data = data.to(device)
out = model(data.x, data.edge_index, data.edge_attr)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
total_loss += float(loss) * int(data.train_mask.sum())
total_examples += int(data.train_mask.sum())
pbar.update(1)
pbar.close()
return total_loss / total_examples
@torch.no_grad()
def test():
model.eval()
y_true = {'train': [], 'valid': [], 'test': []}
y_pred = {'train': [], 'valid': [], 'test': []}
pbar = tqdm(total=len(test_loader))
pbar.set_description(f'Evaluating epoch: {epoch:04d}')
for data in test_loader:
data = data.to(device)
out = model(data.x, data.edge_index, data.edge_attr)
for split in y_true.keys():
mask = data[f'{split}_mask']
y_true[split].append(data.y[mask].cpu())
y_pred[split].append(out[mask].cpu())
pbar.update(1)
pbar.close()
train_rocauc = evaluator.eval({
'y_true': torch.cat(y_true['train'], dim=0),
'y_pred': torch.cat(y_pred['train'], dim=0),
})['rocauc']
valid_rocauc = evaluator.eval({
'y_true': torch.cat(y_true['valid'], dim=0),
'y_pred': torch.cat(y_pred['valid'], dim=0),
})['rocauc']
test_rocauc = evaluator.eval({
'y_true': torch.cat(y_true['test'], dim=0),
'y_pred': torch.cat(y_pred['test'], dim=0),
})['rocauc']
return train_rocauc, valid_rocauc, test_rocauc
for epoch in range(1, 1001):
loss = train(epoch)
train_rocauc, valid_rocauc, test_rocauc = test()
print(f'Loss: {loss:.4f}, Train: {train_rocauc:.4f}, '
f'Val: {valid_rocauc:.4f}, Test: {test_rocauc:.4f}')