Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
vardaan123 committed Oct 22, 2023
1 parent 570d39e commit caf97de
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 413 deletions.
1 change: 0 additions & 1 deletion dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def __init__(self, dataset_path, split, neigh_size, sampling_type, graph_connect

# val data can be based on train graph or eval graph for MLM; while for link prediction, it is always train graph
self.adj_list, _ = get_adj_and_degrees(self.num_nodes, self.train_data)
# self.adj_list = get_adj(self.num_nodes, self.train_data)

# Code credit: https://github.com/kkteru/grail
# Construct the list of adjacency matrix each corresponding to each relation. Note that this is constructed only from the train data.
Expand Down
1 change: 0 additions & 1 deletion dataset_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def __init__(self, dataset_path, split, mode, sample_size, neigh_size, sampling_

print(self.db_path)

# @profile
def __getitem__(self, idx):
edge = self.triples[idx]
# print('edge = {}'.format(edge))
Expand Down
17 changes: 3 additions & 14 deletions utils/dgl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ def _bfs_relational(adj, adj_list, roots, max_nodes_per_hop=None, n_neigh_per_no
Modified from dgl.contrib.data.knowledge_graph to accomodate node sampling
adj_list: contains (edge_id, node_id) pairs
"""
# print('max_nodes_per_hop = {}'.format(max_nodes_per_hop))

visited = set()
current_lvl = set(roots)

Expand All @@ -23,37 +21,28 @@ def _bfs_relational(adj, adj_list, roots, max_nodes_per_hop=None, n_neigh_per_no
for v in current_lvl:
visited.add(v)

# get all neighbors in next hop
# next_lvl = _get_neighbors(adj, adj_list, current_lvl) # next_level should contain (node id, edge_id) pairs

# TODO: convert it to use adj_list instead of adj matrix
next_lvl = set()
for node in current_lvl:
if n_neigh_per_node:
n_edge_samples = min(n_neigh_per_node, len(adj_list[node]))
else:
n_edge_samples = len(adj_list[node])

# print('n_edge_samples = {}'.format(n_edge_samples))

# next_lvl.update([x for x in adj_list[node] if x[1] not in visited])
# remove all nodes already covered
node_samples = [x for x in random.sample(adj_list[node].tolist(), n_edge_samples) if x[1] not in visited]
node_samples = [tuple(x) for x in node_samples]

# node_samples = [x for x in adj_list[node] if x[1] not in visited]

next_lvl.update(node_samples)

# remove all nodes already covered
# next_lvl -= visited # set difference # implement set difference based on node ids

# TODO: support max_nodes_per_hop
# support max_nodes_per_hop
if max_nodes_per_hop and max_nodes_per_hop < len(next_lvl):
next_lvl = set(random.sample(next_lvl, max_nodes_per_hop))

yield next_lvl

# TODO: select only node ids from next_lvl
# select only node ids from next_lvl
current_lvl = set.union(set([x[1] for x in next_lvl]))


Expand Down
Loading

0 comments on commit caf97de

Please sign in to comment.