From aae797f71434a8e79f0b0c07e0db0e37597523f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 9 Nov 2021 16:35:47 +0100 Subject: [PATCH] [Fault Tolerance] Don't check the len of a dataset, but its instance. (#10432) Co-authored-by: Thomas Chaton --- CHANGELOG.md | 1 + pytorch_lightning/trainer/data_loading.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 22615f329dbfd..c96074a6f640f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the logging with `on_step=True` in epoch-level hooks causing unintended side-effects. Logging with `on_step=True` in epoch-level hooks will now correctly raise an error ([#10409](https://github.com/PyTorchLightning/pytorch-lightning/pull/10409)) - Fixed deadlocks for distributed training with `RichProgressBar` ([#10428](https://github.com/PyTorchLightning/pytorch-lightning/pull/10428)) - Fixed an issue where the model wrapper in Lite converted non-floating point tensors to float ([#10429](https://github.com/PyTorchLightning/pytorch-lightning/pull/10429)) +- Fixed an issue with inferring the dataset type in fault-tolerant training ([#10432](https://github.com/PyTorchLightning/pytorch-lightning/pull/10432)) ## [1.5.0] - 2021-11-02 diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index c41d80b903d4e..0f1eccc3f7cbb 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -37,7 +37,7 @@ CaptureMapDataset, FastForwardSampler, ) -from pytorch_lightning.utilities.data import has_iterable_dataset, has_len_all_ranks +from pytorch_lightning.utilities.data import get_len, has_iterable_dataset, has_len_all_ranks from pytorch_lightning.utilities.enums import DistributedType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -282,10 +282,11 @@ def _get_dataloader_init_kwargs( dl_kwargs["sampler"] = None if _fault_tolerant_training(): - if isinstance(dl_kwargs["dataset"], IterableDataset): + dataset = dl_kwargs["dataset"] + if isinstance(dataset, IterableDataset): # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"]) - elif len(dl_kwargs["dataset"]): + elif get_len(dataset) != float("inf"): dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"]) else: raise MisconfigurationException(