Skip to content
This repository has been archived by the owner on Sep 28, 2022. It is now read-only.

Commit

Permalink
fix overfit_batch sampler replacement logic (Lightning-AI#10486)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas chaton <thomas@grid.ai>
  • Loading branch information
2 people authored and Raalsky committed Nov 23, 2021
1 parent 56c9bf1 commit e772a4a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 19 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Squeeze the early stopping monitor to remove empty tensor dimensions ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461))


- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486))


- Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369))


-

## [1.5.1] - 2021-11-09

### Fixed
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,7 @@ def _reset_eval_dataloader(
for loader_i in range(len(dataloaders)):
loader = dataloaders[loader_i]

if hasattr(loader, "sampler") and isinstance(loader.sampler, RandomSampler):

if hasattr(loader, "sampler") and not isinstance(loader.sampler, SequentialSampler):
# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0 and mode.evaluating:
rank_zero_warn(
Expand Down Expand Up @@ -591,16 +590,17 @@ def _add_sampler_metadata_collate(dataloader: DataLoader) -> None:

@staticmethod
def _resolve_overfit_batches(dataloader: Collection[DataLoader]) -> Collection[DataLoader]:
has_random_sampler = False
all_have_sequential_sampler = True

def resolve_had_random_sampler(dataloader: DataLoader):
nonlocal has_random_sampler
if not has_random_sampler:
has_random_sampler = isinstance(dataloader.sampler, RandomSampler)
def resolve_has_no_sequential_sampler(dataloader: DataLoader):
nonlocal all_have_sequential_sampler
all_have_sequential_sampler = all_have_sequential_sampler & isinstance(
dataloader.sampler, SequentialSampler
)

apply_to_collection(dataloader, DataLoader, resolve_had_random_sampler)
apply_to_collection(dataloader, DataLoader, resolve_has_no_sequential_sampler)

if has_random_sampler:
if not all_have_sequential_sampler:
rank_zero_warn(
"You requested to overfit but enabled training dataloader shuffling."
" We are turning off the training dataloader shuffling for you."
Expand Down
63 changes: 53 additions & 10 deletions tests/trainer/flags/test_overfit_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# limitations under the License.
import pytest
import torch
from torch.utils.data.sampler import Sampler, SequentialSampler

from pytorch_lightning import Trainer
from tests.helpers.boring_model import BoringModel, RandomDataset


def test_overfit_multiple_val_loaders(tmpdir):
"""Tests that only training_step can be used."""
"""Tests that overfit batches works with multiple val dataloaders."""
val_dl_count = 2
overfit_batches = 3

class TestModel(BoringModel):
def validation_step(self, batch, batch_idx, dataloader_idx):
Expand All @@ -31,25 +34,65 @@ def validation_epoch_end(self, outputs) -> None:
pass

def val_dataloader(self):
dl1 = torch.utils.data.DataLoader(RandomDataset(32, 64))
dl2 = torch.utils.data.DataLoader(RandomDataset(32, 64))
return [dl1, dl2]
dls = [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(val_dl_count)]
return dls

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir, max_epochs=2, overfit_batches=1, log_every_n_steps=1, enable_model_summary=False
default_root_dir=tmpdir,
max_epochs=2,
overfit_batches=overfit_batches,
log_every_n_steps=1,
enable_model_summary=False,
)

trainer.fit(model)
assert trainer.num_training_batches == overfit_batches
assert len(trainer.num_val_batches) == val_dl_count
assert all(nbatches == overfit_batches for nbatches in trainer.num_val_batches)


@pytest.mark.parametrize("overfit", [1, 2, 0.1, 0.25, 1.0])
def test_overfit_basic(tmpdir, overfit):
"""Tests that only training_step can be used."""
@pytest.mark.parametrize("overfit_batches", [1, 2, 0.1, 0.25, 1.0])
def test_overfit_basic(tmpdir, overfit_batches):
"""Tests that only training_step can be used when overfitting."""

model = BoringModel()
model.validation_step = None
total_train_samples = len(BoringModel().train_dataloader())

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=overfit, enable_model_summary=False)

trainer = Trainer(
default_root_dir=tmpdir, max_epochs=1, overfit_batches=overfit_batches, enable_model_summary=False
)
trainer.fit(model)

assert trainer.num_val_batches == []
assert trainer.num_training_batches == int(
overfit_batches * (1 if isinstance(overfit_batches, int) else total_train_samples)
)


def test_overfit_batches_raises_warning_in_case_of_sequential_sampler(tmpdir):
class NonSequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source

def __iter__(self):
return iter(range(len(self.data_source)))

def __len__(self):
return len(self.data_source)

class TestModel(BoringModel):
def train_dataloader(self):
dataset = RandomDataset(32, 64)
sampler = NonSequentialSampler(dataset)
return torch.utils.data.DataLoader(dataset, sampler=sampler)

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2)

with pytest.warns(UserWarning, match="requested to overfit but enabled training dataloader shuffling"):
trainer.fit(model)

assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)

0 comments on commit e772a4a

Please sign in to comment.