-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
kge_fb15k_237.py
89 lines (73 loc) · 2.52 KB
/
kge_fb15k_237.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import os.path as osp
import torch
import torch.optim as optim
from torch_geometric.datasets import FB15k_237
from torch_geometric.nn import ComplEx, DistMult, RotatE, TransE
model_map = {
'transe': TransE,
'complex': ComplEx,
'distmult': DistMult,
'rotate': RotatE,
}
parser = argparse.ArgumentParser()
parser.add_argument('--model', choices=model_map.keys(), type=str.lower,
required=True)
args = parser.parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'FB15k')
train_data = FB15k_237(path, split='train')[0].to(device)
val_data = FB15k_237(path, split='val')[0].to(device)
test_data = FB15k_237(path, split='test')[0].to(device)
model_arg_map = {'rotate': {'margin': 9.0}}
model = model_map[args.model](
num_nodes=train_data.num_nodes,
num_relations=train_data.num_edge_types,
hidden_channels=50,
**model_arg_map.get(args.model, {}),
).to(device)
loader = model.loader(
head_index=train_data.edge_index[0],
rel_type=train_data.edge_type,
tail_index=train_data.edge_index[1],
batch_size=1000,
shuffle=True,
)
optimizer_map = {
'transe': optim.Adam(model.parameters(), lr=0.01),
'complex': optim.Adagrad(model.parameters(), lr=0.001, weight_decay=1e-6),
'distmult': optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-6),
'rotate': optim.Adam(model.parameters(), lr=1e-3),
}
optimizer = optimizer_map[args.model]
def train():
model.train()
total_loss = total_examples = 0
for head_index, rel_type, tail_index in loader:
optimizer.zero_grad()
loss = model.loss(head_index, rel_type, tail_index)
loss.backward()
optimizer.step()
total_loss += float(loss) * head_index.numel()
total_examples += head_index.numel()
return total_loss / total_examples
@torch.no_grad()
def test(data):
model.eval()
return model.test(
head_index=data.edge_index[0],
rel_type=data.edge_type,
tail_index=data.edge_index[1],
batch_size=20000,
k=10,
)
for epoch in range(1, 501):
loss = train()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
if epoch % 25 == 0:
rank, mrr, hits = test(val_data)
print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, '
f'Val MRR: {mrr:.4f}, Val Hits@10: {hits:.4f}')
rank, mrr, hits_at_10 = test(test_data)
print(f'Test Mean Rank: {rank:.2f}, Test MRR: {mrr:.4f}, '
f'Test Hits@10: {hits_at_10:.4f}')