Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinkNeighborLoader: Support edge_label_time #5137

Merged
merged 49 commits into from
Aug 10, 2022
Merged
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
6b71e0c
first commit for node time attr
wsad1 Aug 4, 2022
0b2ef57
reset time attribute
wsad1 Aug 4, 2022
a7b008e
Merge branch 'master' into link_time_attr
wsad1 Aug 4, 2022
0d0d990
fix
wsad1 Aug 4, 2022
068e7cd
Update torch_geometric/loader/link_neighbor_loader.py
wsad1 Aug 8, 2022
550ee5b
Update torch_geometric/loader/link_neighbor_loader.py
wsad1 Aug 8, 2022
340356f
add comments
wsad1 Aug 8, 2022
a212f98
update comments
wsad1 Aug 8, 2022
35a0eda
Merge branch 'master' into link_time_attr
wsad1 Aug 8, 2022
52ee2d8
add tests for edge_time
wsad1 Aug 8, 2022
040314f
Merge branch 'link_time_attr' of github.com:wsad1/pytorch_geometric i…
wsad1 Aug 8, 2022
80d7c46
added tests for edge time
wsad1 Aug 8, 2022
3a461ac
added support for edge time with negative sampling
wsad1 Aug 8, 2022
1046e3a
rename function _create_edge_label
wsad1 Aug 8, 2022
a145280
fix typo
wsad1 Aug 8, 2022
89a9686
Merge branch 'master' into link_time_attr
wsad1 Aug 8, 2022
456f13d
Update torch_geometric/loader/link_neighbor_loader.py
wsad1 Aug 9, 2022
9b8f0da
Update torch_geometric/loader/link_neighbor_loader.py
wsad1 Aug 9, 2022
2bc97e0
Merge branch 'link_time_attr' of github.com:wsad1/pytorch_geometric i…
wsad1 Aug 9, 2022
6b84910
renamed variable to edge_label_time + time_attr only works if edge_ti…
wsad1 Aug 9, 2022
6b40bcf
update test for temporal sampling
wsad1 Aug 9, 2022
ba0292c
Merge branch 'master' into link_time_attr
wsad1 Aug 9, 2022
245f9e6
update TODO
wsad1 Aug 9, 2022
8faad58
Merge branch 'link_time_attr' of github.com:wsad1/pytorch_geometric i…
wsad1 Aug 9, 2022
626fcfd
use right num_nodes argument in LinkSampler
wsad1 Aug 9, 2022
53b44c5
rename to edge_label_time
wsad1 Aug 9, 2022
fab0808
node_time is cloned in every call
wsad1 Aug 10, 2022
6f9e32d
hetero_neighbor_sampler now accepts optional node_time_dict
wsad1 Aug 10, 2022
853bf2f
test with multiple workers
wsad1 Aug 10, 2022
1f13c3a
Update torch_geometric/loader/link_neighbor_loader.py
wsad1 Aug 10, 2022
64824b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2022
9af6006
remove edge_label_time from LinkSampler init
wsad1 Aug 10, 2022
3b7efa1
merge upstream
wsad1 Aug 10, 2022
2dd001d
update changelog
wsad1 Aug 10, 2022
08e0c3d
Update torch_geometric/loader/link_neighbor_loader.py
wsad1 Aug 10, 2022
8ffd683
Update torch_geometric/loader/link_neighbor_loader.py
wsad1 Aug 10, 2022
947732d
Update torch_geometric/loader/link_neighbor_loader.py
wsad1 Aug 10, 2022
83c6a5a
Update torch_geometric/loader/link_neighbor_loader.py
wsad1 Aug 10, 2022
409a282
Update torch_geometric/loader/link_neighbor_loader.py
wsad1 Aug 10, 2022
4c8759a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2022
14f187b
Update torch_geometric/loader/link_neighbor_loader.py
wsad1 Aug 10, 2022
d703385
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2022
f1a4320
rename variables+ add edge_label_time as object variable
wsad1 Aug 10, 2022
37304f2
Merge branch 'link_time_attr' of github.com:wsad1/pytorch_geometric i…
wsad1 Aug 10, 2022
23b3377
simplify code with scatter_min
wsad1 Aug 10, 2022
a64fd12
test if nodes are sampled before edge time and not node time
wsad1 Aug 10, 2022
165efc3
update to _modify_node_time
wsad1 Aug 10, 2022
de4da3a
Merge branch 'master' into link_time_attr
wsad1 Aug 10, 2022
f8d79c9
typos
rusty1s Aug 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update comments
  • Loading branch information
wsad1 committed Aug 8, 2022
commit a212f98e24c8c98620ae58aa5509a94fe7ef2cb9
38 changes: 22 additions & 16 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -31,7 +32,14 @@ def __init__(
super().__init__(data, *args, **kwargs)
self.neg_sampling_ratio = neg_sampling_ratio
self.edge_time = edge_time

# TODO if self.edge_time is not None and
# input_type doesn't have time attribute
# set it to float(inf).

if self.edge_time is not None:
# While sampling node time is replaced by edge time.
# Create a copy to reset node times.
self.copy_time_dict = {}
self.copy_time_dict[self.input_type[0]] = self.node_time_dict[
self.input_type[0]].clone()
Expand All @@ -42,10 +50,6 @@ def __init__(
edge_attrs = graph_store.get_all_edge_attrs()
edge_types = [attr.edge_type for attr in edge_attrs]

# TODO(jinu) if self.edge_time is not None and
# input_type doesn't have time attribute
# set it to 0.

# Edge label index is part of the graph:
if self.input_type in edge_types:
self.num_src_nodes, self.num_dst_nodes = edge_attrs[
Expand Down Expand Up @@ -94,19 +98,23 @@ def _create_label(self, edge_label_index, edge_label):
return edge_label_index, edge_label

def _modify_node_time(self, query_dict, edge_time):
"""For edges in a batch replace `src` and `dst`
node times by the max across all edge times."""
def update_time(input_type):
index = query_dict[input_type]
new_node_time, _ = scatter_max(edge_time, index,
dim_size=self.num_src_nodes)
index_unique = index.unique()
self.node_time_dict[input_type][index_unique] = new_node_time[
index_unique]
self.node_time_dict[input_type][index_unique] = max(
new_node_time[index_unique],
self.node_time_dict[input_type][index_unique])

update_time(self.input_type[0])
# TODO(jinu) input_type[0] = input_type[1]
update_time(self.input_type[1])

def _reset_node_time(self, query_dict):
"""Reset `node_time_dict` to its original
value saved in `copy_time_dict`."""
def reset_time(input_type):
index_unique = query_dict[input_type].unique()
self.node_time_dict[input_type][
Expand Down Expand Up @@ -218,13 +226,10 @@ class LinkNeighborLoader(torch.utils.data.DataLoader):
:obj:`neg_sampling_ratio` is currently implemented in an approximate
way, *i.e.* negative edges may contain false negatives.

if :obj:`edge_time` is not :obj:`None` then :obj:`time_attr` needs
to be specified.

if :obj:`edge_time` is :obj:`None`
:obj:`time_attr` is currently implemented such that for an edge
`(src_node, dst_node)`, the neighbors of `src_node` can have a later
timestamp than `dst_node` or vice-versa.
if :obj:`edge_time` is :obj:`None`. :obj:`time_attr` is currently
implemented such that for an edge `(src_node, dst_node)`,
the neighbors of `src_node` can have a later timestamp than
`dst_node` or vice-versa.

Args:
data (torch_geometric.data.Data or torch_geometric.data.HeteroData):
Expand Down Expand Up @@ -334,8 +339,9 @@ def __init__(
device=edge_label_index.device)

if edge_time is not None and time_attr is None:
raise ValueError("`time_attr` has to be specified if"
"`edge_time` is set")
edge_time = None
warnings.warn("`edge_time` is specified by `time_attr` is None."
" No temporal sampling will be done.")

# Save for PyTorch Lightning < 1.6:
self.edge_label = edge_label
Expand Down