Skip to content

Commit

Permalink
Sized iterable typing improvements (#16585)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Feb 1, 2023
1 parent 7d4780a commit c8367a2
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 74 deletions.
33 changes: 19 additions & 14 deletions src/lightning/fabric/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
from collections import OrderedDict
from contextlib import contextmanager
from functools import partial
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Sized, Tuple, Type, Union

from lightning_utilities.core.inheritance import get_all_subclasses
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Sampler
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler
from typing_extensions import TypeGuard

from lightning.fabric.utilities.enums import LightningEnum
from lightning.fabric.utilities.exceptions import MisconfigurationException
Expand All @@ -42,31 +43,35 @@ def __call__(self, *args: Any) -> None:
return fn(*args)


def has_iterable_dataset(dataloader: DataLoader) -> bool:
def has_iterable_dataset(dataloader: object) -> bool:
return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset)


def has_len(dataloader: Union[DataLoader, Iterable, Dataset]) -> bool:
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
infinite dataloader."""
def sized_len(dataloader: object) -> Optional[int]:
"""Try to get the length of an object, return ``None`` otherwise."""
try:
# try getting the length
if len(dataloader) == 0: # type: ignore [arg-type]
rank_zero_warn(
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
)
has_len = True
length = len(dataloader) # type: ignore [arg-type]
except (TypeError, NotImplementedError):
has_len = False
length = None
return length


if has_len and isinstance(dataloader, DataLoader) and has_iterable_dataset(dataloader):
def has_len(dataloader: object) -> TypeGuard[Sized]:
"""Checks if a given object has ``__len__`` method implemented."""
length = sized_len(dataloader)
if length == 0:
rank_zero_warn(
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
)
if length is not None and has_iterable_dataset(dataloader):
rank_zero_warn(
"Your `IterableDataset` has `__len__` defined."
" In combination with multi-process data loading (when num_workers > 1),"
" `__len__` could be inaccurate if each worker is not configured independently"
" to avoid having duplicate data."
)
return has_len
return length is not None


def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]) -> DataLoader:
Expand Down
7 changes: 2 additions & 5 deletions src/lightning/pytorch/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Generator, Optional, Tuple, Union
from typing import Generator, Iterable, Optional, Tuple

import torch
from torch import Tensor
from torch.utils.data import DataLoader

import lightning.pytorch as pl
from lightning.fabric.utilities.warnings import PossibleUserWarning
Expand All @@ -25,7 +24,6 @@
from lightning.pytorch.loops.progress import BaseProgress
from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.strategy import Strategy
from lightning.pytorch.trainer.supporters import CombinedLoader
from lightning.pytorch.utilities.rank_zero import rank_zero_warn


Expand Down Expand Up @@ -121,14 +119,13 @@ def _reset_progress(loop: _Loop) -> None:
_reset_progress(v)


def _set_sampler_epoch(dataloader: Union[DataLoader, CombinedLoader], epoch: int) -> None:
def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None:
"""Calls the ``set_epoch`` method on either the sampler or the batch sampler of the given dataloader.
Every PyTorch dataloader has either a sampler or a batch sampler, and if it is wrapped by a
:class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning
of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.
"""

for sampler_name in ("sampler", "batch_sampler"):
sampler = getattr(dataloader, sampler_name, None)
if sampler is not None and callable(getattr(sampler, "set_epoch", None)):
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,7 +1361,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -

module = model or self.lightning_module or self.datamodule
orig_train_batches = self.num_training_batches = (
len(self.train_dataloader) # type: ignore[arg-type]
len(self.train_dataloader)
if has_len_all_ranks(self.train_dataloader, self.strategy, module)
else float("inf")
)
Expand Down
6 changes: 2 additions & 4 deletions src/lightning/pytorch/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import uuid
from copy import deepcopy
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Iterable, Optional, Tuple

import lightning.pytorch as pl
from lightning.pytorch.utilities.memory import garbage_collection_cuda, is_oom_error
Expand Down Expand Up @@ -328,9 +328,7 @@ def _adjust_batch_size(
return new_size, changed


def _is_valid_batch_size(
batch_size: int, dataloader: "pl.trainer.supporters.CombinedLoader", trainer: "pl.Trainer"
) -> bool:
def _is_valid_batch_size(batch_size: int, dataloader: Iterable, trainer: "pl.Trainer") -> bool:
from lightning.pytorch.utilities.data import has_len_all_ranks

module = trainer.lightning_module or trainer.datamodule
Expand Down
93 changes: 43 additions & 50 deletions src/lightning/pytorch/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,24 @@
# limitations under the License.
import inspect
from dataclasses import fields
from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Tuple, Union
from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union

import torch
from lightning_utilities.core.apply_func import is_dataclass_instance
from lightning_utilities.core.rank_zero import rank_prefixed_message
from torch import Tensor
from torch.utils.data import (
BatchSampler,
DataLoader,
Dataset,
IterableDataset,
RandomSampler,
Sampler,
SequentialSampler,
)
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler, SequentialSampler
from typing_extensions import TypeGuard

import lightning.pytorch as pl
from lightning.fabric.utilities.data import _reinstantiate_wrapped_cls, _replace_value_in_saved_args
from lightning.fabric.utilities.data import has_iterable_dataset as new_has_iterable_dataset
from lightning.fabric.utilities.data import has_len as new_has_len
from lightning.fabric.utilities.data import (
_reinstantiate_wrapped_cls,
_replace_value_in_saved_args,
has_iterable_dataset,
sized_len,
)
from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.trainer.supporters import CombinedLoader
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache

Expand Down Expand Up @@ -94,41 +90,41 @@ def extract_batch_size(batch: BType) -> int:


def has_len_all_ranks(
dataloader: Union[DataLoader, CombinedLoader],
dataloader: object,
strategy: "pl.strategies.Strategy",
model: Union["pl.LightningModule", "pl.LightningDataModule"],
) -> bool:
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
infinite dataloader."""
try:
local_length = len(dataloader) # type: ignore [arg-type] # we are checking with duck-typing
total_length = strategy.reduce(torch.tensor(local_length, device=strategy.root_device), reduce_op="sum")

if total_length == 0:
rank_zero_warn(
f"Total length of `{dataloader.__class__.__name__}` across ranks is zero."
" Please make sure this was your intention."
) -> TypeGuard[Sized]:
"""Checks if a given object has ``__len__`` method implemented on all aranks."""
local_length = sized_len(dataloader)
has_len = True
if local_length is None:
# if one rank does not define a length, the reduction after would fail, default to 0
local_length = 0
has_len = False
total_length = strategy.reduce(torch.tensor(local_length, device=strategy.root_device), reduce_op="sum")
if total_length == 0:
rank_zero_warn(
f"Total length of `{type(dataloader).__name__}` across ranks is zero."
" Please make sure this was your intention."
)
if total_length > 0 and local_length == 0:
dataloader_cls_name = type(dataloader).__name__
if not has_len:
raise RuntimeError(
rank_prefixed_message(f"The `{dataloader_cls_name}` does not define a length.", strategy.global_rank)
)
if total_length > 0 and local_length == 0:
if model.allow_zero_length_dataloader_with_multiple_devices:
rank_zero_warn(
f"Total length of `{dataloader.__class__.__name__}` across ranks is zero, but local rank has zero"
" length. Please be cautious of uneven batch length."
)
has_len = False
else:
raise MisconfigurationException(
f"`{dataloader.__class__.__name__}` within local rank has zero length."
" Please make sure that it returns at least 1 batch."
)
else:
has_len = True

except (TypeError, NotImplementedError):
if not model.allow_zero_length_dataloader_with_multiple_devices:
raise RuntimeError(
f"`{dataloader_cls_name}` within local rank has zero length."
" Please make sure that it returns at least 1 batch."
)
rank_zero_warn(
f"Total length of `{dataloader_cls_name}` across ranks is zero, but local rank has zero"
" length. Please be cautious of uneven batch length."
)
has_len = False

# we are checking using lightning.fabric, which doesn't know CombinedLoader
if has_len and new_has_iterable_dataset(dataloader): # type: ignore [arg-type]
if has_len and has_iterable_dataset(dataloader):
rank_zero_warn(
"Your `IterableDataset` has `__len__` defined."
" In combination with multi-process data loading (when num_workers > 1),"
Expand All @@ -138,16 +134,13 @@ def has_len_all_ranks(
return has_len


def get_len(dataloader: Union[DataLoader, Dataset]) -> Union[int, float]:
def get_len(dataloader: object) -> Union[int, float]:
"""Return the length of the given DataLoader.
If ``__len__`` method is not implemented, return float('inf').
"""

if new_has_len(dataloader):
return len(dataloader) # type: ignore [arg-type]

return float("inf")
length = sized_len(dataloader)
return float("inf") if length is None else length


def _update_dataloader(
Expand Down

0 comments on commit c8367a2

Please sign in to comment.