From e772a4a5cab1f49ff8bd405df05e89ebc12e07f4 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 16 Nov 2021 04:01:45 +0530 Subject: [PATCH] fix overfit_batch sampler replacement logic (#10486) Co-authored-by: thomas chaton --- CHANGELOG.md | 6 ++ pytorch_lightning/trainer/data_loading.py | 18 +++--- tests/trainer/flags/test_overfit_batches.py | 63 +++++++++++++++++---- 3 files changed, 68 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 45d370188317b2..320c9ceefac962 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -154,8 +154,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Squeeze the early stopping monitor to remove empty tensor dimensions ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461)) +- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486)) + + - Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369)) + +- + ## [1.5.1] - 2021-11-09 ### Fixed diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 931f6a92958ee4..bdc051091b50c7 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -438,8 +438,7 @@ def _reset_eval_dataloader( for loader_i in range(len(dataloaders)): loader = dataloaders[loader_i] - if hasattr(loader, "sampler") and isinstance(loader.sampler, RandomSampler): - + if hasattr(loader, "sampler") and not isinstance(loader.sampler, SequentialSampler): # when overfitting, the dataloader should not have sampler if self.overfit_batches > 0 and mode.evaluating: rank_zero_warn( @@ -591,16 +590,17 @@ def _add_sampler_metadata_collate(dataloader: DataLoader) -> None: @staticmethod def _resolve_overfit_batches(dataloader: Collection[DataLoader]) -> Collection[DataLoader]: - has_random_sampler = False + all_have_sequential_sampler = True - def resolve_had_random_sampler(dataloader: DataLoader): - nonlocal has_random_sampler - if not has_random_sampler: - has_random_sampler = isinstance(dataloader.sampler, RandomSampler) + def resolve_has_no_sequential_sampler(dataloader: DataLoader): + nonlocal all_have_sequential_sampler + all_have_sequential_sampler = all_have_sequential_sampler & isinstance( + dataloader.sampler, SequentialSampler + ) - apply_to_collection(dataloader, DataLoader, resolve_had_random_sampler) + apply_to_collection(dataloader, DataLoader, resolve_has_no_sequential_sampler) - if has_random_sampler: + if not all_have_sequential_sampler: rank_zero_warn( "You requested to overfit but enabled training dataloader shuffling." " We are turning off the training dataloader shuffling for you." diff --git a/tests/trainer/flags/test_overfit_batches.py b/tests/trainer/flags/test_overfit_batches.py index 76c8b37405b47e..3860d85ec9836d 100644 --- a/tests/trainer/flags/test_overfit_batches.py +++ b/tests/trainer/flags/test_overfit_batches.py @@ -13,13 +13,16 @@ # limitations under the License. import pytest import torch +from torch.utils.data.sampler import Sampler, SequentialSampler from pytorch_lightning import Trainer from tests.helpers.boring_model import BoringModel, RandomDataset def test_overfit_multiple_val_loaders(tmpdir): - """Tests that only training_step can be used.""" + """Tests that overfit batches works with multiple val dataloaders.""" + val_dl_count = 2 + overfit_batches = 3 class TestModel(BoringModel): def validation_step(self, batch, batch_idx, dataloader_idx): @@ -31,25 +34,65 @@ def validation_epoch_end(self, outputs) -> None: pass def val_dataloader(self): - dl1 = torch.utils.data.DataLoader(RandomDataset(32, 64)) - dl2 = torch.utils.data.DataLoader(RandomDataset(32, 64)) - return [dl1, dl2] + dls = [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(val_dl_count)] + return dls model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, max_epochs=2, overfit_batches=1, log_every_n_steps=1, enable_model_summary=False + default_root_dir=tmpdir, + max_epochs=2, + overfit_batches=overfit_batches, + log_every_n_steps=1, + enable_model_summary=False, ) trainer.fit(model) + assert trainer.num_training_batches == overfit_batches + assert len(trainer.num_val_batches) == val_dl_count + assert all(nbatches == overfit_batches for nbatches in trainer.num_val_batches) -@pytest.mark.parametrize("overfit", [1, 2, 0.1, 0.25, 1.0]) -def test_overfit_basic(tmpdir, overfit): - """Tests that only training_step can be used.""" +@pytest.mark.parametrize("overfit_batches", [1, 2, 0.1, 0.25, 1.0]) +def test_overfit_basic(tmpdir, overfit_batches): + """Tests that only training_step can be used when overfitting.""" model = BoringModel() + model.validation_step = None + total_train_samples = len(BoringModel().train_dataloader()) - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=overfit, enable_model_summary=False) - + trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, overfit_batches=overfit_batches, enable_model_summary=False + ) trainer.fit(model) + + assert trainer.num_val_batches == [] + assert trainer.num_training_batches == int( + overfit_batches * (1 if isinstance(overfit_batches, int) else total_train_samples) + ) + + +def test_overfit_batches_raises_warning_in_case_of_sequential_sampler(tmpdir): + class NonSequentialSampler(Sampler): + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(range(len(self.data_source))) + + def __len__(self): + return len(self.data_source) + + class TestModel(BoringModel): + def train_dataloader(self): + dataset = RandomDataset(32, 64) + sampler = NonSequentialSampler(dataset) + return torch.utils.data.DataLoader(dataset, sampler=sampler) + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2) + + with pytest.warns(UserWarning, match="requested to overfit but enabled training dataloader shuffling"): + trainer.fit(model) + + assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)