-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
label_prop.py
32 lines (25 loc) · 931 Bytes
/
label_prop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os.path as osp
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
import torch_geometric.transforms as T
from torch_geometric.nn import LabelPropagation
root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'OGB')
dataset = PygNodePropPredDataset(
'ogbn-arxiv', root, transform=T.Compose([
T.ToUndirected(),
T.ToSparseTensor(),
]))
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name='ogbn-arxiv')
data = dataset[0]
model = LabelPropagation(num_layers=3, alpha=0.9)
out = model(data.y, data.adj_t, mask=split_idx['train'])
y_pred = out.argmax(dim=-1, keepdim=True)
val_acc = evaluator.eval({
'y_true': data.y[split_idx['valid']],
'y_pred': y_pred[split_idx['valid']],
})['acc']
test_acc = evaluator.eval({
'y_true': data.y[split_idx['test']],
'y_pred': y_pred[split_idx['test']],
})['acc']
print(f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')