Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customize loader arguments for evaluation in LightningDataModule #6450

Merged
merged 5 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
rusty1s committed Jan 17, 2023
commit 5fc30e3a91f3f41d3c6a7f8c016c9586c8eb7e61
11 changes: 6 additions & 5 deletions test/data/lightning/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)


# @onlyCUDA
# @onlyFullTest
@onlyCUDA
@onlyFullTest
@withPackage('pytorch_lightning')
@pytest.mark.parametrize('loader', ['neighbor'])
@pytest.mark.parametrize('strategy_type', [None])
@pytest.mark.parametrize('loader', ['full', 'neighbor'])
@pytest.mark.parametrize('strategy_type', [None, 'ddp_spawn'])
def test_lightning_node_data(get_dataset, strategy_type, loader):
import pytorch_lightning as pl

Expand Down Expand Up @@ -198,11 +198,12 @@ def test_lightning_node_data(get_dataset, strategy_type, loader):
max_epochs=5, log_every_n_steps=1)
datamodule = LightningNodeData(data, loader=loader, batch_size=batch_size,
num_workers=num_workers, **kwargs)
return

old_x = data.x.clone().cpu()
assert str(datamodule) == (f'LightningNodeData(data={data_repr}, '
f'loader={loader}, batch_size={batch_size}, '
f'num_workers={num_workers}, '
f'num_neighbors=[5], '
f'pin_memory={loader != "full"}, '
f'persistent_workers={loader != "full"})')
trainer.fit(model, datamodule)
Expand Down
187 changes: 138 additions & 49 deletions torch_geometric/data/lightning/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class LightningDataset(LightningDataModule):
(default: :obj:`None`)
test_dataset (Dataset, optional): The test dataset.
(default: :obj:`None`)
pred_dataset (Dataset, optional): The prediction dataset.
(default: :obj:`None`)
batch_size (int, optional): How many samples per batch to load.
(default: :obj:`1`)
num_workers (int): How many subprocesses to use for data loading.
Expand All @@ -125,6 +127,7 @@ def __init__(
train_dataset: Dataset,
val_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
pred_dataset: Optional[Dataset] = None,
batch_size: int = 1,
num_workers: int = 0,
**kwargs,
Expand All @@ -140,6 +143,7 @@ def __init__(
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.pred_dataset = pred_dataset

def dataloader(self, dataset: Dataset, **kwargs) -> DataLoader:
return DataLoader(dataset, **kwargs)
Expand Down Expand Up @@ -170,6 +174,14 @@ def test_dataloader(self) -> DataLoader:

return self.dataloader(self.test_dataset, shuffle=False, **kwargs)

def predict_dataloader(self) -> DataLoader:
""""""
kwargs = copy.copy(self.kwargs)
kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

return self.dataloader(self.pred_dataset, shuffle=False, **kwargs)

def __repr__(self) -> str:
kwargs = kwargs_repr(train_dataset=self.train_dataset,
val_dataset=self.val_dataset,
Expand Down Expand Up @@ -251,7 +263,7 @@ class LightningNodeData(LightningDataModule):
:obj:`0` means that the data will be loaded in the main process.
(default: :obj:`0`)
eval_loader_kwargs (Dict[str, Any], optional): Custom keyword arguments
that override the :class:`~torch_geometric.loader.NeighborLoader`
that override the :class:`torch_geometric.loader.NeighborLoader`
configuration during evaluation. (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.loader.NeighborLoader`.
Expand Down Expand Up @@ -363,7 +375,7 @@ def __init__(
# The user wants to override certain values during evaluation, so
# we shallow-copy the sampler and update its attributes.

if hasattr(self.neighbor_sampler):
if hasattr(self, 'neighbor_sampler'):
self.eval_neighbor_sampler = copy.copy(self.neighbor_sampler)

eval_sampler_kwargs, self.eval_loader_kwargs = split_kwargs(
Expand All @@ -376,7 +388,7 @@ def __init__(
self.eval_loader_kwargs = copy.copy(self.loader_kwargs)
self.eval_loader_kwargs.update(eval_loader_kwargs)
else:
if hasattr(self.neighbor_sampler):
if hasattr(self, 'neighbor_sampler'):
self.eval_neighbor_sampler = self.neighbor_sampler

self.eval_loader_kwargs = self.loader_kwargs
Expand Down Expand Up @@ -488,10 +500,6 @@ def test_dataloader(self) -> DataLoader:

def predict_dataloader(self) -> DataLoader:
""""""
kwargs = copy.copy(self.kwargs)
kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

return self.dataloader(
self.input_pred_nodes,
self.input_pred_time,
Expand Down Expand Up @@ -543,9 +551,9 @@ class LightningLinkData(LightningDataModule):
input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]):
The training edges. (default: :obj:`None`)
input_train_labels (torch.Tensor, optional):
The labels of train edges. (default: :obj:`None`)
The labels of training edges. (default: :obj:`None`)
input_train_time (torch.Tensor, optional): The timestamp
of train edges. (default: :obj:`None`)
of training edges. (default: :obj:`None`)
input_val_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]):
The validation edges. (default: :obj:`None`)
input_val_labels (torch.Tensor, optional):
Expand All @@ -555,9 +563,15 @@ class LightningLinkData(LightningDataModule):
input_test_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]):
The test edges. (default: :obj:`None`)
input_test_labels (torch.Tensor, optional):
The labels of train edges. (default: :obj:`None`)
The labels of test edges. (default: :obj:`None`)
input_test_time (torch.Tensor, optional): The timestamp
of test edges. (default: :obj:`None`)
input_pred_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]):
The prediction edges. (default: :obj:`None`)
input_pred_labels (torch.Tensor, optional):
The labels of prediction edges. (default: :obj:`None`)
input_pred_time (torch.Tensor, optional): The timestamp
of prediction edges. (default: :obj:`None`)
loader (str): The scalability technique to use (:obj:`"full"`,
:obj:`"neighbor"`). (default: :obj:`"neighbor"`)
link_sampler (BaseSampler, optional): A custom sampler object to
Expand All @@ -568,6 +582,10 @@ class LightningLinkData(LightningDataModule):
num_workers (int): How many subprocesses to use for data loading.
:obj:`0` means that the data will be loaded in the main process.
(default: :obj:`0`)
eval_loader_kwargs (Dict[str, Any], optional): Custom keyword arguments
that override the
:class:`torch_geometric.loader.LinkNeighborLoader` configuration
during evaluation. (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.loader.LinkNeighborLoader`.
"""
Expand All @@ -583,10 +601,14 @@ def __init__(
input_test_edges: InputEdges = None,
input_test_labels: OptTensor = None,
input_test_time: OptTensor = None,
input_pred_edges: InputEdges = None,
input_pred_labels: OptTensor = None,
input_pred_time: OptTensor = None,
loader: str = "neighbor",
link_sampler: Optional[BaseSampler] = None,
batch_size: int = 1,
num_workers: int = 0,
eval_loader_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
if link_sampler is not None:
Expand All @@ -610,6 +632,16 @@ def __init__(
f"(got '{num_workers}')")
num_workers = 0

if loader == 'full' and kwargs.get('sampler') is not None:
warnings.warn("'sampler' option is not supported for "
"loader='full'")
kwargs.pop('sampler', None)

if loader == 'full' and kwargs.get('batch_sampler') is not None:
warnings.warn("'batch_sampler' option is not supported for "
"loader='full'")
kwargs.pop('sampler', None)

super().__init__(
has_val=input_val_edges is not None,
has_test=input_test_edges is not None,
Expand All @@ -628,27 +660,56 @@ def __init__(
self.data = data
self.loader = loader

# Determine sampler and loader arguments ##############################

if loader in ['neighbor', 'link_neighbor']:
sampler_args = dict(inspect.signature(NeighborSampler).parameters)
sampler_args.pop('data')
sampler_args.pop('share_memory')
sampler_kwargs = {
key: kwargs.get(key, param.default)
for key, param in sampler_args.items()
}
self.neighbor_sampler = NeighborSampler(
data=data,
share_memory=num_workers > 0,
**sampler_kwargs,
# Define a new `NeighborSampler` that can be re-used across
# different data loaders.
sampler_kwargs, self.loader_kwargs = split_kwargs(
self.kwargs,
NeighborSampler,
)
elif link_sampler is not None:
sampler_kwargs.setdefault('share_memory', num_workers > 0)

# TODO Consider renaming to `self.link_sampler`
self.neighbor_sampler = NeighborSampler(data, **sampler_kwargs)

elif link_sampler is not None:
_, self.loader_kwargs = split_kwargs(
self.kwargs,
link_sampler.__class__,
)
self.neighbor_sampler = link_sampler

if getattr(self, 'neighbor_sampler', None) is not None:
cls = self.neighbor_sampler.__class__
for param in inspect.signature(cls).parameters:
self.kwargs.pop(param, None)
else:
self.loader_kwargs = self.kwargs

# Determine validation sampler and loader arguments ###################

if eval_loader_kwargs is not None:
# The user wants to override certain values during evaluation, so
# we shallow-copy the sampler and update its attributes.

if hasattr(self, 'neighbor_sampler'):
self.eval_neighbor_sampler = copy.copy(self.neighbor_sampler)

eval_sampler_kwargs, self.eval_loader_kwargs = split_kwargs(
eval_loader_kwargs,
self.neighbor_sampler.__class__,
)
for key, value in eval_sampler_kwargs.items():
setattr(self.eval_neighbor_sampler, key, value)
else:
self.eval_loader_kwargs = copy.copy(self.loader_kwargs)
self.eval_loader_kwargs.update(eval_loader_kwargs)
else:
if hasattr(self, 'neighbor_sampler'):
self.eval_neighbor_sampler = self.neighbor_sampler

self.eval_loader_kwargs = self.loader_kwargs

self.eval_loader_kwargs.pop('sampler', None)
self.eval_loader_kwargs.pop('batch_sampler', None)

self.input_train_edges = input_train_edges
self.input_train_labels = input_train_labels
Expand All @@ -659,11 +720,15 @@ def __init__(
self.input_test_edges = input_test_edges
self.input_test_labels = input_test_labels
self.input_test_time = input_test_time
self.input_pred_edges = input_pred_edges
self.input_pred_labels = input_pred_labels
self.input_pred_time = input_pred_time

# Can be overriden to set input indices of the `LinkLoader`:
self.input_train_id: OptTensor = None
self.input_val_id: OptTensor = None
self.input_test_id: OptTensor = None
self.input_pred_id: OptTensor = None

def prepare_data(self):
""""""
Expand All @@ -687,26 +752,28 @@ def dataloader(
input_labels: OptTensor = None,
input_time: OptTensor = None,
input_id: OptTensor = None,
link_sampler: Optional[BaseSampler] = None,
**kwargs,
) -> DataLoader:
if self.loader == 'full':
warnings.filterwarnings('ignore', '.*does not have many workers.*')
warnings.filterwarnings('ignore', '.*data loading bottlenecks.*')

kwargs['shuffle'] = True
kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

return torch.utils.data.DataLoader(
[self.data],
collate_fn=lambda xs: xs[0],
**kwargs,
)

else:
if link_sampler is None:
warnings.warn("No 'link_sampler' specified. Falling back to "
"using the default training sampler.")
link_sampler = self.neighbor_sampler

return LinkLoader(
self.data,
link_sampler=self.neighbor_sampler,
link_sampler=link_sampler,
edge_label_index=input_edges,
edge_label=input_labels,
edge_label_time=input_time,
Expand All @@ -716,32 +783,54 @@ def dataloader(

def train_dataloader(self) -> DataLoader:
""""""
shuffle = (self.kwargs.get('sampler', None) is None
and self.kwargs.get('batch_sampler', None) is None)
shuffle = self.kwargs.get('sampler', None) is None
shuffle &= self.kwargs.get('batch_sampler', None) is None

return self.dataloader(self.input_train_edges, self.input_train_labels,
self.input_train_time, self.input_train_id,
shuffle=shuffle, **self.kwargs)
return self.dataloader(
self.input_train_edges,
self.input_train_labels,
self.input_train_time,
self.input_train_id,
link_sampler=getattr(self, 'neighbor_sampler', None),
shuffle=shuffle,
**self.loader_kwargs,
)

def val_dataloader(self) -> DataLoader:
""""""
kwargs = copy.copy(self.kwargs)
kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

return self.dataloader(self.input_val_edges, self.input_val_labels,
self.input_val_time, self.input_val_id,
shuffle=False, **kwargs)
return self.dataloader(
self.input_val_edges,
self.input_val_labels,
self.input_val_time,
self.input_val_id,
link_sampler=getattr(self, 'eval_neighbor_sampler', None),
shuffle=False,
**self.eval_loader_kwargs,
)

def test_dataloader(self) -> DataLoader:
""""""
kwargs = copy.copy(self.kwargs)
kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)
return self.dataloader(
self.input_test_edges,
self.input_test_labels,
self.input_test_time,
self.input_test_id,
link_sampler=getattr(self, 'eval_neighbor_sampler', None),
shuffle=False,
**self.eval_loader_kwargs,
)

return self.dataloader(self.input_test_edges, self.input_test_labels,
self.input_test_time, self.input_test_id,
shuffle=False, **kwargs)
def predict_dataloader(self) -> DataLoader:
""""""
return self.dataloader(
self.input_pred_edges,
self.input_pred_labels,
self.input_pred_time,
self.input_pred_id,
link_sampler=getattr(self, 'eval_neighbor_sampler', None),
shuffle=False,
**self.eval_loader_kwargs,
)

def __repr__(self) -> str:
kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs)
Expand Down