From 8da25f52965395e940150a9b99490809372a0604 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 13 Mar 2023 15:58:34 +0100 Subject: [PATCH] Update data docs --- .../{guides => advanced}/speed.rst | 0 docs/source-pytorch/api_references.rst | 1 + docs/source-pytorch/data/access.rst | 43 ++ ...om_data_iterables.rst => alternatives.rst} | 53 ++- docs/source-pytorch/data/data.rst | 54 +++ docs/source-pytorch/data/datamodule.rst | 14 +- docs/source-pytorch/data/infinite.rst | 49 +++ docs/source-pytorch/data/iterables.rst | 93 +++++ docs/source-pytorch/guides/data.rst | 385 ------------------ docs/source-pytorch/index.rst | 8 +- src/lightning/pytorch/core/hooks.py | 126 +----- src/lightning/pytorch/loops/fit_loop.py | 6 +- .../pytorch/loops/prediction_loop.py | 2 + src/lightning/pytorch/trainer/trainer.py | 51 ++- .../pytorch/utilities/combined_loader.py | 43 +- .../utilities/test_combined_loader.py | 11 +- 16 files changed, 369 insertions(+), 570 deletions(-) rename docs/source-pytorch/{guides => advanced}/speed.rst (100%) create mode 100644 docs/source-pytorch/data/access.rst rename docs/source-pytorch/data/{custom_data_iterables.rst => alternatives.rst} (67%) create mode 100644 docs/source-pytorch/data/data.rst create mode 100644 docs/source-pytorch/data/infinite.rst create mode 100644 docs/source-pytorch/data/iterables.rst delete mode 100644 docs/source-pytorch/guides/data.rst diff --git a/docs/source-pytorch/guides/speed.rst b/docs/source-pytorch/advanced/speed.rst similarity index 100% rename from docs/source-pytorch/guides/speed.rst rename to docs/source-pytorch/advanced/speed.rst diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 70a42d823600ae..fd3c2109280818 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -243,6 +243,7 @@ utilities :toctree: api :nosignatures: + combined_loader data deepspeed distributed diff --git a/docs/source-pytorch/data/access.rst b/docs/source-pytorch/data/access.rst new file mode 100644 index 00000000000000..030da8e170ef31 --- /dev/null +++ b/docs/source-pytorch/data/access.rst @@ -0,0 +1,43 @@ +:orphan: + +Accessing DataLoaders +===================== + +In the case that you require access to the DataLoader or Dataset objects, DataLoaders for each step can be accessed +via the trainer properties :meth:`~lightning.pytorch.trainer.trainer.Trainer.train_dataloader`, +:meth:`~lightning.pytorch.trainer.trainer.Trainer.val_dataloaders`, +:meth:`~lightning.pytorch.trainer.trainer.Trainer.test_dataloaders`, and +:meth:`~lightning.pytorch.trainer.trainer.Trainer.predict_dataloaders`. + +.. code-block:: python + + dataloaders = trainer.train_dataloader + dataloaders = trainer.val_dataloaders + dataloaders = trainer.test_dataloaders + dataloaders = trainer.predict_dataloaders + +These properties will match exactly what was returned in your ``*_dataloader`` hooks or passed to the ``Trainer``, +meaning that if you returned a dictionary of dataloaders, these will return a dictionary of dataloaders. + +Replacing DataLoaders +--------------------- + +If you are using a :class:`~lightning.pytorch.utilities.CombinedLoader`. A flattened list of DataLoaders can be accessed by doing: + +.. code-block:: python + + from lightning.pytorch.utilities import CombinedLoader + + iterables = {"dl1": dl1, "dl2": dl2} + combined_loader = CombinedLoader(iterables) + # access the original iterables + assert combined_loader.iterables is iterables + # the `.flattened` property can be convenient + assert combined_loader.flattened == [dl1, dl2] + # for example, to do a simple loop + updated = [] + for dl in combined_loader.flattened: + new_dl = apply_some_transformation_to(dl) + updated.append(new_dl) + # it also allows you to easily replace the dataloaders + combined_loader.flattened = updated diff --git a/docs/source-pytorch/data/custom_data_iterables.rst b/docs/source-pytorch/data/alternatives.rst similarity index 67% rename from docs/source-pytorch/data/custom_data_iterables.rst rename to docs/source-pytorch/data/alternatives.rst index 647e7ac28c8b2d..db26f9e80cd970 100644 --- a/docs/source-pytorch/data/custom_data_iterables.rst +++ b/docs/source-pytorch/data/alternatives.rst @@ -1,16 +1,17 @@ +:orphan: + .. _dataiters: -################################## -Injecting 3rd Party Data Iterables -################################## +Using 3rd Party Data Iterables +============================== When training a model on a specific task, data loading and preprocessing might become a bottleneck. Lightning does not enforce a specific data loading approach nor does it try to control it. -The only assumption Lightning makes is that the data is returned as an iterable of batches. +The only assumption Lightning makes is that a valid iterable is provided. For PyTorch-based programs, these iterables are typically instances of :class:`~torch.utils.data.DataLoader`. - -However, Lightning also supports other data types such as plain list of batches, generators or other custom iterables. +However, Lightning also supports other data types such as a list of batches, generators, or other custom iterables or +collections of the former. .. code-block:: python @@ -20,13 +21,23 @@ However, Lightning also supports other data types such as plain list of batches, trainer = Trainer() trainer.fit(model, data) -Examples for custom iterables include `NVIDIA DALI `__ or `FFCV `__ for computer vision. -Both libraries offer support for custom data loading and preprocessing (also hardware accelerated) and can be used with Lightning. +Below we showcase Lightning examples with packages that compete with the generic PyTorch DataLoader and might be +faster depending on your use case. They might require custom data serialization, loading, and preprocessing that +is often hardware accelerated. + +StreamingDataset +^^^^^^^^^^^^^^^^ + +The `StreamingDataset `__ FIXME +FFCV +^^^^ -For example, taking the example from FFCV's readme, we can use it with Lightning by just removing the hardcoded ``ToDevice(0)`` -as Lightning takes care of GPU placement. In case you want to use some data transformations on GPUs, change the -``ToDevice(0)`` to ``ToDevice(self.trainer.local_rank)`` to correctly map to the desired GPU in your pipeline. +Taking the example from the `FFCV `__ readme, we can use it with Lightning +by just removing the hardcoded ``ToDevice(0)`` as Lightning takes care of GPU placement. In case you want to use some +data transformations on GPUs, change the ``ToDevice(0)`` to ``ToDevice(self.trainer.local_rank)`` to correctly map to +the desired GPU in your pipeline. When moving data to a specific device, you can always refer to +``self.trainer.local_rank`` to get the accelerator used by the current process. .. code-block:: python @@ -54,8 +65,14 @@ as Lightning takes care of GPU placement. In case you want to use some data tran return loader -When moving data to a specific device, you can always refer to ``self.trainer.local_rank`` to get the accelerator -used by the current process. + +WebDataset +^^^^^^^^^^ + +The `WebDataset `__ FIXME + +NVIDIA DALI +^^^^^^^^^^^ By just changing ``device_id=0`` to ``device_id=self.trainer.local_rank`` we can also leverage DALI's GPU decoding: @@ -107,8 +124,8 @@ Lightning works with all kinds of custom data iterables as shown above. There ar be supported this way. These restrictions come from the fact that for their support, Lightning needs to know a lot on the internals of these iterables. -- In a distributed multi-GPU setting (ddp), - Lightning automatically replaces the DataLoader's sampler with its distributed counterpart. - This makes sure that each GPU sees a different part of the dataset. - As sampling can be implemented in arbitrary ways with custom iterables, - there is no way for Lightning to know, how to replace the sampler. +- In a distributed multi-GPU setting (ddp), Lightning wraps the DataLoader's sampler with a wrapper for distributed + support. This makes sure that each GPU sees a different part of the dataset. As sampling can be implemented in + arbitrary ways with custom iterables, Lightning might not be able to do this for you. If this is the case, you can use + the :paramref:`~lightning.pytorch.trainer.trainer.Trainer.use_distributed_sampler` argument to disable this logic and + set the distributed sampler yourself. diff --git a/docs/source-pytorch/data/data.rst b/docs/source-pytorch/data/data.rst new file mode 100644 index 00000000000000..f9d90f7c7747da --- /dev/null +++ b/docs/source-pytorch/data/data.rst @@ -0,0 +1,54 @@ +.. _data: + +Complex data uses +================= + +.. raw:: html + +
+
+ +.. displayitem:: + :header: LightningDataModules + :description: Introduction to the LightningDataModule + :col_css: col-md-4 + :button_link: datamodule.html + :height: 150 + :tag: basic + +.. displayitem:: + :header: Iterables + :description: What is an iterable? How do I use them? + :col_css: col-md-4 + :button_link: iterables.html + :height: 150 + :tag: basic + +.. displayitem:: + :header: Access your data + :description: How to access your dataloaders + :col_css: col-md-4 + :button_link: access.html + :height: 150 + :tag: basic + +.. displayitem:: + :header: Infinte or streaming datasets + :description: Using iterable-style datasets with Lightning + :col_css: col-md-4 + :button_link: infinite.html + :height: 150 + :tag: basic + +.. displayitem:: + :header: Faster DataLoaders + :description: How alternative dataloader projects can be used with Lightning + :col_css: col-md-4 + :button_link: alternatives.html + :height: 150 + :tag: basic + +.. raw:: html + +
+
diff --git a/docs/source-pytorch/data/datamodule.rst b/docs/source-pytorch/data/datamodule.rst index 606e83d760afc3..407fd39183a2e4 100644 --- a/docs/source-pytorch/data/datamodule.rst +++ b/docs/source-pytorch/data/datamodule.rst @@ -25,8 +25,6 @@ This class can then be shared and used anywhere: .. code-block:: python - from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule - model = LitClassifier() trainer = Trainer() @@ -56,8 +54,11 @@ Datamodules are for you if you ever asked the questions: ********************* What is a DataModule? ********************* -A DataModule is simply a collection of a train_dataloader(s), val_dataloader(s), test_dataloader(s) and -predict_dataloader(s) along with the matching transforms and data processing/downloads steps required. + +The :class:`~lightning.pytorch.core.datamodule.LightningDataModule` is a convenient way to manage data in PyTorch Lightning. +It encapsulates training, validation, testing, and prediction dataloaders, as well as any necessary steps for data processing, +downloads, and transformations. By using a :class:`~lightning.pytorch.core.datamodule.LightningDataModule`, you can +easily develop dataset-agnostic models, hot-swap different datasets, and share data splits and transformations across projects. Here's a simple PyTorch example: @@ -411,7 +412,10 @@ the method runs on the correct devices). trainer.test(datamodule=dm) You can access the current used datamodule of a trainer via ``trainer.datamodule`` and the current used -dataloaders via ``trainer.train_dataloader``, ``trainer.val_dataloaders`` and ``trainer.test_dataloaders``. +dataloaders via the trainer properties :meth:`~lightning.pytorch.trainer.trainer.Trainer.train_dataloader`, +:meth:`~lightning.pytorch.trainer.trainer.Trainer.val_dataloaders`, +:meth:`~lightning.pytorch.trainer.trainer.Trainer.test_dataloaders`, and +:meth:`~lightning.pytorch.trainer.trainer.Trainer.predict_dataloaders`. ---------------- diff --git a/docs/source-pytorch/data/infinite.rst b/docs/source-pytorch/data/infinite.rst new file mode 100644 index 00000000000000..f12be82c748fa8 --- /dev/null +++ b/docs/source-pytorch/data/infinite.rst @@ -0,0 +1,49 @@ +:orphan: + +Iterable Datasets +================= + +Lightning supports using :class:`~torch.utils.data.IterableDataset` as well as map-style Datasets. IterableDatasets provide a more natural +option when using sequential data. + +.. note:: When using an :class:`~torch.utils.data.IterableDataset` you must set the ``val_check_interval`` to 1.0 (the default) or an int + (specifying the number of training batches to run before each validation loop) when initializing the Trainer. This is + because the IterableDataset does not have a ``__len__`` and Lightning requires this to calculate the validation + interval when ``val_check_interval`` is less than one. Similarly, you can set ``limit_{mode}_batches`` to a float or + an int. If it is set to 0.0 or 0, it will set ``num_{mode}_batches`` to 0, if it is an int, it will set ``num_{mode}_batches`` + to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception. + Here ``mode`` can be train/val/test/predict. + +When iterable datasets are used, Lightning will pre-fetch 1 batch (in addition to the current batch) so it can detect +when the training will stop and run validation if necessary. + +.. testcode:: + + # IterableDataset + class CustomDataset(IterableDataset): + def __init__(self, data): + self.data_source = data + + def __iter__(self): + return iter(self.data_source) + + + # Setup DataLoader + def train_dataloader(self): + seq_data = ["A", "long", "time", "ago", "in", "a", "galaxy", "far", "far", "away"] + iterable_dataset = CustomDataset(seq_data) + + dataloader = DataLoader(dataset=iterable_dataset, batch_size=5) + return dataloader + + +.. testcode:: + + # Set val_check_interval as an int + trainer = Trainer(val_check_interval=100) + + # Disable validation: Set limit_val_batches to 0.0 or 0 + trainer = Trainer(limit_val_batches=0.0) + + # Set limit_val_batches as an int + trainer = Trainer(limit_val_batches=100) diff --git a/docs/source-pytorch/data/iterables.rst b/docs/source-pytorch/data/iterables.rst new file mode 100644 index 00000000000000..125ec454e6f364 --- /dev/null +++ b/docs/source-pytorch/data/iterables.rst @@ -0,0 +1,93 @@ +:orphan: + +Arbitrary iterable support +========================== + +Python iterables are objects that can be iterated or looped over. Examples of iterables in Python include lists and dictionaries. +In PyTorch, a :class:`torch.utils.data.DataLoader` is also an iterable which typically retrieves data from a :class:`torch.utils.data.Dataset` or :class:`torch.utils.data.IterableDataset`. + +The :class:`~lightning.pytorch.trainer.trainer.Trainer` works with arbitrary iterables, but most people will use a :class:`torch.utils.data.DataLoader` as the iterable to feed data to the model. + +.. _multiple-dataloaders: + +Multiple Iterables +------------------ + +In addition to supporting arbitrary iterables, the ``Trainer`` also supports arbitrary collections of iterables. Some examples of this are: + +.. code-block:: python + + return DataLoader(...) + return list(range(1000)) + + # pass loaders as a dict. This will create batches like this: + # {'a': batch_from_loader_a, 'b': batch_from_loader_b} + return {"a": DataLoader(...), "b": DataLoader(...)} + + # pass loaders as list. This will create batches like this: + # [batch_from_dl_1, batch_from_dl_2] + return [DataLoader(...), DataLoader(...)] + + # {'a': [batch_from_dl_1, batch_from_dl_2], 'b': [batch_from_dl_3, batch_from_dl_4]} + return {"a": [dl1, dl2], "b": [dl3, dl4]} + +Lightning automatically collates the batches from multiple iterables based on a "mode". This is done with our +:class:`~lightning.pytorch.utilities.combined_loader.CombinedLoader` class. +The list of modes available can be found by looking at the :paramref:`~lightning.pytorch.utilities.combined_loader.CombinedLoader.mode` documentation. + +By default, the ``"max_size_cycle"`` mode is used during training and the ``"sequential"`` mode is used during validation, testing, and prediction. +To choose a different mode, you can use the :class:`~lightning.pytorch.utilities.combined_loader.CombinedLoader` class directly with your mode of choice: + +.. code-block:: python + + from lightning.pytorch.utilities import CombinedLoader + + iterables = {"a": DataLoader(), "b": DataLoader()} + combined_loader = CombinedLoader(iterables, mode="min_size") + model = ... + trainer = Trainer() + trainer.fit(model, combined_loader) + + +Currently, ``trainer.validate``, ``trainer.test``, and ``trainer.predict`` methods only support the ``"sequential"`` mode, while ``trainer.fit`` method does not support it. +Support for this feature is tracked in this `issue `__. + +Note that when using the ``"sequential"`` mode, you need to add an additional argument ``dataloader_idx`` to some specific hooks. +Lightning will `raise an error `__ informing you of this requirement. + +Using LightningDataModule +------------------------- + +You can set more than one :class:`~torch.utils.data.DataLoader` in your :class:`~lightning.pytorch.core.datamodule.LightningDataModule` using its DataLoader hooks +and Lightning will use the correct one. + +.. testcode:: + + class DataModule(LightningDataModule): + def train_dataloader(self): + # any iterable or collection of iterables + return DataLoader(self.train_dataset) + + def val_dataloader(self): + # any iterable or collection of iterables + return [DataLoader(self.val_dataset_1), DataLoader(self.val_dataset_2)] + + def test_dataloader(self): + # any iterable or collection of iterables + return DataLoader(self.test_dataset) + + def predict_dataloader(self): + # any iterable or collection of iterables + return DataLoader(self.predict_dataset) + +Using LightningModule Hooks +--------------------------- + +The exact same code as above works when overriding :class:`~lightning.pytorch.core.module.LightningModule` + +Passing the iterables to the Trainer +------------------------------------ + +The same support for arbitrary iterables, or collection of iterables applies to the dataloader arguments of +:meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`, :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`, +:meth:`~lightning.pytorch.trainer.trainer.Trainer.test`, :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict` diff --git a/docs/source-pytorch/guides/data.rst b/docs/source-pytorch/guides/data.rst deleted file mode 100644 index 96b0072368f6a4..00000000000000 --- a/docs/source-pytorch/guides/data.rst +++ /dev/null @@ -1,385 +0,0 @@ -:orphan: - -.. _data: - -############# -Managing Data -############# - -**************************** -Data Containers in Lightning -**************************** - -There are a few different data containers used in Lightning: - -.. list-table:: Data objects - :widths: 20 80 - :header-rows: 1 - - * - Object - - Definition - * - :class:`~torch.utils.data.Dataset` - - The PyTorch :class:`~torch.utils.data.Dataset` represents a map from keys to data samples. - * - :class:`~torch.utils.data.IterableDataset` - - The PyTorch :class:`~torch.utils.data.IterableDataset` represents a stream of data. - * - :class:`~torch.utils.data.DataLoader` - - The PyTorch :class:`~torch.utils.data.DataLoader` represents a Python iterable over a Dataset. - * - :class:`~lightning.pytorch.core.datamodule.LightningDataModule` - - A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` is simply a collection of: training DataLoader(s), validation DataLoader(s), test DataLoader(s) and predict DataLoader(s), along with the matching transforms and data processing/downloads steps required. - - -Why Use LightningDataModule? -============================ - -The :class:`~lightning.pytorch.core.datamodule.LightningDataModule` was designed as a way of decoupling data-related hooks from the :class:`~lightning.pytorch.core.module.LightningModule` so you can develop dataset agnostic models. The :class:`~lightning.pytorch.core.datamodule.LightningDataModule` makes it easy to hot swap different Datasets with your model, so you can test it and benchmark it across domains. It also makes sharing and reusing the exact data splits and transforms across projects possible. - -Read :ref:`this ` for more details on LightningDataModule. - ---------- - -.. _multiple-dataloaders: - -***************** -Multiple Datasets -***************** - -There are a few ways to pass multiple Datasets to Lightning: - -1. Create a DataLoader that iterates over multiple Datasets under the hood. -2. In the training loop, you can pass multiple DataLoaders as a dict or list/tuple, and Lightning will - automatically combine the batches from different DataLoaders. -3. In the validation, test, or prediction, you have the option to return multiple DataLoaders as list/tuple, which Lightning will call sequentially - or combine the DataLoaders using :class:`~lightning.pytorch.utilities.CombinedLoader`, which is what Lightning uses - under the hood. - - -Using LightningDataModule -========================= - -You can set more than one :class:`~torch.utils.data.DataLoader` in your :class:`~lightning.pytorch.core.datamodule.LightningDataModule` using its DataLoader hooks -and Lightning will use the correct one. - -.. testcode:: - - class DataModule(LightningDataModule): - ... - - def train_dataloader(self): - return DataLoader(self.train_dataset) - - def val_dataloader(self): - return [DataLoader(self.val_dataset_1), DataLoader(self.val_dataset_2)] - - def test_dataloader(self): - return DataLoader(self.test_dataset) - - def predict_dataloader(self): - return DataLoader(self.predict_dataset) - - -Using LightningModule Hooks -=========================== - -Concatenated Dataset --------------------- - -For training with multiple Datasets, you can create a :class:`~torch.utils.data.DataLoader` class -which wraps your multiple Datasets using :class:`~torch.utils.data.ConcatDataset`. This, of course, -also works for testing, validation, and prediction Datasets. - -.. testcode:: - - from torch.utils.data import ConcatDataset - - - class LitModel(LightningModule): - def train_dataloader(self): - concat_dataset = ConcatDataset(datasets.ImageFolder(traindir_A), datasets.ImageFolder(traindir_B)) - - loader = DataLoader( - concat_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True - ) - return loader - - def val_dataloader(self): - # SAME - ... - - def test_dataloader(self): - # SAME - ... - - -Return Multiple DataLoaders ---------------------------- - -You can set multiple DataLoaders in your :class:`~lightning.pytorch.core.module.LightningModule`, and Lightning will take care of batch combination. - -.. testcode:: - - class LitModel(LightningModule): - def train_dataloader(self): - loader_a = DataLoader(range(6), batch_size=4) - loader_b = DataLoader(range(15), batch_size=5) - - # pass loaders as a dict. This will create batches like this: - # {'a': batch from loader_a, 'b': batch from loader_b} - loaders = {"a": loader_a, "b": loader_b} - - # OR: - # pass loaders as sequence. This will create batches like this: - # [batch from loader_a, batch from loader_b] - loaders = [loader_a, loader_b] - - return loaders - -Furthermore, Lightning also supports nested lists and dicts (or a combination). - -.. testcode:: - - class LitModel(LightningModule): - def train_dataloader(self): - loader_a = DataLoader(range(8), batch_size=4) - loader_b = DataLoader(range(16), batch_size=2) - - return {"a": loader_a, "b": loader_b} - - def training_step(self, batch, batch_idx): - # access a dictionary with a batch from each DataLoader - batch_a = batch["a"] - batch_b = batch["b"] - - -.. testcode:: - - class LitModel(LightningModule): - def train_dataloader(self): - loader_a = DataLoader(range(8), batch_size=4) - loader_b = DataLoader(range(16), batch_size=4) - loader_c = DataLoader(range(32), batch_size=4) - loader_c = DataLoader(range(64), batch_size=4) - - # pass loaders as a nested dict. This will create batches like this: - loaders = {"loaders_a_b": [loader_a, loader_b], "loaders_c_d": {"c": loader_c, "d": loader_d}} - return loaders - - def training_step(self, batch, batch_idx): - # access the data - batch_a_b = batch["loaders_a_b"] - batch_c_d = batch["loaders_c_d"] - - batch_a = batch_a_b[0] - batch_b = batch_a_b[1] - - batch_c = batch_c_d["c"] - batch_d = batch_c_d["d"] - -Alternatively, you can also pass in a :class:`~lightning.pytorch.utilities.CombinedLoader` containing multiple DataLoaders. - -.. testcode:: - - from lightning.pytorch.utilities import CombinedLoader - - - def train_dataloader(self): - loader_a = DataLoader() - loader_b = DataLoader() - loaders = {"a": loader_a, "b": loader_b} - combined_loader = CombinedLoader(loaders, mode="max_size_cycle") - return combined_loader - - - def training_step(self, batch, batch_idx): - batch_a = batch["a"] - batch_b = batch["b"] - - -Multiple Validation/Test/Predict DataLoaders -============================================ - -For validation, test and predict DataLoaders, you can pass a single DataLoader or a list of them. This optional named -parameter can be used in conjunction with any of the above use cases. You can choose to pass -the batches sequentially or simultaneously, as is done for the training step. -The default mode for these DataLoaders is sequential. Note that when using a sequence of DataLoaders you need -to add an additional argument ``dataloader_idx`` in their corresponding step specific hook. The corresponding loop will process -the DataLoaders in sequential order; that is, the first DataLoader will be processed completely, then the second one, and so on. - -Refer to the following for more details for the default sequential option: - -- :meth:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` -- :meth:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` -- :meth:`~lightning.pytorch.core.hooks.DataHooks.predict_dataloader` - -.. testcode:: - - def val_dataloader(self): - loader_1 = DataLoader() - loader_2 = DataLoader() - return [loader_1, loader_2] - - - def validation_step(self, batch, batch_idx, dataloader_idx): - ... - - -Evaluation DataLoaders are iterated over sequentially. The above is equivalent to: - -.. testcode:: - - from lightning.pytorch.utilities import CombinedLoader - - - def val_dataloader(self): - loader_a = DataLoader() - loader_b = DataLoader() - loaders = {"a": loader_a, "b": loader_b} - combined_loaders = CombinedLoader(loaders, mode="sequential") - return combined_loaders - - - def validation_step(self, batch, batch_idx): - batch_a = batch["a"] - batch_b = batch["b"] - - -Evaluate with Additional DataLoaders -==================================== - -You can evaluate your models using additional DataLoaders even if the DataLoader specific hooks haven't been defined within your -:class:`~lightning.pytorch.core.module.LightningModule`. For example, this would be the case if your test data -set is not available at the time your model was declared. Simply pass the test set to the :meth:`~lightning.pytorch.trainer.trainer.Trainer.test` method: - -.. code-block:: python - - # setup your DataLoader - test = DataLoader(...) - - # test (pass in the loader) - trainer.test(dataloaders=test) - --------------- - -******************************************** -Accessing DataLoaders within LightningModule -******************************************** - -In the case that you require access to the DataLoader or Dataset objects, DataLoaders for each step can be accessed using the ``Trainer`` object: - -.. testcode:: - - from lightning.pytorch import LightningModule - - - class Model(LightningModule): - def test_step(self, batch, batch_idx, dataloader_idx): - test_dl = self.trainer.test_dataloaders[dataloader_idx] - test_dataset = test_dl.dataset - test_sampler = test_dl.sampler - ... - # extract metadata, etc. from the dataset: - ... - -If you are using a :class:`~lightning.pytorch.utilities.CombinedLoader` object which allows you to fetch batches from a collection of DataLoaders -simultaneously which supports collections of DataLoader such as list, tuple, or dictionary. The DataLoaders can be accessed using the same collection structure: - -.. code-block:: python - - from lightning.pytorch.utilities import CombinedLoader - - test_dl1 = ... - test_dl2 = ... - - # If you provided a list of DataLoaders: - - combined_loader = CombinedLoader([test_dl1, test_dl2]) - list_of_loaders = combined_loader.iterables - test_dl1 = list_of_loaders.loaders[0] - - - # If you provided dictionary of DataLoaders: - - combined_loader = CombinedLoader({"dl1": test_dl1, "dl2": test_dl2}) - dictionary_of_loaders = combined_loader.iterables - test_dl1 = dictionary_of_loaders["dl1"] - --------------- - -.. _sequential-data: - -*************** -Sequential Data -*************** - -Lightning has built in support for dealing with sequential data. - - -Packed Sequences as Inputs -========================== - -When using :class:`~torch.nn.utils.rnn.PackedSequence`, do two things: - -1. Return either a padded tensor in dataset or a list of variable length tensors in the DataLoader's `collate_fn `_ (example shows the list implementation). -2. Pack the sequence in forward or training and validation steps depending on use case. - -| - -.. testcode:: - - # For use in DataLoader - def collate_fn(batch): - x = [item[0] for item in batch] - y = [item[1] for item in batch] - return x, y - - - # In LightningModule - def training_step(self, batch, batch_idx): - x = rnn.pack_sequence(batch[0], enforce_sorted=False) - y = rnn.pack_sequence(batch[1], enforce_sorted=False) - -Iterable Datasets -================= -Lightning supports using :class:`~torch.utils.data.IterableDataset` as well as map-style Datasets. IterableDatasets provide a more natural -option when using sequential data. - -.. note:: When using an :class:`~torch.utils.data.IterableDataset` you must set the ``val_check_interval`` to 1.0 (the default) or an int - (specifying the number of training batches to run before each validation loop) when initializing the Trainer. This is - because the IterableDataset does not have a ``__len__`` and Lightning requires this to calculate the validation - interval when ``val_check_interval`` is less than one. Similarly, you can set ``limit_{mode}_batches`` to a float or - an int. If it is set to 0.0 or 0, it will set ``num_{mode}_batches`` to 0, if it is an int, it will set ``num_{mode}_batches`` - to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception. - Here ``mode`` can be train/val/test/predict. - -When iterable datasets are used, Lightning will pre-fetch 1 batch (in addition to the current batch) so it can detect -when the training will stop and run validation if necessary. - -.. testcode:: - - # IterableDataset - class CustomDataset(IterableDataset): - def __init__(self, data): - self.data_source = data - - def __iter__(self): - return iter(self.data_source) - - - # Setup DataLoader - def train_dataloader(self): - seq_data = ["A", "long", "time", "ago", "in", "a", "galaxy", "far", "far", "away"] - iterable_dataset = CustomDataset(seq_data) - - dataloader = DataLoader(dataset=iterable_dataset, batch_size=5) - return dataloader - - -.. testcode:: - - # Set val_check_interval - trainer = Trainer(val_check_interval=100) - - # Set limit_val_batches to 0.0 or 0 - trainer = Trainer(limit_val_batches=0.0) - - # Set limit_val_batches as an int - trainer = Trainer(limit_val_batches=100) diff --git a/docs/source-pytorch/index.rst b/docs/source-pytorch/index.rst index 1ebda03407ac02..41c9d770557539 100644 --- a/docs/source-pytorch/index.rst +++ b/docs/source-pytorch/index.rst @@ -206,7 +206,7 @@ Current Lightning Users Train on single or multiple TPUs Train on MPS Use a pretrained model - Inject Custom Data Iterables + data/data model/own_your_loop .. toctree:: @@ -233,7 +233,6 @@ Current Lightning Users Lightning CLI LightningDataModule LightningModule - Lightning Transformers Log TPU Metrics @@ -283,8 +282,3 @@ Current Lightning Users .. raw:: html - -.. PyTorch-Lightning documentation master file, created by - sphinx-quickstart on Fri Nov 15 07:48:22 2019. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index c01653bdb74d36..4e8be7d0558ff6 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -388,11 +388,9 @@ def teardown(self, stage: str) -> None: """ def train_dataloader(self) -> TRAIN_DATALOADERS: - """Implement one or more PyTorch DataLoaders for training. + """An iterable or collection of iterables specifying training samples. - Return: - A collection of :class:`torch.utils.data.DataLoader` specifying training samples. - In the case of multiple dataloaders, please see this :ref:`section `. + For more information about multiple dataloaders, see this :ref:`section `. The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.Trainer.reload_dataloaders_every_n_epochs` to @@ -412,55 +410,15 @@ def train_dataloader(self) -> TRAIN_DATALOADERS: - :meth:`setup` Note: - Lightning adds the correct sampler for distributed and arbitrary hardware. + Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself. - - Example:: - - # single dataloader - def train_dataloader(self): - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.5,), (1.0,))]) - dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, - download=True) - loader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=self.batch_size, - shuffle=True - ) - return loader - - # multiple dataloaders, return as list - def train_dataloader(self): - mnist = MNIST(...) - cifar = CIFAR(...) - mnist_loader = torch.utils.data.DataLoader( - dataset=mnist, batch_size=self.batch_size, shuffle=True - ) - cifar_loader = torch.utils.data.DataLoader( - dataset=cifar, batch_size=self.batch_size, shuffle=True - ) - # each batch will be a list of tensors: [batch_mnist, batch_cifar] - return [mnist_loader, cifar_loader] - - # multiple dataloader, return as dict - def train_dataloader(self): - mnist = MNIST(...) - cifar = CIFAR(...) - mnist_loader = torch.utils.data.DataLoader( - dataset=mnist, batch_size=self.batch_size, shuffle=True - ) - cifar_loader = torch.utils.data.DataLoader( - dataset=cifar, batch_size=self.batch_size, shuffle=True - ) - # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} - return {'mnist': mnist_loader, 'cifar': cifar_loader} """ raise MisconfigurationException("`train_dataloader` must be implemented to be used with the Lightning Trainer") def test_dataloader(self) -> EVAL_DATALOADERS: - r""" - Implement one or multiple PyTorch DataLoaders for testing. + r"""An iterable or collection of iterables specifying test samples. + + For more information about multiple dataloaders, see this :ref:`section `. For data processing use the following pattern: @@ -477,44 +435,19 @@ def test_dataloader(self) -> EVAL_DATALOADERS: - :meth:`setup` Note: - Lightning adds the correct sampler for distributed and arbitrary hardware. + Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself. - Return: - A :class:`torch.utils.data.DataLoader` or a sequence of them specifying testing samples. - - Example:: - - def test_dataloader(self): - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.5,), (1.0,))]) - dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, - download=True) - loader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=self.batch_size, - shuffle=False - ) - - return loader - - # can also return multiple dataloaders - def test_dataloader(self): - return [loader_a, loader_b, ..., loader_n] - Note: If you don't need a test dataset and a :meth:`test_step`, you don't need to implement this method. - - Note: - In the case where you return multiple test dataloaders, the :meth:`test_step` - will have an argument ``dataloader_idx`` which matches the order here. """ raise MisconfigurationException("`test_dataloader` must be implemented to be used with the Lightning Trainer") def val_dataloader(self) -> EVAL_DATALOADERS: - r""" - Implement one or multiple PyTorch DataLoaders for validation. + r"""An iterable or collection of iterables specifying validation samples. + + For more information about multiple dataloaders, see this :ref:`section `. The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.Trainer.reload_dataloaders_every_n_epochs` to @@ -528,44 +461,19 @@ def val_dataloader(self) -> EVAL_DATALOADERS: - :meth:`setup` Note: - Lightning adds the correct sampler for distributed and arbitrary hardware + Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself. - Return: - A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. - - Examples:: - - def val_dataloader(self): - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.5,), (1.0,))]) - dataset = MNIST(root='/path/to/mnist/', train=False, - transform=transform, download=True) - loader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=self.batch_size, - shuffle=False - ) - - return loader - - # can also return multiple dataloaders - def val_dataloader(self): - return [loader_a, loader_b, ..., loader_n] - Note: If you don't need a validation dataset and a :meth:`validation_step`, you don't need to implement this method. - - Note: - In the case where you return multiple validation dataloaders, the :meth:`validation_step` - will have an argument ``dataloader_idx`` which matches the order here. """ raise MisconfigurationException("`val_dataloader` must be implemented to be used with the Lightning Trainer") def predict_dataloader(self) -> EVAL_DATALOADERS: - r""" - Implement one or multiple PyTorch DataLoaders for prediction. + r"""An iterable or collection of iterables specifying prediction samples. + + For more information about multiple dataloaders, see this :ref:`section `. It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. @@ -574,15 +482,11 @@ def predict_dataloader(self) -> EVAL_DATALOADERS: - :meth:`setup` Note: - Lightning adds the correct sampler for distributed and arbitrary hardware + Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself. Return: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples. - - Note: - In the case where you return multiple prediction dataloaders, the :meth:`predict_step` - will have an argument ``dataloader_idx`` which matches the order here. """ raise MisconfigurationException( "`predict_dataloader` must be implemented to be used with the Lightning Trainer" diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 3409d33c229f09..663aad8f64a491 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -343,10 +343,8 @@ def advance(self) -> None: combined_loader = self._combined_loader assert combined_loader is not None - if combined_loader._mode not in ("max_size_cycle", "min_size"): - raise ValueError( - f'`{type(self).__name__}` only supports the `CombinedLoader(mode="max_size_cycle" | "min_size")` modes.' - ) + if combined_loader._mode == "sequential": + raise ValueError(f'`{type(self).__name__}` does not support the `CombinedLoader(mode="sequential")` mode.') assert self._data_fetcher is not None self._data_fetcher.setup(combined_loader) with self.trainer.profiler.profile("run_training_epoch"): diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index ca543cb2e1576a..b6ee280f882964 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -163,6 +163,8 @@ def reset(self) -> None: ) combined_loader = self._combined_loader assert combined_loader is not None + if combined_loader._mode != "sequential": + raise ValueError('`trainer.predict()` only supports the `CombinedLoader(mode="sequential")` mode.') data_fetcher.setup(combined_loader) iter(data_fetcher) # creates the iterator inside the fetcher assert isinstance(combined_loader._iterator, _Sequential) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 0fbd619af5c683..653152f239759c 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -501,17 +501,20 @@ def fit( Args: model: Model to fit. - train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a - :class:`~lightning.pytorch.core.datamodule.LightningDataModule` specifying training samples. - In the case of multiple dataloaders, please see this :ref:`section `. + train_dataloaders: An iterable or collection of iterables specifying training samples. + Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines + the `:class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook. - val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. + val_dataloaders: An iterable or collection of iterables specifying validation samples. ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch. - datamodule: An instance of :class:`~lightning.pytorch.core.datamodule.LightningDataModule`. + datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines + the `:class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook. + + For more information about multiple dataloaders, see this :ref:`section `. """ model = _maybe_unwrap_optimized(model) self.strategy._lightning_module = model @@ -574,8 +577,9 @@ def validate( Args: model: The model to validate. - dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, - or a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` specifying validation samples. + dataloaders: An iterable or collection of iterables specifying validation samples. + Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines + the `:class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook. ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate. If ``None`` and the model instance was passed, use the current weights. @@ -584,7 +588,10 @@ def validate( verbose: If True, prints the validation results. - datamodule: An instance of :class:`~lightning.pytorch.core.datamodule.LightningDataModule`. + datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines + the `:class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook. + + For more information about multiple dataloaders, see this :ref:`section `. Returns: List of dictionaries with metrics logged during the validation phase, e.g., in model- or callback hooks @@ -667,8 +674,9 @@ def test( Args: model: The model to test. - dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, - or a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` specifying test samples. + dataloaders: An iterable or collection of iterables specifying test samples. + Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines + the `:class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook. ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test. If ``None`` and the model instance was passed, use the current weights. @@ -677,7 +685,10 @@ def test( verbose: If True, prints the test results. - datamodule: An instance of :class:`~lightning.pytorch.core.datamodule.LightningDataModule`. + datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines + the `:class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook. + + For more information about multiple dataloaders, see this :ref:`section `. Returns: List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks @@ -761,10 +772,12 @@ def predict( Args: model: The model to predict with. - dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, - or a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` specifying prediction samples. + dataloaders: An iterable or collection of iterables specifying predict samples. + Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines + the `:class:`~lightning.pytorch.core.hooks.DataHooks.predict_dataloader` hook. - datamodule: The datamodule with a predict_dataloader method that returns one or more dataloaders. + datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines + the `:class:`~lightning.pytorch.core.hooks.DataHooks.predict_dataloader` hook. return_predictions: Whether to return predictions. ``True`` by default except when an accelerator that spawns processes is used (not supported). @@ -774,6 +787,8 @@ def predict( Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. + For more information about multiple dataloaders, see this :ref:`section `. + Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. @@ -1361,13 +1376,13 @@ def is_last_batch(self) -> bool: return self.fit_loop.epoch_loop.batch_progress.is_last_batch @property - def train_dataloader(self) -> TRAIN_DATALOADERS: + def train_dataloader(self) -> Optional[TRAIN_DATALOADERS]: """The training dataloader(s) used during ``trainer.fit()``.""" if (combined_loader := self.fit_loop._combined_loader) is not None: return combined_loader.iterables @property - def val_dataloaders(self) -> EVAL_DATALOADERS: + def val_dataloaders(self) -> Optional[EVAL_DATALOADERS]: """The validation dataloader(s) used during ``trainer.fit()`` or ``trainer.validate()``.""" if (combined_loader := self.fit_loop.epoch_loop.val_loop._combined_loader) is not None: return combined_loader.iterables @@ -1375,13 +1390,13 @@ def val_dataloaders(self) -> EVAL_DATALOADERS: return combined_loader.iterables @property - def test_dataloaders(self) -> EVAL_DATALOADERS: + def test_dataloaders(self) -> Optional[EVAL_DATALOADERS]: """The test dataloader(s) used during ``trainer.test()``.""" if (combined_loader := self.test_loop._combined_loader) is not None: return combined_loader.iterables @property - def predict_dataloaders(self) -> EVAL_DATALOADERS: + def predict_dataloaders(self) -> Optional[EVAL_DATALOADERS]: """The prediction dataloader(s) used during ``trainer.predict()``.""" if (combined_loader := self.predict_loop._combined_loader) is not None: return combined_loader.iterables diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 895de912bc4ec1..db89f86a61d6a0 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -154,30 +154,30 @@ class _CombinationMode(TypedDict): iterator: Type[_ModeIterator] -_supported_modes = { +_SUPPORTED_MODES = { "min_size": _CombinationMode(fn=min, iterator=_MinSize), "max_size_cycle": _CombinationMode(fn=max, iterator=_MaxSizeCycle), "max_size": _CombinationMode(fn=max, iterator=_MaxSize), "sequential": _CombinationMode(fn=sum, iterator=_Sequential), } - _LITERAL_SUPPORTED_MODES = Literal["min_size", "max_size_cycle", "max_size", "sequential"] class CombinedLoader(Iterable): - """Combines different iterables under custom sampling modes. + """Combines different iterables under specific sampling modes. + + The following modes are supported: + * ``min_size``: stops after the shortest iterable (the one with the lowest number of items) is done. + * ``max_size_cycle``: stops after the longest iterable (the one with most items) is done, while cycling through + the rest of the iterables. + * ``max_size``: stops after the longest iterable (the one with most items) is done, while returning None for the + exhausted iterables. + * ``sequential``: completely consumes ecah iterable sequentially, and returns a triplet + ``(data, idx, iterable_idx)`` Args: - iterables: the loaders to sample from. Can be any kind of collection - mode: - * ``"min_size"``, which raises StopIteration after the shortest iterable (the one with the lowest number of - items) is done. - * ``"max_size_cycle"`` which raises StopIteration after the longest iterable (the one with most items) is - done, while cycling through rest of the iterables. - * ``"max_size"`` which raises StopIteration after the longest iterable (the one with most items) is - done, while returning None for exhausted iterables. - * ``"sequential"`` will consume ecah iterable sequentially, and returns a tuple with the associated index - from each iterable. + iterables: the iterable or collection of iterables to sample from. + mode: the mode to use. Examples: >>> from torch.utils.data import DataLoader @@ -191,6 +191,7 @@ class CombinedLoader(Iterable): {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} {'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])} + >>> combined_loader = CombinedLoader(iterables, 'max_size') >>> len(combined_loader) 3 @@ -199,6 +200,7 @@ class CombinedLoader(Iterable): {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} {'a': None, 'b': tensor([10, 11, 12, 13, 14])} + >>> combined_loader = CombinedLoader(iterables, 'min_size') >>> len(combined_loader) 2 @@ -206,6 +208,7 @@ class CombinedLoader(Iterable): ... print(batch) {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} + >>> combined_loader = CombinedLoader(iterables, 'sequential') >>> len(combined_loader) 5 @@ -219,8 +222,8 @@ class CombinedLoader(Iterable): """ def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None: - if mode not in _supported_modes: - raise ValueError(f"Unsupported mode {mode!r}, please select one of: {list(_supported_modes)}.") + if mode not in _SUPPORTED_MODES: + raise ValueError(f"Unsupported mode {mode!r}, please select one of: {list(_SUPPORTED_MODES)}.") self._iterables = iterables self._flattened, self._spec = _tree_flatten(iterables) self._mode = mode @@ -248,6 +251,7 @@ def flattened(self) -> List[Any]: @flattened.setter def flattened(self, flattened: List[Any]) -> None: + """Setter to conveniently update the list of iterables.""" if len(flattened) != len(self._flattened): raise ValueError( f"Mismatch in flattened length ({len(flattened)}) and existing length ({len(self._flattened)})" @@ -264,7 +268,7 @@ def __next__(self) -> Any: return tree_unflatten(out, self._spec) def __iter__(self) -> Self: - cls = _supported_modes[self._mode]["iterator"] + cls = _SUPPORTED_MODES[self._mode]["iterator"] iterator = cls(self.flattened) iter(iterator) self._iterator = iterator @@ -278,10 +282,11 @@ def __len__(self) -> int: if length is None: raise NotImplementedError(f"`{type(dl).__name__}` does not define `__len__`") lengths.append(length) - fn = _supported_modes[self._mode]["fn"] + fn = _SUPPORTED_MODES[self._mode]["fn"] return fn(lengths) def reset(self) -> None: + """Reset the state and shutdown any workers.""" if self._iterator is not None: self._iterator.reset() self._iterator = None @@ -289,12 +294,12 @@ def reset(self) -> None: _shutdown_workers_and_reset_iterator(iterable) def _dataset_length(self) -> int: - """Compute the total length of the datasets according to the `mode`.""" + """Compute the total length of the datasets according to the current mode.""" datasets = [getattr(dl, "dataset", None) for dl in self.flattened] lengths = [length for ds in datasets if (length := sized_len(ds)) is not None] if not lengths: raise NotImplementedError("All datasets are iterable-style datasets.") - fn = _supported_modes[self._mode]["fn"] + fn = _SUPPORTED_MODES[self._mode]["fn"] return fn(lengths) diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index 68e06237720b14..421e171435597d 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Any, NamedTuple, Sequence +from typing import Any, get_args, NamedTuple, Sequence import pytest import torch @@ -27,11 +27,12 @@ from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.combined_loader import ( + _LITERAL_SUPPORTED_MODES, _MaxSize, _MaxSizeCycle, _MinSize, _Sequential, - _supported_modes, + _SUPPORTED_MODES, CombinedLoader, ) from tests_pytorch.helpers.runif import RunIf @@ -517,7 +518,7 @@ def test_combined_dataloader_for_training_with_ddp(use_distributed_sampler, mode ) trainer.strategy.connect(model) trainer._data_connector.attach_data(model=model, train_dataloaders=dataloader) - fn = _supported_modes[mode]["fn"] + fn = _SUPPORTED_MODES[mode]["fn"] expected_length_before_ddp = fn([n1, n2]) expected_length_after_ddp = ( math.ceil(expected_length_before_ddp / trainer.num_devices) @@ -531,3 +532,7 @@ def test_combined_dataloader_for_training_with_ddp(use_distributed_sampler, mode assert isinstance(trainer.fit_loop._combined_loader, CombinedLoader) assert trainer.fit_loop._combined_loader._mode == mode assert trainer.num_training_batches == expected_length_after_ddp + + +def test_supported_modes(): + assert set(_SUPPORTED_MODES) == set(get_args(_LITERAL_SUPPORTED_MODES))