diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1e27ad5b18f62..f4fa30754b92f 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -122,6 +122,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The selection `Trainer(strategy="ddp_spawn", ...)` no longer falls back to "ddp" when a cluster environment gets detected ([#16780](https://github.com/Lightning-AI/lightning/pull/16780)) +- Predict's custom BatchSampler that tracks the batch indices no longer consumes the entire batch sampler at the beginning ([#16826](https://github.com/Lightning-AI/lightning/pull/16826)) + + ### Deprecated - @@ -237,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * The fetching classes are now marked as protected ([#16664](https://github.com/Lightning-AI/lightning/pull/16664)) +- The `lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper` class is now marked as protected ([#16826](https://github.com/Lightning-AI/lightning/pull/16826)) + + - Removed the `DataLoaderLoop`, `EvaluationEpochLoop`, and `PredictionEpochLoop` classes ([#16726](https://github.com/Lightning-AI/lightning/pull/16726)) @@ -362,6 +368,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed an issue where `DistributedSampler.set_epoch` wasn't getting called during `trainer.predict` ([#16785](https://github.com/Lightning-AI/lightning/pull/16785), [#16826](https://github.com/Lightning-AI/lightning/pull/16826)) + + - Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806)) @@ -373,8 +382,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed early stopping triggering extra validation runs after reaching `min_epochs` or `min_steps` ([#16719](https://github.com/Lightning-AI/lightning/pull/16719)) -- Fixed bug where `set_epoch` was not called for prediction dataloaders ([#16785](https://github.com/Lightning-AI/lightning/pull/16785)) - ## [1.9.1] - 2023-02-10 ### Fixed diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 2a3c1ef3dc35c..ebf4ca4076e43 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -11,7 +11,7 @@ from lightning.pytorch.loops.loop import _Loop from lightning.pytorch.loops.progress import Progress from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _set_sampler_epoch -from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper +from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper from lightning.pytorch.strategies import DDPSpawnStrategy from lightning.pytorch.trainer import call from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource @@ -215,18 +215,14 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int def _get_batch_indices(self, dataloader: object) -> List[List[int]]: # batches x samples """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our - :class:`~lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper`.""" + :class:`~lightning.pytorch.overrides.distributed._IndexBatchSamplerWrapper`.""" batch_sampler = getattr(dataloader, "batch_sampler", None) - if not isinstance(batch_sampler, IndexBatchSamplerWrapper): + if not isinstance(batch_sampler, _IndexBatchSamplerWrapper): self._warning_cache.warn( f"Couldn't infer the batch indices fetched from your dataloader: `{type(dataloader).__name__}`" ) return [] - seen_batch_indices = batch_sampler.seen_batch_indices - # TODO(carmocca): this could be avoided - # we need to truncate the list because `IndexBatchSamplerWrapper` computes all indices on `__iter__` - seen_batch_indices = seen_batch_indices[: (self.batch_progress.current.completed + 1)] - return seen_batch_indices + return batch_sampler.seen_batch_indices def _store_data_for_prediction_writer(self, batch_idx: int, dataloader_idx: int) -> bool: prediction_writers = [cb for cb in self.trainer.callbacks if isinstance(cb, BasePredictionWriter)] @@ -238,7 +234,7 @@ def _store_data_for_prediction_writer(self, batch_idx: int, dataloader_idx: int) dataloader = combined_loader.flattened[dataloader_idx] batch_indices = self._get_batch_indices(dataloader) if not batch_indices: - # this is only available with `IndexBatchSamplerWrapper`, but it's only used on DataLoaders, if this is + # this is only available with `_IndexBatchSamplerWrapper`, but it's only used on DataLoaders, if this is # reached, it's likely because a non-DataLoader was passed return any_on_epoch batch_indices = batch_indices[batch_idx] diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index 0d6aa9182307e..f9c4602a10f62 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -124,16 +124,25 @@ def _reset_progress(loop: _Loop) -> None: def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None: - """Calls the ``set_epoch`` method on either the sampler or the batch sampler of the given dataloader. + """Calls the ``set_epoch`` method on either the sampler of the given dataloader. - Every PyTorch dataloader has either a sampler or a batch sampler, and if it is wrapped by a + Every PyTorch dataloader has either a sampler or a batch sampler. If the sampler is wrapped by a :class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off. """ - for sampler_name in ("sampler", "batch_sampler"): - sampler = getattr(dataloader, sampler_name, None) - if sampler is not None and callable(getattr(sampler, "set_epoch", None)): - sampler.set_epoch(epoch) + objects = set() + # check dataloader.sampler + if (sampler := getattr(dataloader, "sampler", None)) is not None: + objects.add(sampler) + # check dataloader.batch_sampler.sampler + if (batch_sampler := getattr(dataloader, "batch_sampler", None)) is not None and ( + sampler := getattr(batch_sampler, "sampler", None) + ) is not None: + objects.add(sampler) + for obj in objects: + set_epoch = getattr(obj, "set_epoch", None) + if callable(set_epoch): + set_epoch(epoch) def _select_data_fetcher(trainer: "pl.Trainer") -> _DataFetcher: diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py index 8830c223c622a..b41ed71b29c2e 100644 --- a/src/lightning/pytorch/overrides/distributed.py +++ b/src/lightning/pytorch/overrides/distributed.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Any, cast, Iterable, Iterator, List, Sized, Union +from typing import Any, cast, Dict, Iterable, Iterator, List, Optional, Sized, Union import torch from torch import Tensor @@ -108,34 +108,36 @@ def __iter__(self) -> Iterator: return (self.dataset[index] for index in super().__iter__()) -class IndexBatchSamplerWrapper: +class _IndexBatchSamplerWrapper(BatchSampler): """This class is used to wrap a :class:`torch.utils.data.BatchSampler` and capture its indices.""" - def __init__(self, sampler: BatchSampler) -> None: + def __init__(self, batch_sampler: BatchSampler) -> None: + # do not call super().__init__() on purpose self.seen_batch_indices: List[List[int]] = [] - self._sampler = sampler + + self.__dict__ = { + k: v + for k, v in batch_sampler.__dict__.items() + if k not in ("__next__", "__iter__", "__len__", "__getstate__") + } + self._batch_sampler = batch_sampler + self._iterator: Optional[Iterator[List[int]]] = None + + def __next__(self) -> List[int]: + assert self._iterator is not None + batch = next(self._iterator) + self.seen_batch_indices.append(batch) + return batch def __iter__(self) -> Iterator[List[int]]: self.seen_batch_indices = [] - for batch in self._sampler: - self.seen_batch_indices.append(batch) - yield batch + self._iterator = iter(self._batch_sampler) + return self def __len__(self) -> int: - return len(self._sampler) - - @property - def drop_last(self) -> bool: - return self._sampler.drop_last - - @property - def batch_size(self) -> int: - return self._sampler.batch_size - - @property - def sampler(self) -> Union[Sampler, Iterable]: - return self._sampler.sampler + return len(self._batch_sampler) - def set_epoch(self, epoch: int) -> None: - if hasattr(self._sampler, "set_epoch"): - self._sampler.set_epoch(epoch) + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + state["_iterator"] = None # cannot pickle 'generator' object + return state diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index 0165535425659..2de1b55caaad2 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -29,7 +29,7 @@ has_iterable_dataset, sized_len, ) -from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper +from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache @@ -246,7 +246,7 @@ def _dataloader_init_kwargs_resolve_sampler( """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re-instantiation. - If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so + If the dataloader is being used for prediction, the sampler will be wrapped into an `_IndexBatchSamplerWrapper`, so Lightning can keep track of its indices. If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated @@ -322,8 +322,9 @@ def _dataloader_init_kwargs_resolve_sampler( ) from e if is_predicting: - batch_sampler = IndexBatchSamplerWrapper(batch_sampler) + batch_sampler = _IndexBatchSamplerWrapper(batch_sampler) + # batch_sampler option is mutually exclusive with batch_size, shuffle, sampler, and drop_last return { "sampler": None, "shuffle": False, diff --git a/tests/tests_pytorch/loops/test_evaluation_loop.py b/tests/tests_pytorch/loops/test_evaluation_loop.py index 5d189da6d17f5..15abe16131e57 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop.py @@ -43,18 +43,22 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): assert eval_epoch_end_mock.call_count == 4 -def test_evaluation_loop_sampler_set_epoch_called(tmpdir): +@pytest.mark.parametrize("use_batch_sampler", (False, True)) +def test_evaluation_loop_sampler_set_epoch_called(tmp_path, use_batch_sampler): """Tests that set_epoch is called on the dataloader's sampler (if any) during training and validation.""" def _get_dataloader(): dataset = RandomDataset(32, 64) sampler = RandomSampler(dataset) sampler.set_epoch = Mock() + if use_batch_sampler: + batch_sampler = BatchSampler(sampler, 2, True) + return DataLoader(dataset, batch_sampler=batch_sampler) return DataLoader(dataset, sampler=sampler) model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, limit_train_batches=1, limit_val_batches=1, max_epochs=2, @@ -66,48 +70,19 @@ def _get_dataloader(): train_dataloader = _get_dataloader() val_dataloader = _get_dataloader() trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) - # One for each epoch - assert train_dataloader.sampler.set_epoch.mock_calls == [call(0), call(1)] - # One for each epoch + sanity check - assert val_dataloader.sampler.set_epoch.mock_calls == [call(0), call(0), call(1)] - - val_dataloader = _get_dataloader() - trainer.validate(model, val_dataloader) - assert val_dataloader.sampler.set_epoch.mock_calls == [call(2)] - - -def test_evaluation_loop_batch_sampler_set_epoch_called(tmpdir): - """Tests that set_epoch is called on the dataloader's batch sampler (if any) during training and validation.""" - - def _get_dataloader(): - dataset = RandomDataset(32, 64) - sampler = RandomSampler(dataset) - batch_sampler = BatchSampler(sampler, 2, True) - batch_sampler.set_epoch = Mock() - return DataLoader(dataset, batch_sampler=batch_sampler) - - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=1, - limit_val_batches=1, - max_epochs=2, - enable_model_summary=False, - enable_checkpointing=False, - logger=False, - ) + train_sampler = train_dataloader.batch_sampler.sampler if use_batch_sampler else train_dataloader.sampler + val_sampler = val_dataloader.batch_sampler.sampler if use_batch_sampler else val_dataloader.sampler - train_dataloader = _get_dataloader() - val_dataloader = _get_dataloader() - trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) # One for each epoch - assert train_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(1)] + assert train_sampler.set_epoch.mock_calls == [call(0), call(1)] # One for each epoch + sanity check - assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(0), call(1)] + assert val_sampler.set_epoch.mock_calls == [call(0), call(0), call(1)] val_dataloader = _get_dataloader() trainer.validate(model, val_dataloader) - assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(2)] + val_sampler = val_dataloader.batch_sampler.sampler if use_batch_sampler else val_dataloader.sampler + + assert val_sampler.set_epoch.mock_calls == [call(2)] @mock.patch( diff --git a/tests/tests_pytorch/loops/test_prediction_loop.py b/tests/tests_pytorch/loops/test_prediction_loop.py index 324ea6f4366cb..bc6d5f209a1dd 100644 --- a/tests/tests_pytorch/loops/test_prediction_loop.py +++ b/tests/tests_pytorch/loops/test_prediction_loop.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from unittest import mock -from unittest.mock import call import pytest +from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler from lightning.pytorch import Trainer -from lightning.pytorch.demos.boring_classes import BoringModel +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset +from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper def test_prediction_loop_stores_predictions(tmp_path): @@ -51,21 +51,39 @@ def predict_step(self, batch, batch_idx): assert trainer.predict_loop.predictions == [] -def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path): +@pytest.mark.parametrize("replace_sampler_ddp", (False, True)) +def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, replace_sampler_ddp): """Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction.""" - model = BoringModel() trainer = Trainer( default_root_dir=tmp_path, limit_predict_batches=1, enable_model_summary=False, enable_checkpointing=False, logger=False, + strategy="ddp", + devices=1, + accelerator="cpu", + replace_sampler_ddp=replace_sampler_ddp, ) - trainer.fit_loop.epoch_progress.current.processed = 2 - with mock.patch("lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper.set_epoch") as set_epoch_mock: - trainer.predict(model) - assert set_epoch_mock.mock_calls == [call(2)] + class MyModel(BoringModel): + def predict_dataloader(self): + dataset = RandomDataset(32, 64) + sampler = None + if not replace_sampler_ddp: + sampler = DistributedSampler(dataset) + return DataLoader(dataset, sampler=sampler) + + model = MyModel() + trainer.fit_loop.epoch_progress.current.processed = 2 + trainer.predict(model) + + # torch will set this .sampler attribute for backwards compatibility, but in reality, the batch sampler is used + assert isinstance(trainer.predict_dataloaders.sampler, SequentialSampler) + batch_sampler = trainer.predict_dataloaders.batch_sampler + assert isinstance(batch_sampler, _IndexBatchSamplerWrapper) + assert isinstance(batch_sampler.sampler, DistributedSampler) + assert batch_sampler.sampler.epoch == 2 def test_prediction_loop_with_iterable_dataset(tmp_path): diff --git a/tests/tests_pytorch/loops/test_utilities.py b/tests/tests_pytorch/loops/test_utilities.py index 422daa2843706..0b873a169d83a 100644 --- a/tests/tests_pytorch/loops/test_utilities.py +++ b/tests/tests_pytorch/loops/test_utilities.py @@ -33,4 +33,4 @@ def test_set_sampler_epoch(): dataloader = Mock() _set_sampler_epoch(dataloader, 55) dataloader.sampler.set_epoch.assert_called_once_with(55) - dataloader.batch_sampler.set_epoch.assert_called_once_with(55) + dataloader.batch_sampler.sampler.set_epoch.assert_called_once_with(55) diff --git a/tests/tests_pytorch/overrides/test_distributed.py b/tests/tests_pytorch/overrides/test_distributed.py index 6e06325313afc..4a1057b1b0d78 100644 --- a/tests/tests_pytorch/overrides/test_distributed.py +++ b/tests/tests_pytorch/overrides/test_distributed.py @@ -18,11 +18,11 @@ from lightning.fabric.utilities.data import has_len from lightning.pytorch import seed_everything -from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler +from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper, UnrepeatedDistributedSampler @pytest.mark.parametrize("shuffle", [False, True]) -def test_unrepeated_distributed_sampler(shuffle, tmpdir): +def test_unrepeated_distributed_sampler(shuffle): """Test each rank will receive a different number of elements.""" seed_everything(42) @@ -44,24 +44,29 @@ def test_unrepeated_distributed_sampler(shuffle, tmpdir): assert indices[3][-1] == 35 if shuffle else 99 -def test_index_batch_sampler(tmpdir): +def test_index_batch_sampler(): """Test `IndexBatchSampler` properly extracts indices.""" dataset = range(15) sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, 3, False) - index_batch_sampler = IndexBatchSamplerWrapper(batch_sampler) + index_batch_sampler = _IndexBatchSamplerWrapper(batch_sampler) + assert isinstance(index_batch_sampler, BatchSampler) assert batch_sampler.batch_size == index_batch_sampler.batch_size assert batch_sampler.drop_last == index_batch_sampler.drop_last assert batch_sampler.sampler is sampler + assert index_batch_sampler.sampler is sampler assert list(index_batch_sampler) == index_batch_sampler.seen_batch_indices - - -def test_index_batch_sampler_methods(): - dataset = range(15) - sampler = SequentialSampler(dataset) - batch_sampler = BatchSampler(sampler, 3, False) - index_batch_sampler = IndexBatchSamplerWrapper(batch_sampler) + assert list(index_batch_sampler) == list(batch_sampler) assert isinstance(index_batch_sampler, Iterable) assert has_len(index_batch_sampler) + + iterator = iter(index_batch_sampler) + assert index_batch_sampler.seen_batch_indices == [] + b0 = next(iterator) + assert b0 == [0, 1, 2] + assert index_batch_sampler.seen_batch_indices == [b0] + b1 = next(iterator) + assert b1 == [3, 4, 5] + assert index_batch_sampler.seen_batch_indices == [b0, b1] diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 553fbd49259c4..bf80f15b44c68 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -49,7 +49,7 @@ RandomIterableDatasetWithLen, ) from lightning.pytorch.loggers import TensorBoardLogger -from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler +from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper, UnrepeatedDistributedSampler from lightning.pytorch.strategies import DDPSpawnStrategy, DDPStrategy, SingleDeviceStrategy from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -1286,7 +1286,7 @@ def on_predict_epoch_end(self, trainer, pl_module): if trainer._accelerator_connector.is_distributed: for idx in range(2): assert isinstance(trainer.predict_dataloaders[idx].batch_sampler.sampler, UnrepeatedDistributedSampler) - assert isinstance(trainer.predict_dataloaders[idx].batch_sampler, IndexBatchSamplerWrapper) + assert isinstance(trainer.predict_dataloaders[idx].batch_sampler, _IndexBatchSamplerWrapper) super().on_predict_epoch_end(trainer, pl_module) diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index 592b5a451bf6b..e8bfc87c2f485 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -10,7 +10,7 @@ from lightning.fabric.utilities.data import _replace_dunder_methods from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset -from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper +from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.data import ( _dataloader_init_kwargs_resolve_sampler, @@ -176,8 +176,8 @@ def __init__(self, sampler, extra_arg, drop_last=True): batch_sampler = dataloader.batch_sampler if predicting: - assert isinstance(batch_sampler, IndexBatchSamplerWrapper) - batch_sampler = batch_sampler._sampler + assert isinstance(batch_sampler, _IndexBatchSamplerWrapper) + batch_sampler = batch_sampler._batch_sampler assert isinstance(batch_sampler, MyBatchSampler) assert batch_sampler.drop_last == (not predicting)