Skip to content

Commit

Permalink
Consume the prediction batch indices iteratively (#16826)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Feb 22, 2023
1 parent 598c247 commit 62e3d58
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 107 deletions.
11 changes: 9 additions & 2 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The selection `Trainer(strategy="ddp_spawn", ...)` no longer falls back to "ddp" when a cluster environment gets detected ([#16780](https://github.com/Lightning-AI/lightning/pull/16780))


- Predict's custom BatchSampler that tracks the batch indices no longer consumes the entire batch sampler at the beginning ([#16826](https://github.com/Lightning-AI/lightning/pull/16826))


### Deprecated

-
Expand Down Expand Up @@ -237,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* The fetching classes are now marked as protected ([#16664](https://github.com/Lightning-AI/lightning/pull/16664))


- The `lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper` class is now marked as protected ([#16826](https://github.com/Lightning-AI/lightning/pull/16826))


- Removed the `DataLoaderLoop`, `EvaluationEpochLoop`, and `PredictionEpochLoop` classes ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))


Expand Down Expand Up @@ -362,6 +368,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed an issue where `DistributedSampler.set_epoch` wasn't getting called during `trainer.predict` ([#16785](https://github.com/Lightning-AI/lightning/pull/16785), [#16826](https://github.com/Lightning-AI/lightning/pull/16826))


- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))


Expand All @@ -373,8 +382,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed early stopping triggering extra validation runs after reaching `min_epochs` or `min_steps` ([#16719](https://github.com/Lightning-AI/lightning/pull/16719))


- Fixed bug where `set_epoch` was not called for prediction dataloaders ([#16785](https://github.com/Lightning-AI/lightning/pull/16785))

## [1.9.1] - 2023-02-10

### Fixed
Expand Down
14 changes: 5 additions & 9 deletions src/lightning/pytorch/loops/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from lightning.pytorch.loops.loop import _Loop
from lightning.pytorch.loops.progress import Progress
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _set_sampler_epoch
from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
from lightning.pytorch.strategies import DDPSpawnStrategy
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
Expand Down Expand Up @@ -215,18 +215,14 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int

def _get_batch_indices(self, dataloader: object) -> List[List[int]]: # batches x samples
"""Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our
:class:`~lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper`."""
:class:`~lightning.pytorch.overrides.distributed._IndexBatchSamplerWrapper`."""
batch_sampler = getattr(dataloader, "batch_sampler", None)
if not isinstance(batch_sampler, IndexBatchSamplerWrapper):
if not isinstance(batch_sampler, _IndexBatchSamplerWrapper):
self._warning_cache.warn(
f"Couldn't infer the batch indices fetched from your dataloader: `{type(dataloader).__name__}`"
)
return []
seen_batch_indices = batch_sampler.seen_batch_indices
# TODO(carmocca): this could be avoided
# we need to truncate the list because `IndexBatchSamplerWrapper` computes all indices on `__iter__`
seen_batch_indices = seen_batch_indices[: (self.batch_progress.current.completed + 1)]
return seen_batch_indices
return batch_sampler.seen_batch_indices

def _store_data_for_prediction_writer(self, batch_idx: int, dataloader_idx: int) -> bool:
prediction_writers = [cb for cb in self.trainer.callbacks if isinstance(cb, BasePredictionWriter)]
Expand All @@ -238,7 +234,7 @@ def _store_data_for_prediction_writer(self, batch_idx: int, dataloader_idx: int)
dataloader = combined_loader.flattened[dataloader_idx]
batch_indices = self._get_batch_indices(dataloader)
if not batch_indices:
# this is only available with `IndexBatchSamplerWrapper`, but it's only used on DataLoaders, if this is
# this is only available with `_IndexBatchSamplerWrapper`, but it's only used on DataLoaders, if this is
# reached, it's likely because a non-DataLoader was passed
return any_on_epoch
batch_indices = batch_indices[batch_idx]
Expand Down
21 changes: 15 additions & 6 deletions src/lightning/pytorch/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,25 @@ def _reset_progress(loop: _Loop) -> None:


def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None:
"""Calls the ``set_epoch`` method on either the sampler or the batch sampler of the given dataloader.
"""Calls the ``set_epoch`` method on either the sampler of the given dataloader.
Every PyTorch dataloader has either a sampler or a batch sampler, and if it is wrapped by a
Every PyTorch dataloader has either a sampler or a batch sampler. If the sampler is wrapped by a
:class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning
of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.
"""
for sampler_name in ("sampler", "batch_sampler"):
sampler = getattr(dataloader, sampler_name, None)
if sampler is not None and callable(getattr(sampler, "set_epoch", None)):
sampler.set_epoch(epoch)
objects = set()
# check dataloader.sampler
if (sampler := getattr(dataloader, "sampler", None)) is not None:
objects.add(sampler)
# check dataloader.batch_sampler.sampler
if (batch_sampler := getattr(dataloader, "batch_sampler", None)) is not None and (
sampler := getattr(batch_sampler, "sampler", None)
) is not None:
objects.add(sampler)
for obj in objects:
set_epoch = getattr(obj, "set_epoch", None)
if callable(set_epoch):
set_epoch(epoch)


def _select_data_fetcher(trainer: "pl.Trainer") -> _DataFetcher:
Expand Down
48 changes: 25 additions & 23 deletions src/lightning/pytorch/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any, cast, Iterable, Iterator, List, Sized, Union
from typing import Any, cast, Dict, Iterable, Iterator, List, Optional, Sized, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -108,34 +108,36 @@ def __iter__(self) -> Iterator:
return (self.dataset[index] for index in super().__iter__())


class IndexBatchSamplerWrapper:
class _IndexBatchSamplerWrapper(BatchSampler):
"""This class is used to wrap a :class:`torch.utils.data.BatchSampler` and capture its indices."""

def __init__(self, sampler: BatchSampler) -> None:
def __init__(self, batch_sampler: BatchSampler) -> None:
# do not call super().__init__() on purpose
self.seen_batch_indices: List[List[int]] = []
self._sampler = sampler

self.__dict__ = {
k: v
for k, v in batch_sampler.__dict__.items()
if k not in ("__next__", "__iter__", "__len__", "__getstate__")
}
self._batch_sampler = batch_sampler
self._iterator: Optional[Iterator[List[int]]] = None

def __next__(self) -> List[int]:
assert self._iterator is not None
batch = next(self._iterator)
self.seen_batch_indices.append(batch)
return batch

def __iter__(self) -> Iterator[List[int]]:
self.seen_batch_indices = []
for batch in self._sampler:
self.seen_batch_indices.append(batch)
yield batch
self._iterator = iter(self._batch_sampler)
return self

def __len__(self) -> int:
return len(self._sampler)

@property
def drop_last(self) -> bool:
return self._sampler.drop_last

@property
def batch_size(self) -> int:
return self._sampler.batch_size

@property
def sampler(self) -> Union[Sampler, Iterable]:
return self._sampler.sampler
return len(self._batch_sampler)

def set_epoch(self, epoch: int) -> None:
if hasattr(self._sampler, "set_epoch"):
self._sampler.set_epoch(epoch)
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state["_iterator"] = None # cannot pickle 'generator' object
return state
7 changes: 4 additions & 3 deletions src/lightning/pytorch/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
has_iterable_dataset,
sized_len,
)
from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache
Expand Down Expand Up @@ -246,7 +246,7 @@ def _dataloader_init_kwargs_resolve_sampler(
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
re-instantiation.
If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so
If the dataloader is being used for prediction, the sampler will be wrapped into an `_IndexBatchSamplerWrapper`, so
Lightning can keep track of its indices.
If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated
Expand Down Expand Up @@ -322,8 +322,9 @@ def _dataloader_init_kwargs_resolve_sampler(
) from e

if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
batch_sampler = _IndexBatchSamplerWrapper(batch_sampler)

# batch_sampler option is mutually exclusive with batch_size, shuffle, sampler, and drop_last
return {
"sampler": None,
"shuffle": False,
Expand Down
51 changes: 13 additions & 38 deletions tests/tests_pytorch/loops/test_evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,22 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir):
assert eval_epoch_end_mock.call_count == 4


def test_evaluation_loop_sampler_set_epoch_called(tmpdir):
@pytest.mark.parametrize("use_batch_sampler", (False, True))
def test_evaluation_loop_sampler_set_epoch_called(tmp_path, use_batch_sampler):
"""Tests that set_epoch is called on the dataloader's sampler (if any) during training and validation."""

def _get_dataloader():
dataset = RandomDataset(32, 64)
sampler = RandomSampler(dataset)
sampler.set_epoch = Mock()
if use_batch_sampler:
batch_sampler = BatchSampler(sampler, 2, True)
return DataLoader(dataset, batch_sampler=batch_sampler)
return DataLoader(dataset, sampler=sampler)

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
default_root_dir=tmp_path,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
Expand All @@ -66,48 +70,19 @@ def _get_dataloader():
train_dataloader = _get_dataloader()
val_dataloader = _get_dataloader()
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
# One for each epoch
assert train_dataloader.sampler.set_epoch.mock_calls == [call(0), call(1)]
# One for each epoch + sanity check
assert val_dataloader.sampler.set_epoch.mock_calls == [call(0), call(0), call(1)]

val_dataloader = _get_dataloader()
trainer.validate(model, val_dataloader)
assert val_dataloader.sampler.set_epoch.mock_calls == [call(2)]


def test_evaluation_loop_batch_sampler_set_epoch_called(tmpdir):
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during training and validation."""

def _get_dataloader():
dataset = RandomDataset(32, 64)
sampler = RandomSampler(dataset)
batch_sampler = BatchSampler(sampler, 2, True)
batch_sampler.set_epoch = Mock()
return DataLoader(dataset, batch_sampler=batch_sampler)

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
enable_model_summary=False,
enable_checkpointing=False,
logger=False,
)
train_sampler = train_dataloader.batch_sampler.sampler if use_batch_sampler else train_dataloader.sampler
val_sampler = val_dataloader.batch_sampler.sampler if use_batch_sampler else val_dataloader.sampler

train_dataloader = _get_dataloader()
val_dataloader = _get_dataloader()
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
# One for each epoch
assert train_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(1)]
assert train_sampler.set_epoch.mock_calls == [call(0), call(1)]
# One for each epoch + sanity check
assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(0), call(1)]
assert val_sampler.set_epoch.mock_calls == [call(0), call(0), call(1)]

val_dataloader = _get_dataloader()
trainer.validate(model, val_dataloader)
assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(2)]
val_sampler = val_dataloader.batch_sampler.sampler if use_batch_sampler else val_dataloader.sampler

assert val_sampler.set_epoch.mock_calls == [call(2)]


@mock.patch(
Expand Down
36 changes: 27 additions & 9 deletions tests/tests_pytorch/loops/test_prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from unittest import mock
from unittest.mock import call

import pytest
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper


def test_prediction_loop_stores_predictions(tmp_path):
Expand Down Expand Up @@ -51,21 +51,39 @@ def predict_step(self, batch, batch_idx):
assert trainer.predict_loop.predictions == []


def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path):
@pytest.mark.parametrize("replace_sampler_ddp", (False, True))
def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, replace_sampler_ddp):
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmp_path,
limit_predict_batches=1,
enable_model_summary=False,
enable_checkpointing=False,
logger=False,
strategy="ddp",
devices=1,
accelerator="cpu",
replace_sampler_ddp=replace_sampler_ddp,
)
trainer.fit_loop.epoch_progress.current.processed = 2

with mock.patch("lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper.set_epoch") as set_epoch_mock:
trainer.predict(model)
assert set_epoch_mock.mock_calls == [call(2)]
class MyModel(BoringModel):
def predict_dataloader(self):
dataset = RandomDataset(32, 64)
sampler = None
if not replace_sampler_ddp:
sampler = DistributedSampler(dataset)
return DataLoader(dataset, sampler=sampler)

model = MyModel()
trainer.fit_loop.epoch_progress.current.processed = 2
trainer.predict(model)

# torch will set this .sampler attribute for backwards compatibility, but in reality, the batch sampler is used
assert isinstance(trainer.predict_dataloaders.sampler, SequentialSampler)
batch_sampler = trainer.predict_dataloaders.batch_sampler
assert isinstance(batch_sampler, _IndexBatchSamplerWrapper)
assert isinstance(batch_sampler.sampler, DistributedSampler)
assert batch_sampler.sampler.epoch == 2


def test_prediction_loop_with_iterable_dataset(tmp_path):
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/loops/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ def test_set_sampler_epoch():
dataloader = Mock()
_set_sampler_epoch(dataloader, 55)
dataloader.sampler.set_epoch.assert_called_once_with(55)
dataloader.batch_sampler.set_epoch.assert_called_once_with(55)
dataloader.batch_sampler.sampler.set_epoch.assert_called_once_with(55)
Loading

0 comments on commit 62e3d58

Please sign in to comment.