Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Placn Edits #5

Draft
wants to merge 27 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Fixes for labels
  • Loading branch information
Jae committed Aug 4, 2021
commit ccd3aec0765bae4440a63e7b473c79a6bbe6a1e0
6 changes: 3 additions & 3 deletions subgraph_extraction/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ def get_kge_embeddings(dataset, kge_model):
class SubgraphDataset(Dataset):
"""Extracted, labeled, subgraph dataset -- DGL Only"""

def __init__(self, db_path, db_name_pos, db_name_neg, raw_data_paths, included_relations=None, add_traspose_rels=False, num_neg_samples_per_link=1, use_kge_embeddings=False, dataset='', kge_model='', file_name=''):
def __init__(self, db_path, db_name_pos, db_name_neg, raw_data_paths, included_relations=None, add_traspose_rels=False, num_neg_samples_per_link=1, use_kge_embeddings=False, dataset='', kge_model='', file_name='', placn_size=20):

self.main_env = lmdb.open(db_path, readonly=True, max_dbs=3, lock=False)
self.db_pos = self.main_env.open_db(db_name_pos.encode())
self.db_neg = self.main_env.open_db(db_name_neg.encode())
self.node_features, self.kge_entity2id = get_kge_embeddings(dataset, kge_model) if use_kge_embeddings else (None, None)
self.num_neg_samples_per_link = num_neg_samples_per_link
self.file_name = file_name

self.placn_size=placn_size
ssp_graph, __, __, __, id2entity, id2relation = process_files(raw_data_paths, included_relations)
self.num_rels = len(ssp_graph)

Expand Down Expand Up @@ -158,7 +158,7 @@ def _prepare_subgraphs(self, nodes, r_label, n_labels):
def _prepare_features_placn(self, subgraph, n_labels, n_feats=None):
# One hot encode the node label feature and concat to n_featsure
n_nodes = subgraph.number_of_nodes()
label_feats = np.zeros((n_nodes, n_nodes))
label_feats = np.zeros((n_nodes, self.placn_size))
label_feats[np.array(np.arange(n_nodes)), n_labels] = 1
n_feats = np.concatenate((label_feats, n_feats), axis=1) if n_feats is not None else label_feats
subgraph.ndata['feat'] = torch.FloatTensor(n_feats)
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main(params):
add_traspose_rels=params.add_traspose_rels,
num_neg_samples_per_link=params.num_neg_samples_per_link,
use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset,
kge_model=params.kge_model, file_name=params.train_file)
kge_model=params.kge_model, file_name=params.train_file, placn_size=params.placn_subgraph_size)
valid = SubgraphDataset(params.db_path, 'valid_pos', 'valid_neg', params.file_paths,
add_traspose_rels=params.add_traspose_rels,
num_neg_samples_per_link=params.num_neg_samples_per_link,
Expand Down
2 changes: 1 addition & 1 deletion utils/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def ssp_multigraph_to_dgl(graph, n_feats=None):

def collate_dgl(samples):
# The input `samples` is a list of pairs
graphs_pos, g_labels_pos, r_labels_pos, graphs_negs, g_labels_negs, r_labels_negs = map(list, zip(*samples))
graphs_pos, g_labels_pos, r_labels_pos, graphs_negs, g_labels_negs, r_labels_negs = map(list, zip(*samples))
batched_graph_pos = dgl.batch(graphs_pos)

graphs_neg = [item for sublist in graphs_negs for item in sublist]
Expand Down