Skip to content

Commit

Permalink
move the dgl functions out of tkg_utils to tkg_utils_dgl and update i…
Browse files Browse the repository at this point in the history
…mports in cen and regcn
  • Loading branch information
JuliaGast committed May 31, 2024
1 parent c5042aa commit c96805f
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 9 deletions.
3 changes: 2 additions & 1 deletion examples/linkproppred/tkgl-icews/cen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
sys.path.append(tgb_modules_path)
from tgb_modules.rrgcn import RecurrentRGCNCEN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from tgb_modules.tkg_utils import get_args_cen, build_sub_graph, reformat_ts
from tgb_modules.tkg_utils import get_args_cen, reformat_ts
from tgb_modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset

Expand Down
3 changes: 2 additions & 1 deletion examples/linkproppred/tkgl-icews/regcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
sys.path.append(tgb_modules_path)
from tgb_modules.rrgcn import RecurrentRGCNREGCN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from tgb_modules.tkg_utils import get_args_regcn, build_sub_graph, reformat_ts
from tgb_modules.tkg_utils import get_args_regcn, reformat_ts
from tgb_modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
import json
Expand Down
3 changes: 2 additions & 1 deletion examples/linkproppred/tkgl-polecat/cen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
sys.path.append(tgb_modules_path)
from tgb_modules.rrgcn import RecurrentRGCNCEN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from tgb_modules.tkg_utils import get_args_cen, build_sub_graph, reformat_ts
from tgb_modules.tkg_utils import get_args_cen, reformat_ts
from tgb_modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset

Expand Down
3 changes: 2 additions & 1 deletion examples/linkproppred/tkgl-polecat/regcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
sys.path.append(tgb_modules_path)
from tgb_modules.rrgcn import RecurrentRGCNREGCN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from tgb_modules.tkg_utils import get_args_regcn, build_sub_graph, reformat_ts
from tgb_modules.tkg_utils import get_args_regcn, reformat_ts
from tgb_modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
import json
Expand Down
3 changes: 2 additions & 1 deletion examples/linkproppred/tkgl-smallpedia/cen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
sys.path.append(tgb_modules_path)
from tgb_modules.rrgcn import RecurrentRGCNCEN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from tgb_modules.tkg_utils import get_args_cen, build_sub_graph, reformat_ts
from tgb_modules.tkg_utils import get_args_cen, reformat_ts
from tgb_modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset

Expand Down
3 changes: 2 additions & 1 deletion examples/linkproppred/tkgl-smallpedia/regcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
sys.path.append(tgb_modules_path)
from tgb_modules.rrgcn import RecurrentRGCNREGCN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from tgb_modules.tkg_utils import get_args_regcn, build_sub_graph, reformat_ts
from tgb_modules.tkg_utils import get_args_regcn, reformat_ts
from tgb_modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
import json
Expand Down
3 changes: 2 additions & 1 deletion examples/linkproppred/tkgl-wikidata/regcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
sys.path.append(tgb_modules_path)
from tgb_modules.rrgcn import RecurrentRGCNREGCN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from tgb_modules.tkg_utils import get_args_regcn, build_sub_graph, reformat_ts
from tgb_modules.tkg_utils import get_args_regcn, reformat_ts
from tgb_modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
import json
Expand Down
3 changes: 2 additions & 1 deletion examples/linkproppred/tkgl-yago/cen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
sys.path.append(tgb_modules_path)
from tgb_modules.rrgcn import RecurrentRGCNCEN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from tgb_modules.tkg_utils import get_args_cen, build_sub_graph, reformat_ts
from tgb_modules.tkg_utils import get_args_cen, reformat_ts
from tgb_modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset

Expand Down
3 changes: 2 additions & 1 deletion examples/linkproppred/tkgl-yago/regcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
sys.path.append(tgb_modules_path)
from tgb_modules.rrgcn import RecurrentRGCNREGCN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from tgb_modules.tkg_utils import get_args_regcn, build_sub_graph, reformat_ts
from tgb_modules.tkg_utils import get_args_regcn, reformat_ts
from tgb_modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
import json
Expand Down
48 changes: 48 additions & 0 deletions tgb_modules/tkg_utils_dgl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@

import dgl
import torch
import numpy as np


def build_sub_graph(num_nodes, num_rels, triples, use_cuda, gpu, mode='dyn'):
"""
https://github.com/Lee-zix/CEN/blob/main/rgcn/utils.py
:param node_id: node id in the large graph
:param num_rels: number of relation
:param src: relabeled src id
:param rel: original rel id
:param dst: relabeled dst id
:param use_cuda:
:return:
"""
def comp_deg_norm(g):
in_deg = g.in_degrees(range(g.number_of_nodes())).float()
in_deg[torch.nonzero(in_deg == 0).view(-1)] = 1
norm = 1.0 / in_deg
return norm

src, rel, dst = triples.transpose()
if mode =='static':
src, dst = np.concatenate((src, dst)), np.concatenate((dst, src))
rel = np.concatenate((rel, rel + num_rels))
g = dgl.DGLGraph()
g.add_nodes(num_nodes)
#g.ndata['original_id'] = np.unique(np.concatenate((np.unique(triples[:,0]), np.unique(triples[:,2]))))
g.add_edges(src, dst)
norm = comp_deg_norm(g)
#node_id =torch.arange(0, g.num_nodes(), dtype=torch.long).view(-1, 1) #updated to deal with the fact that ot only the first k nodes of our graph have static infos
node_id = torch.arange(0, num_nodes, dtype=torch.long).view(-1, 1)
g.ndata.update({'id': node_id, 'norm': norm.view(-1, 1)})
g.apply_edges(lambda edges: {'norm': edges.dst['norm'] * edges.src['norm']})
g.edata['type'] = torch.LongTensor(rel)


uniq_r, r_len, r_to_e = r2e(triples, num_rels)
g.uniq_r = uniq_r
g.r_to_e = r_to_e
g.r_len = r_len

if use_cuda:
g = g.to(gpu)
g.r_to_e = torch.from_numpy(np.array(r_to_e))
return g

0 comments on commit c96805f

Please sign in to comment.