Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(sampler): consolidate link neighbor sampling interface, part 2 #5365

Merged
merged 41 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
0b976ee
init
mananshah99 Aug 29, 2022
015743e
Merge branch 'master' of github.com:pyg-team/pytorch_geometric into r…
mananshah99 Aug 29, 2022
b111754
update
mananshah99 Aug 29, 2022
9a36b3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2022
f23586b
update
mananshah99 Aug 29, 2022
5c35cd6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2022
6545ced
update
mananshah99 Aug 29, 2022
aebf0d5
more cleanup
mananshah99 Aug 29, 2022
6741272
update
mananshah99 Aug 29, 2022
aacc361
fix
mananshah99 Aug 29, 2022
4b8d88f
init
mananshah99 Aug 30, 2022
1b0bfc8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 30, 2022
d7ec97c
merge
mananshah99 Aug 30, 2022
5a88c3d
Merge branch 'remote_backend_2' of github.com:pyg-team/pytorch_geomet…
mananshah99 Aug 30, 2022
71b76e3
rm
mananshah99 Aug 30, 2022
dfa34dc
update
mananshah99 Sep 6, 2022
0058f7f
minor
mananshah99 Sep 6, 2022
fc4794f
merge
mananshah99 Sep 6, 2022
d7d1e5e
udpate
mananshah99 Sep 6, 2022
46aacc0
update
mananshah99 Sep 6, 2022
fe98488
init
mananshah99 Sep 7, 2022
a0b82b1
merge with part 1
mananshah99 Sep 7, 2022
fcd086f
merge
mananshah99 Sep 7, 2022
91f18a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2022
bdf1682
merge leftovers
mananshah99 Sep 7, 2022
d77e314
Merge branch 'remote_backend_3' of github.com:pyg-team/pytorch_geomet…
mananshah99 Sep 7, 2022
28289c8
update
mananshah99 Sep 7, 2022
d2caca9
flake8
mananshah99 Sep 7, 2022
1f77437
update
mananshah99 Sep 7, 2022
6d8dc23
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2022
168ee7b
update
mananshah99 Sep 7, 2022
b695a5d
update
mananshah99 Sep 7, 2022
75ae7a4
update
mananshah99 Sep 7, 2022
39c003c
Merge branch 'master' into remote_backend_3
rusty1s Sep 8, 2022
6dc4282
merge
mananshah99 Sep 8, 2022
d64a149
udpate
mananshah99 Sep 8, 2022
9e67de5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2022
22c1913
update
mananshah99 Sep 8, 2022
5c3bac3
merge
mananshah99 Sep 8, 2022
88e2764
cleanup
mananshah99 Sep 8, 2022
b5d1b6e
final
mananshah99 Sep 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 30, 2022
commit 1b0bfc83e0c12e8c4d2c6c32eb4f8a17bcddbf16
6 changes: 4 additions & 2 deletions torch_geometric/sampler/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple, Tuple, Union, List
from typing import Dict, List, NamedTuple, Tuple, Union

import torch
from torch_geometric.typing import NodeType, EdgeType

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.feature_store import FeatureStore
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.typing import EdgeType, NodeType

# An input to a sampler is either a list or tensor of node indices:
SamplerInput = Union[List[int], torch.Tensor]
Expand Down
8 changes: 6 additions & 2 deletions torch_geometric/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
to_csc,
to_hetero_csc,
)
from torch_geometric.sampler.base import BaseSampler, SamplerInput, SamplerOutput
from torch_geometric.sampler.base import (
BaseSampler,
SamplerInput,
SamplerOutput,
)
from torch_geometric.typing import NumNeighbors


Expand Down Expand Up @@ -205,4 +209,4 @@ def __call__(self, index: SamplerInput) -> SamplerOutput:
elif self.data_cls == 'custom' or issubclass(self.data_cls,
HeteroData):
return self._hetero_sparse_neighbor_sample(
{self.input_type: index}) + (index.numel(), )
{self.input_type: index}) + (index.numel(), )