diff --git a/examples/linkproppred/tkgl-icews/cen.py b/examples/linkproppred/tkgl-icews/cen.py index 97d0055..6fb5ace 100644 --- a/examples/linkproppred/tkgl-icews/cen.py +++ b/examples/linkproppred/tkgl-icews/cen.py @@ -20,7 +20,8 @@ sys.path.append(tgb_modules_path) from tgb_modules.rrgcn import RecurrentRGCNCEN from tgb.utils.utils import set_random_seed, split_by_time, save_results -from tgb_modules.tkg_utils import get_args_cen, build_sub_graph, reformat_ts +from tgb_modules.tkg_utils import get_args_cen, reformat_ts +from tgb_modules.tkg_utils_dgl import build_sub_graph from tgb.linkproppred.evaluate import Evaluator from tgb.linkproppred.dataset import LinkPropPredDataset diff --git a/examples/linkproppred/tkgl-icews/regcn.py b/examples/linkproppred/tkgl-icews/regcn.py index b2e49e2..2785ade 100644 --- a/examples/linkproppred/tkgl-icews/regcn.py +++ b/examples/linkproppred/tkgl-icews/regcn.py @@ -20,7 +20,8 @@ sys.path.append(tgb_modules_path) from tgb_modules.rrgcn import RecurrentRGCNREGCN from tgb.utils.utils import set_random_seed, split_by_time, save_results -from tgb_modules.tkg_utils import get_args_regcn, build_sub_graph, reformat_ts +from tgb_modules.tkg_utils import get_args_regcn, reformat_ts +from tgb_modules.tkg_utils_dgl import build_sub_graph from tgb.linkproppred.evaluate import Evaluator from tgb.linkproppred.dataset import LinkPropPredDataset import json diff --git a/examples/linkproppred/tkgl-polecat/cen.py b/examples/linkproppred/tkgl-polecat/cen.py index b604927..4e6a219 100644 --- a/examples/linkproppred/tkgl-polecat/cen.py +++ b/examples/linkproppred/tkgl-polecat/cen.py @@ -20,7 +20,8 @@ sys.path.append(tgb_modules_path) from tgb_modules.rrgcn import RecurrentRGCNCEN from tgb.utils.utils import set_random_seed, split_by_time, save_results -from tgb_modules.tkg_utils import get_args_cen, build_sub_graph, reformat_ts +from tgb_modules.tkg_utils import get_args_cen, reformat_ts +from tgb_modules.tkg_utils_dgl import build_sub_graph from tgb.linkproppred.evaluate import Evaluator from tgb.linkproppred.dataset import LinkPropPredDataset diff --git a/examples/linkproppred/tkgl-polecat/regcn.py b/examples/linkproppred/tkgl-polecat/regcn.py index 80e3277..6014d68 100644 --- a/examples/linkproppred/tkgl-polecat/regcn.py +++ b/examples/linkproppred/tkgl-polecat/regcn.py @@ -20,7 +20,8 @@ sys.path.append(tgb_modules_path) from tgb_modules.rrgcn import RecurrentRGCNREGCN from tgb.utils.utils import set_random_seed, split_by_time, save_results -from tgb_modules.tkg_utils import get_args_regcn, build_sub_graph, reformat_ts +from tgb_modules.tkg_utils import get_args_regcn, reformat_ts +from tgb_modules.tkg_utils_dgl import build_sub_graph from tgb.linkproppred.evaluate import Evaluator from tgb.linkproppred.dataset import LinkPropPredDataset import json diff --git a/examples/linkproppred/tkgl-smallpedia/cen.py b/examples/linkproppred/tkgl-smallpedia/cen.py index 0fa5829..03f3b0f 100644 --- a/examples/linkproppred/tkgl-smallpedia/cen.py +++ b/examples/linkproppred/tkgl-smallpedia/cen.py @@ -20,7 +20,8 @@ sys.path.append(tgb_modules_path) from tgb_modules.rrgcn import RecurrentRGCNCEN from tgb.utils.utils import set_random_seed, split_by_time, save_results -from tgb_modules.tkg_utils import get_args_cen, build_sub_graph, reformat_ts +from tgb_modules.tkg_utils import get_args_cen, reformat_ts +from tgb_modules.tkg_utils_dgl import build_sub_graph from tgb.linkproppred.evaluate import Evaluator from tgb.linkproppred.dataset import LinkPropPredDataset diff --git a/examples/linkproppred/tkgl-smallpedia/regcn.py b/examples/linkproppred/tkgl-smallpedia/regcn.py index 0513f6d..8236319 100644 --- a/examples/linkproppred/tkgl-smallpedia/regcn.py +++ b/examples/linkproppred/tkgl-smallpedia/regcn.py @@ -20,7 +20,8 @@ sys.path.append(tgb_modules_path) from tgb_modules.rrgcn import RecurrentRGCNREGCN from tgb.utils.utils import set_random_seed, split_by_time, save_results -from tgb_modules.tkg_utils import get_args_regcn, build_sub_graph, reformat_ts +from tgb_modules.tkg_utils import get_args_regcn, reformat_ts +from tgb_modules.tkg_utils_dgl import build_sub_graph from tgb.linkproppred.evaluate import Evaluator from tgb.linkproppred.dataset import LinkPropPredDataset import json diff --git a/examples/linkproppred/tkgl-wikidata/regcn.py b/examples/linkproppred/tkgl-wikidata/regcn.py index 19aea92..d4282d4 100644 --- a/examples/linkproppred/tkgl-wikidata/regcn.py +++ b/examples/linkproppred/tkgl-wikidata/regcn.py @@ -20,7 +20,8 @@ sys.path.append(tgb_modules_path) from tgb_modules.rrgcn import RecurrentRGCNREGCN from tgb.utils.utils import set_random_seed, split_by_time, save_results -from tgb_modules.tkg_utils import get_args_regcn, build_sub_graph, reformat_ts +from tgb_modules.tkg_utils import get_args_regcn, reformat_ts +from tgb_modules.tkg_utils_dgl import build_sub_graph from tgb.linkproppred.evaluate import Evaluator from tgb.linkproppred.dataset import LinkPropPredDataset import json diff --git a/examples/linkproppred/tkgl-yago/cen.py b/examples/linkproppred/tkgl-yago/cen.py index 4fae465..46caeb5 100644 --- a/examples/linkproppred/tkgl-yago/cen.py +++ b/examples/linkproppred/tkgl-yago/cen.py @@ -20,7 +20,8 @@ sys.path.append(tgb_modules_path) from tgb_modules.rrgcn import RecurrentRGCNCEN from tgb.utils.utils import set_random_seed, split_by_time, save_results -from tgb_modules.tkg_utils import get_args_cen, build_sub_graph, reformat_ts +from tgb_modules.tkg_utils import get_args_cen, reformat_ts +from tgb_modules.tkg_utils_dgl import build_sub_graph from tgb.linkproppred.evaluate import Evaluator from tgb.linkproppred.dataset import LinkPropPredDataset diff --git a/examples/linkproppred/tkgl-yago/regcn.py b/examples/linkproppred/tkgl-yago/regcn.py index ff502b1..29206ff 100644 --- a/examples/linkproppred/tkgl-yago/regcn.py +++ b/examples/linkproppred/tkgl-yago/regcn.py @@ -20,7 +20,8 @@ sys.path.append(tgb_modules_path) from tgb_modules.rrgcn import RecurrentRGCNREGCN from tgb.utils.utils import set_random_seed, split_by_time, save_results -from tgb_modules.tkg_utils import get_args_regcn, build_sub_graph, reformat_ts +from tgb_modules.tkg_utils import get_args_regcn, reformat_ts +from tgb_modules.tkg_utils_dgl import build_sub_graph from tgb.linkproppred.evaluate import Evaluator from tgb.linkproppred.dataset import LinkPropPredDataset import json diff --git a/tgb_modules/tkg_utils_dgl.py b/tgb_modules/tkg_utils_dgl.py new file mode 100644 index 0000000..bcc64b8 --- /dev/null +++ b/tgb_modules/tkg_utils_dgl.py @@ -0,0 +1,48 @@ + +import dgl +import torch +import numpy as np + + +def build_sub_graph(num_nodes, num_rels, triples, use_cuda, gpu, mode='dyn'): + """ + https://github.com/Lee-zix/CEN/blob/main/rgcn/utils.py + :param node_id: node id in the large graph + :param num_rels: number of relation + :param src: relabeled src id + :param rel: original rel id + :param dst: relabeled dst id + :param use_cuda: + :return: + """ + def comp_deg_norm(g): + in_deg = g.in_degrees(range(g.number_of_nodes())).float() + in_deg[torch.nonzero(in_deg == 0).view(-1)] = 1 + norm = 1.0 / in_deg + return norm + + src, rel, dst = triples.transpose() + if mode =='static': + src, dst = np.concatenate((src, dst)), np.concatenate((dst, src)) + rel = np.concatenate((rel, rel + num_rels)) + g = dgl.DGLGraph() + g.add_nodes(num_nodes) + #g.ndata['original_id'] = np.unique(np.concatenate((np.unique(triples[:,0]), np.unique(triples[:,2])))) + g.add_edges(src, dst) + norm = comp_deg_norm(g) + #node_id =torch.arange(0, g.num_nodes(), dtype=torch.long).view(-1, 1) #updated to deal with the fact that ot only the first k nodes of our graph have static infos + node_id = torch.arange(0, num_nodes, dtype=torch.long).view(-1, 1) + g.ndata.update({'id': node_id, 'norm': norm.view(-1, 1)}) + g.apply_edges(lambda edges: {'norm': edges.dst['norm'] * edges.src['norm']}) + g.edata['type'] = torch.LongTensor(rel) + + + uniq_r, r_len, r_to_e = r2e(triples, num_rels) + g.uniq_r = uniq_r + g.r_to_e = r_to_e + g.r_len = r_len + + if use_cuda: + g = g.to(gpu) + g.r_to_e = torch.from_numpy(np.array(r_to_e)) + return g