Skip to content
This repository has been archived by the owner on Sep 28, 2022. It is now read-only.

Commit

Permalink
shutdown workers on failure (Lightning-AI#10463)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored and Raalsky committed Nov 23, 2021
1 parent 3bce41a commit dffa3a1
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))


- Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463))


- Squeeze the early stopping monitor to remove empty tensor dimensions ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461))


Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,8 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs:
# reset bookkeeping
self.state.stage = None
self.on_exception(exception)
# shutdown workers
self._data_connector.teardown()
raise

def fit(
Expand Down
34 changes: 27 additions & 7 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from pl_examples.bug_report_model import RandomDataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.loops import Loop, TrainingBatchLoop
from pytorch_lightning.trainer.progress import BaseProgress
from tests.helpers import BoringModel
Expand Down Expand Up @@ -907,8 +907,10 @@ def val_dataloader(self):


@RunIf(min_torch="1.8.0")
@pytest.mark.parametrize("persistent_workers", (False, True))
def test_workers_are_shutdown(tmpdir, persistent_workers):
@pytest.mark.parametrize("should_fail", [False, True])
# False is de-activated due to slowness
@pytest.mark.parametrize("persistent_workers", [True])
def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers):
# `num_workers == 1` uses `_MultiProcessingDataLoaderIter`
# `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance

Expand Down Expand Up @@ -936,12 +938,30 @@ def _get_iterator(self):
train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)

class TestCallback(Callback):
def on_train_epoch_end(self, trainer, *_):
if trainer.current_epoch == 1:
raise CustomException

max_epochs = 3

model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=max_epochs)
trainer.fit(model, train_dataloader, val_dataloader)
assert train_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs)
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=max_epochs,
callbacks=TestCallback() if should_fail else None,
)

if should_fail:
with pytest.raises(CustomException):
trainer.fit(model, train_dataloader, val_dataloader)
else:
trainer.fit(model, train_dataloader, val_dataloader)

assert train_dataloader.count_shutdown_workers == 2 if should_fail else (2 if persistent_workers else max_epochs)
# on sanity checking end, the workers are being deleted too.
assert val_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs + 1)
assert val_dataloader.count_shutdown_workers == 2 if persistent_workers else (3 if should_fail else max_epochs + 1)
assert train_dataloader._iterator is None
assert val_dataloader._iterator is None

0 comments on commit dffa3a1

Please sign in to comment.