Skip to content

Commit

Permalink
Merge pull request #52 from JuliaGast/andy_new
Browse files Browse the repository at this point in the history
updates for merging into main TGB
  • Loading branch information
JuliaGast authored Jun 24, 2024
2 parents bd91148 + 20bb069 commit 5f6a899
Show file tree
Hide file tree
Showing 12 changed files with 293 additions and 435 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
!requirements*.txt
get_croissant.py
#dataset
stats_figures/
figs/
Expand Down
4 changes: 4 additions & 0 deletions docs/api/tgb.linkproppred.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@
::: tgb.linkproppred.evaluate
::: tgb.linkproppred.negative_sampler
::: tgb.linkproppred.negative_generator
::: tgb.linkproppred.tkg_negative_generator
::: tgb.linkproppred.tkg_negative_sampler
::: tgb.linkproppred.thg_negative_generator
::: tgb.linkproppred.thg_negative_sampler
1 change: 1 addition & 0 deletions examples/linkproppred/thgl-github/recurrencybaseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
organization={International Joint Conferences on Artificial Intelligence Organization}
}
python recurrencybaseline.py --seed 1 --num_processes 1 -tr False
"""

## imports
Expand Down
File renamed without changes.
550 changes: 275 additions & 275 deletions stats_figures/create_relation_figures.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions tgb/linkproppred/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def _eval_hits_and_mrr(self, y_pred_pos, y_pred_neg, type_info, k_value):

else:
y_pred_pos = y_pred_pos.reshape(-1, 1)
optimistic_rank = (y_pred_neg >= y_pred_pos).sum(axis=1)
pessimistic_rank = (y_pred_neg > y_pred_pos).sum(axis=1)
optimistic_rank = (y_pred_neg > y_pred_pos).sum(axis=1)
pessimistic_rank = (y_pred_neg >= y_pred_pos).sum(axis=1)
ranking_list = 0.5 * (optimistic_rank + pessimistic_rank) + 1
hitsK_list = (ranking_list <= k_value).astype(np.float32)
mrr_list = 1./ranking_list.astype(np.float32)
Expand Down
12 changes: 5 additions & 7 deletions tgb/linkproppred/thg_negative_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
edge_data: TemporalData = None,
) -> None:
r"""
Negative Edge Sampler class
Negative Edge Generator class for Temporal Heterogeneous Graphs
this is a class for generating negative samples for a specific datasets
the set of the positive samples are provided, the negative samples are generated with specific strategies
and are saved for consistent evaluation across different methods
Expand All @@ -39,11 +39,10 @@ def __init__(
first_node_id: the first node id
last_node_id: the last node id
node_type: the node type of each node
num_neg_e: number of negative edges being generated per each positive edge
strategy: specifies which strategy should be used for generating the negatives
rnd_seed: random seed for reproducibility
edge_data: the positive edges to generate the negatives for, assuming sorted temporally
strategy: the strategy to generate negative samples
num_neg_e: number of negative samples to generate
rnd_seed: random seed
edge_data: the edge data object containing the positive edges
Returns:
None
"""
Expand Down Expand Up @@ -72,7 +71,6 @@ def get_destinations_based_on_node_type(self,
node_type: np.ndarray) -> dict:
r"""
get the destination node id arrays based on the node type
Parameters:
first_node_id: the first node id
last_node_id: the last node id
Expand Down
12 changes: 1 addition & 11 deletions tgb/linkproppred/thg_negative_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
r"""
Negative Edge Sampler
Loads and query the negative batches based on the positive batches provided.
constructor for the negative edge sampler class
constructor for the negative edge sampler class
Parameters:
dataset_name: name of the dataset
Expand Down Expand Up @@ -124,16 +124,6 @@ def query_batch(self,
neg_samples.append(
neg_d_arr
)

# conflict_set, d_node_type = conflict_dict[(pos_t, pos_s, e_type)]

# all_dst = self.node_type_dict[d_node_type]
# # filtered_all_dst = np.delete(all_dst, conflict_set, axis=0)
# filtered_all_dst = np.setdiff1d(all_dst, conflict_set)
# neg_d_arr = filtered_all_dst
# neg_samples.append(
# neg_d_arr
# )

#? can't convert to numpy array due to different lengths of negative samples
return neg_samples
132 changes: 3 additions & 129 deletions tgb/linkproppred/tkg_negative_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,8 @@ def __init__(
edge_data: TemporalData = None,
) -> None:
r"""
Negative Edge Sampler class
this is a class for generating negative samples for a specific datasets
the set of the positive samples are provided, the negative samples are generated with specific strategies
and are saved for consistent evaluation across different methods
negative edges are sampled with 'oen_vs_many' strategy.
it is assumed that the destination nodes are indexed sequentially with 'first_dst_id'
and 'last_dst_id' being the first and last index, respectively.
Negative Edge Generator class for Temporal Knowledge Graphs
constructor for the negative edge generator class
Parameters:
dataset_name: name of the dataset
Expand Down Expand Up @@ -121,13 +116,6 @@ def generate_dst_dict(self, edge_data: TemporalData, dst_name: str) -> dict:
edge_type_size = []
for key in dst_track_dict:
dst = np.array(list(dst_track_dict[key].keys()))
# #* if there are too few dst, sample up to 1000
# if len(dst) < 1000:
# dst_sampled = np.random.choice(np.arange(min_dst_idx, max_dst_idx+1), 1000, replace=False)
# while np.intersect1d(dst, dst_sampled).shape[0] != 0:
# dst_sampled = np.random.choice(np.arange(min_dst_idx, max_dst_idx+1), 1000, replace=False)
# dst_sampled[0:len(dst)] = dst[:]
# dst = dst_sampled
edge_type_size.append(len(dst))
dst_dict[key] = dst
print ('destination candidates generated for all edge types ', len(dst_dict))
Expand Down Expand Up @@ -401,118 +389,4 @@ def generate_negative_samples_random(self,
evaluation_set[(pos_t, pos_s, edge_type)] = neg_d_arr
save_pkl(evaluation_set, filename)





# def generate_negative_samples_ftr(self,
# data: TemporalData,
# split_mode: str,
# filename: str,
# ) -> None:
# r"""
# now we consider (s, d, t, edge_type) as a unique edge
# Generate negative samples based on the random strategy:
# - for each positive edge, sample a batch of negative edges from all possible edges with the same source node
# - filter actual positive edges at the same timestamp with the same edge type

# Parameters:
# data: an object containing positive edges information
# split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
# filename: name of the file containing the generated negative edges
# """
# print(
# f"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}"
# )
# assert split_mode in [
# "val",
# "test",
# ], "Invalid split-mode! It should be `val` or `test`!"

# if os.path.exists(filename):
# print(
# f"INFO: negative samples for '{split_mode}' evaluation are already generated!"
# )
# else:
# print(f"INFO: Generating negative samples for '{split_mode}' evaluation!")
# # retrieve the information from the batch
# pos_src, pos_dst, pos_timestamp, edge_type = (
# data.src.cpu().numpy(),
# data.dst.cpu().numpy(),
# data.t.cpu().numpy(),
# data.edge_type.cpu().numpy(),
# )

# # all possible destinations
# all_dst = np.arange(self.first_dst_id, self.last_dst_id + 1)
# evaluation_set = {}
# # generate a list of negative destinations for each positive edge
# pos_edge_tqdm = tqdm(
# zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)
# )

# edge_t_dict = {} # {(t, u, edge_type): {v_1, v_2, ..} }
# #! iterate once to put all edges into a dictionary for reference
# for (
# pos_s,
# pos_d,
# pos_t,
# edge_type,
# ) in pos_edge_tqdm:
# if (pos_t, pos_s, edge_type) not in edge_t_dict:
# edge_t_dict[(pos_t, pos_s, edge_type)] = {pos_d:1}
# else:
# edge_t_dict[(pos_t, pos_s, edge_type)][pos_d] = 1

# conflict_dict = {}
# for key in edge_t_dict:
# conflict_dict[key] = np.array(list(edge_t_dict[key].keys()))

# print ("conflict sets for ns samples for ", len(conflict_dict), " positive edges are generated")

# # save the generated evaluation set to disk
# save_pkl(conflict_dict, filename)

# # pos_src, pos_dst, pos_timestamp, edge_type = (
# # data.src.cpu().numpy(),
# # data.dst.cpu().numpy(),
# # data.t.cpu().numpy(),
# # data.edge_type.cpu().numpy(),
# # )


# # # generate a list of negative destinations for each positive edge
# # pos_edge_tqdm = tqdm(
# # zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)
# # )


# # for (
# # pos_s,
# # pos_d,
# # pos_t,
# # edge_type,
# # ) in pos_edge_tqdm:

# # #! generate all negatives unless restricted
# # conflict_set = list(edge_t_dict[(pos_t, pos_s, edge_type)].keys())

# # # filter out positive destination
# # conflict_set = np.array(conflict_set)
# # filtered_all_dst = np.setdiff1d(all_dst, conflict_set)

# # '''
# # when num_neg_e is larger than all possible destinations simple return all possible destinations
# # '''
# # if (self.num_neg_e < 0):
# # neg_d_arr = filtered_all_dst
# # elif (self.num_neg_e > len(filtered_all_dst)):
# # neg_d_arr = filtered_all_dst
# # else:
# # neg_d_arr = np.random.choice(
# # filtered_all_dst, self.num_neg_e, replace=False) #never replace negatives

# # evaluation_set[(pos_s, pos_d, pos_t, edge_type)] = neg_d_arr

# # # save the generated evaluation set to disk
# # save_pkl(evaluation_set, filename)

12 changes: 1 addition & 11 deletions tgb/linkproppred/tkg_negative_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,7 @@ def __init__(
self.last_dst_id = last_dst_id
self.strategy = strategy
self.dst_dict = None
# if self.strategy in ["dst-time-filtered"]:
# dst_dict_name = (
# partial_path
# + "_"
# + "dst_dict"
# + ".pkl"
# )
# if not os.path.exists(dst_dict_name):
# raise FileNotFoundError(f"File not found at {dst_dict_name}, dst_time_filtered strategy requires the dst_dict file")
# self.dst_dict = load_pkl(dst_dict_name)


def load_eval_set(
self,
fname: str,
Expand Down
File renamed without changes.
File renamed without changes.

0 comments on commit 5f6a899

Please sign in to comment.