-
Notifications
You must be signed in to change notification settings - Fork 89
Expand file tree
/
Copy pathgcn_on_cora.py
More file actions
86 lines (74 loc) · 2.69 KB
/
gcn_on_cora.py
File metadata and controls
86 lines (74 loc) · 2.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import time
from copy import deepcopy
import torch
import torch.optim as optim
import torch.nn.functional as F
from dhg import Graph, Hypergraph
from dhg.data import Cora, Pubmed, Citeseer
from dhg.models import GCN, GIN, HyperGCN, GraphSAGE
from dhg.random import set_seed
from dhg.metrics import GraphVertexClassificationEvaluator as Evaluator
def train(net, X, A, lbls, train_idx, optimizer, epoch):
net.train()
st = time.time()
optimizer.zero_grad()
outs = net(X, A)
outs, lbls = outs[train_idx], lbls[train_idx]
loss = F.cross_entropy(outs, lbls)
loss.backward()
optimizer.step()
print(f"Epoch: {epoch}, Time: {time.time()-st:.5f}s, Loss: {loss.item():.5f}")
return loss.item()
@torch.no_grad()
def infer(net, X, A, lbls, idx, test=False):
net.eval()
outs = net(X, A)
outs, lbls = outs[idx], lbls[idx]
if not test:
res = evaluator.validate(lbls, outs)
else:
res = evaluator.test(lbls, outs)
return res
if __name__ == "__main__":
set_seed(2022)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
evaluator = Evaluator(["accuracy", "f1_score", {"f1_score": {"average": "micro"}}])
data = Cora()
# data = Pubmed()
# data = Citeseer()
X, lbl = data["features"], data["labels"]
G = Graph(data["num_vertices"], data["edge_list"])
# G = Hypergraph(num_v, data["edge_list"])
train_mask = data["train_mask"]
val_mask = data["val_mask"]
test_mask = data["test_mask"]
net = GCN(data["dim_features"], 16, data["num_classes"])
# net = GraphSAGE(data["dim_features"], 16, data["num_classes"])
# net = HyperGCN(data["dim_features"], 16, data["num_classes"])
# net = GIN(data["dim_features"], 16, data["num_classes"], num_layers=5, train_eps=True)
optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)
X, lbl = X.to(device), lbl.to(device)
G = G.to(device)
net = net.to(device)
best_state = None
best_epoch, best_val = 0, 0
for epoch in range(200):
# train
train(net, X, G, lbl, train_mask, optimizer, epoch)
# validation
if epoch % 1 == 0:
with torch.no_grad():
val_res = infer(net, X, G, lbl, val_mask)
if val_res > best_val:
print(f"update best: {val_res:.5f}")
best_epoch = epoch
best_val = val_res
best_state = deepcopy(net.state_dict())
print("\ntrain finished!")
print(f"best val: {best_val:.5f}")
# test
print("test...")
net.load_state_dict(best_state)
res = infer(net, X, G, lbl, test_mask, test=True)
print(f"final result: epoch: {best_epoch}")
print(res)