From c8367a22b883d07d9a3b60034b5787c984518779 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 1 Feb 2023 20:00:46 +0100 Subject: [PATCH] Sized iterable typing improvements (#16585) --- src/lightning/fabric/utilities/data.py | 33 ++++--- src/lightning/pytorch/loops/utilities.py | 7 +- src/lightning/pytorch/trainer/trainer.py | 2 +- .../pytorch/tuner/batch_size_scaling.py | 6 +- src/lightning/pytorch/utilities/data.py | 93 +++++++++---------- 5 files changed, 67 insertions(+), 74 deletions(-) diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py index cb44a9bab4b40..f8a3ae9a15d22 100644 --- a/src/lightning/fabric/utilities/data.py +++ b/src/lightning/fabric/utilities/data.py @@ -18,10 +18,11 @@ from collections import OrderedDict from contextlib import contextmanager from functools import partial -from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Generator, Iterable, Optional, Sized, Tuple, Type, Union from lightning_utilities.core.inheritance import get_all_subclasses -from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Sampler +from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler +from typing_extensions import TypeGuard from lightning.fabric.utilities.enums import LightningEnum from lightning.fabric.utilities.exceptions import MisconfigurationException @@ -42,31 +43,35 @@ def __call__(self, *args: Any) -> None: return fn(*args) -def has_iterable_dataset(dataloader: DataLoader) -> bool: +def has_iterable_dataset(dataloader: object) -> bool: return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset) -def has_len(dataloader: Union[DataLoader, Iterable, Dataset]) -> bool: - """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or - infinite dataloader.""" +def sized_len(dataloader: object) -> Optional[int]: + """Try to get the length of an object, return ``None`` otherwise.""" try: # try getting the length - if len(dataloader) == 0: # type: ignore [arg-type] - rank_zero_warn( - f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention." - ) - has_len = True + length = len(dataloader) # type: ignore [arg-type] except (TypeError, NotImplementedError): - has_len = False + length = None + return length + - if has_len and isinstance(dataloader, DataLoader) and has_iterable_dataset(dataloader): +def has_len(dataloader: object) -> TypeGuard[Sized]: + """Checks if a given object has ``__len__`` method implemented.""" + length = sized_len(dataloader) + if length == 0: + rank_zero_warn( + f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention." + ) + if length is not None and has_iterable_dataset(dataloader): rank_zero_warn( "Your `IterableDataset` has `__len__` defined." " In combination with multi-process data loading (when num_workers > 1)," " `__len__` could be inaccurate if each worker is not configured independently" " to avoid having duplicate data." ) - return has_len + return length is not None def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]) -> DataLoader: diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index d2a7c863e7d82..5880bade35577 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Generator, Optional, Tuple, Union +from typing import Generator, Iterable, Optional, Tuple import torch from torch import Tensor -from torch.utils.data import DataLoader import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning @@ -25,7 +24,6 @@ from lightning.pytorch.loops.progress import BaseProgress from lightning.pytorch.strategies.parallel import ParallelStrategy from lightning.pytorch.strategies.strategy import Strategy -from lightning.pytorch.trainer.supporters import CombinedLoader from lightning.pytorch.utilities.rank_zero import rank_zero_warn @@ -121,14 +119,13 @@ def _reset_progress(loop: _Loop) -> None: _reset_progress(v) -def _set_sampler_epoch(dataloader: Union[DataLoader, CombinedLoader], epoch: int) -> None: +def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None: """Calls the ``set_epoch`` method on either the sampler or the batch sampler of the given dataloader. Every PyTorch dataloader has either a sampler or a batch sampler, and if it is wrapped by a :class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off. """ - for sampler_name in ("sampler", "batch_sampler"): sampler = getattr(dataloader, sampler_name, None) if sampler is not None and callable(getattr(sampler, "set_epoch", None)): diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 58146fcd25531..fdf814d69b2d0 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1361,7 +1361,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - module = model or self.lightning_module or self.datamodule orig_train_batches = self.num_training_batches = ( - len(self.train_dataloader) # type: ignore[arg-type] + len(self.train_dataloader) if has_len_all_ranks(self.train_dataloader, self.strategy, module) else float("inf") ) diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index e8d221dab0268..6f206209cd4d3 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -15,7 +15,7 @@ import os import uuid from copy import deepcopy -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Iterable, Optional, Tuple import lightning.pytorch as pl from lightning.pytorch.utilities.memory import garbage_collection_cuda, is_oom_error @@ -328,9 +328,7 @@ def _adjust_batch_size( return new_size, changed -def _is_valid_batch_size( - batch_size: int, dataloader: "pl.trainer.supporters.CombinedLoader", trainer: "pl.Trainer" -) -> bool: +def _is_valid_batch_size(batch_size: int, dataloader: Iterable, trainer: "pl.Trainer") -> bool: from lightning.pytorch.utilities.data import has_len_all_ranks module = trainer.lightning_module or trainer.datamodule diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index 9ae5cf150318d..0b7ec1ef6153d 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -13,28 +13,24 @@ # limitations under the License. import inspect from dataclasses import fields -from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Tuple, Union +from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union import torch from lightning_utilities.core.apply_func import is_dataclass_instance +from lightning_utilities.core.rank_zero import rank_prefixed_message from torch import Tensor -from torch.utils.data import ( - BatchSampler, - DataLoader, - Dataset, - IterableDataset, - RandomSampler, - Sampler, - SequentialSampler, -) +from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler, SequentialSampler +from typing_extensions import TypeGuard import lightning.pytorch as pl -from lightning.fabric.utilities.data import _reinstantiate_wrapped_cls, _replace_value_in_saved_args -from lightning.fabric.utilities.data import has_iterable_dataset as new_has_iterable_dataset -from lightning.fabric.utilities.data import has_len as new_has_len +from lightning.fabric.utilities.data import ( + _reinstantiate_wrapped_cls, + _replace_value_in_saved_args, + has_iterable_dataset, + sized_len, +) from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper from lightning.pytorch.trainer.states import RunningStage -from lightning.pytorch.trainer.supporters import CombinedLoader from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache @@ -94,41 +90,41 @@ def extract_batch_size(batch: BType) -> int: def has_len_all_ranks( - dataloader: Union[DataLoader, CombinedLoader], + dataloader: object, strategy: "pl.strategies.Strategy", model: Union["pl.LightningModule", "pl.LightningDataModule"], -) -> bool: - """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or - infinite dataloader.""" - try: - local_length = len(dataloader) # type: ignore [arg-type] # we are checking with duck-typing - total_length = strategy.reduce(torch.tensor(local_length, device=strategy.root_device), reduce_op="sum") - - if total_length == 0: - rank_zero_warn( - f"Total length of `{dataloader.__class__.__name__}` across ranks is zero." - " Please make sure this was your intention." +) -> TypeGuard[Sized]: + """Checks if a given object has ``__len__`` method implemented on all aranks.""" + local_length = sized_len(dataloader) + has_len = True + if local_length is None: + # if one rank does not define a length, the reduction after would fail, default to 0 + local_length = 0 + has_len = False + total_length = strategy.reduce(torch.tensor(local_length, device=strategy.root_device), reduce_op="sum") + if total_length == 0: + rank_zero_warn( + f"Total length of `{type(dataloader).__name__}` across ranks is zero." + " Please make sure this was your intention." + ) + if total_length > 0 and local_length == 0: + dataloader_cls_name = type(dataloader).__name__ + if not has_len: + raise RuntimeError( + rank_prefixed_message(f"The `{dataloader_cls_name}` does not define a length.", strategy.global_rank) ) - if total_length > 0 and local_length == 0: - if model.allow_zero_length_dataloader_with_multiple_devices: - rank_zero_warn( - f"Total length of `{dataloader.__class__.__name__}` across ranks is zero, but local rank has zero" - " length. Please be cautious of uneven batch length." - ) - has_len = False - else: - raise MisconfigurationException( - f"`{dataloader.__class__.__name__}` within local rank has zero length." - " Please make sure that it returns at least 1 batch." - ) - else: - has_len = True - - except (TypeError, NotImplementedError): + if not model.allow_zero_length_dataloader_with_multiple_devices: + raise RuntimeError( + f"`{dataloader_cls_name}` within local rank has zero length." + " Please make sure that it returns at least 1 batch." + ) + rank_zero_warn( + f"Total length of `{dataloader_cls_name}` across ranks is zero, but local rank has zero" + " length. Please be cautious of uneven batch length." + ) has_len = False - # we are checking using lightning.fabric, which doesn't know CombinedLoader - if has_len and new_has_iterable_dataset(dataloader): # type: ignore [arg-type] + if has_len and has_iterable_dataset(dataloader): rank_zero_warn( "Your `IterableDataset` has `__len__` defined." " In combination with multi-process data loading (when num_workers > 1)," @@ -138,16 +134,13 @@ def has_len_all_ranks( return has_len -def get_len(dataloader: Union[DataLoader, Dataset]) -> Union[int, float]: +def get_len(dataloader: object) -> Union[int, float]: """Return the length of the given DataLoader. If ``__len__`` method is not implemented, return float('inf'). """ - - if new_has_len(dataloader): - return len(dataloader) # type: ignore [arg-type] - - return float("inf") + length = sized_len(dataloader) + return float("inf") if length is None else length def _update_dataloader(