Skip to content

Commit 6e88368

Browse files
EdisonLeeeeepre-commit-ci[bot]rusty1s
authored
Let ImbalancedSampler accept torch.Tensor as input (pyg-team#5138)
* Fix ImbalancedSampler to support torch.Tensor as input * Update test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * changelog * update Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
1 parent 7692969 commit 6e88368

File tree

3 files changed

+41
-13
lines changed

3 files changed

+41
-13
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.0.5] - 2022-MM-DD
77
### Added
8+
- Let `ImbalancedSampler` accept `torch.Tensor` as input ([#5138](https://github.com/pyg-team/pytorch_geometric/pull/5138))
89
- Added `flow` argument to `gcn_norm` to correctly normalize the adjacency matrix in `GCNConv` ([#5149](https://github.com/pyg-team/pytorch_geometric/pull/5149))
910
- `NeighborSampler` supports graphs without edges ([#5072](https://github.com/pyg-team/pytorch_geometric/pull/5072))
1011
- Added the `MeanSubtractionNorm` layer ([#5068](https://github.com/pyg-team/pytorch_geometric/pull/5068))

test/loader/test_imbalanced_sampler.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import List
22

33
import torch
4-
from torch import Tensor
54

65
from torch_geometric.data import Data
76
from torch_geometric.loader import (
@@ -22,16 +21,21 @@ def test_dataloader_with_imbalanced_sampler():
2221
sampler = ImbalancedSampler(data_list)
2322
loader = DataLoader(data_list, batch_size=10, sampler=sampler)
2423

25-
ys: List[Tensor] = []
26-
for batch in loader:
27-
ys.append(batch.y)
24+
y = torch.cat([batch.y for batch in loader])
2825

29-
histogram = torch.cat(ys).bincount()
26+
histogram = y.bincount()
3027
prob = histogram / histogram.sum()
3128

3229
assert histogram.sum() == len(data_list)
3330
assert prob.min() > 0.4 and prob.max() < 0.6
3431

32+
# Test with label tensor as input:
33+
torch.manual_seed(12345)
34+
sampler = ImbalancedSampler(torch.tensor([data.y for data in data_list]))
35+
loader = DataLoader(data_list, batch_size=10, sampler=sampler)
36+
37+
assert torch.allclose(y, torch.cat([batch.y for batch in loader]))
38+
3539

3640
def test_neighbor_loader_with_imbalanced_sampler():
3741
zeros = torch.zeros(10, dtype=torch.long)
@@ -46,12 +50,18 @@ def test_neighbor_loader_with_imbalanced_sampler():
4650
loader = NeighborLoader(data, batch_size=10, sampler=sampler,
4751
num_neighbors=[-1])
4852

49-
ys: List[Tensor] = []
50-
for batch in loader:
51-
ys.append(batch.y)
53+
y = torch.cat([batch.y for batch in loader])
5254

53-
histogram = torch.cat(ys).bincount()
55+
histogram = y.bincount()
5456
prob = histogram / histogram.sum()
5557

5658
assert histogram.sum() == data.num_nodes
5759
assert prob.min() > 0.4 and prob.max() < 0.6
60+
61+
# Test with label tensor as input:
62+
torch.manual_seed(12345)
63+
sampler = ImbalancedSampler(data.y)
64+
loader = NeighborLoader(data, batch_size=10, sampler=sampler,
65+
num_neighbors=[-1])
66+
67+
assert torch.allclose(y, torch.cat([batch.y for batch in loader]))

torch_geometric/loader/imbalanced_sampler.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,23 @@ class distribution.
3333
batch_size=64, num_neighbors=[-1, -1],
3434
sampler=sampler, ...)
3535
36+
You can also pass in the class labels directly as a :class:`torch.Tensor`:
37+
38+
.. code-block:: python
39+
40+
from torch_geometric.loader import NeighborLoader, ImbalancedSampler
41+
42+
sampler = ImbalancedSampler(data.y)
43+
loader = NeighborLoader(data, input_nodes=data.train_mask,
44+
batch_size=64, num_neighbors=[-1, -1],
45+
sampler=sampler, ...)
46+
3647
Args:
37-
dataset (Dataset or Data): The dataset from which to sample the data,
38-
either given as a :class:`~torch_geometric.data.Dataset` or
39-
:class:`~torch_geometric.data.Data` object.
48+
dataset (Dataset or Data or Tensor): The dataset or class distribution
49+
from which to sample the data, given either as a
50+
:class:`~torch_geometric.data.Dataset`,
51+
:class:`~torch_geometric.data.Data`, or :class:`torch.Tensor`
52+
object.
4053
input_nodes (Tensor, optional): The indices of nodes that are used by
4154
the corresponding loader, *e.g.*, by
4255
:class:`~torch_geometric.loader.NeighborLoader`.
@@ -50,7 +63,7 @@ class distribution.
5063
"""
5164
def __init__(
5265
self,
53-
dataset: Union[Data, Dataset, List[Data]],
66+
dataset: Union[Dataset, Data, List[Data], Tensor],
5467
input_nodes: Optional[Tensor] = None,
5568
num_samples: Optional[int] = None,
5669
):
@@ -60,6 +73,10 @@ def __init__(
6073
assert dataset.num_nodes == y.numel()
6174
y = y[input_nodes] if input_nodes is not None else y
6275

76+
elif isinstance(dataset, Tensor):
77+
y = dataset.view(-1)
78+
y = y[input_nodes] if input_nodes is not None else y
79+
6380
elif isinstance(dataset, InMemoryDataset):
6481
y = dataset.data.y.view(-1)
6582
assert len(dataset) == y.numel()

0 commit comments

Comments
 (0)