Skip to content

Commit

Permalink
Sequential CombinedLoader to flatten the eval and predict loops (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Feb 17, 2023
1 parent ccd2a48 commit ec4f592
Show file tree
Hide file tree
Showing 42 changed files with 1,209 additions and 1,393 deletions.
18 changes: 18 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a new method `Strategy.on_exception` to the strategy base interface ([#16646](https://github.com/Lightning-AI/lightning/pull/16646))


- Added support for `predict_step(dataloader_iter, batch_index)` ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))


- Added support for arbitrary iterables as dataloaders ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))


- Added "sequential" mode support to `CombinedLoader` to consume multiple iterables in sequence ([#16743](https://github.com/Lightning-AI/lightning/pull/16743), [#16784](https://github.com/Lightning-AI/lightning/pull/16784))

### Changed
Expand Down Expand Up @@ -87,6 +93,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Renamed `CombinedLoader.loaders` to `CombinedLoader.iterables` ([#16743](https://github.com/Lightning-AI/lightning/pull/16743))


- The top-level loops now own the data sources and combined dataloaders ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))


- The `trainer.*_dataloader` properties now return what the user returned in their `LightningModule.*_dataloader()` hook ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))


- The `dataloader_idx` argument is now optional for the `on_{validation,test,predict}_batch_{start,end}` hooks. Remove it or default it to 0 if you don't use multiple dataloaders ([#16753](https://github.com/Lightning-AI/lightning/pull/16753))


Expand Down Expand Up @@ -210,6 +222,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* The fetching classes are now marked as protected ([#16664](https://github.com/Lightning-AI/lightning/pull/16664))


- Removed the `DataLoaderLoop`, `EvaluationEpochLoop`, and `PredictionEpochLoop` classes ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))


- Removed `trainer.reset_*_dataloader()` methods in favor of `Loop.setup_data()` for the top-level loops ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))


- Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172))
* Removed the `LightningModule.truncated_bptt_steps` attribute
* Removed the `LightningModule.tbptt_split_batch` hook
Expand Down
20 changes: 10 additions & 10 deletions src/lightning/pytorch/callbacks/batch_size_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,24 +128,25 @@ def __init__(
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if trainer._accelerator_connector.is_distributed:
raise MisconfigurationException("The Batch size finder is not supported with distributed strategies.")

running_stage = trainer.state.stage
assert running_stage is not None
dl_source = getattr(trainer._data_connector, f"_{running_stage.dataloader_prefix}_dataloader_source")

# TODO: check if this can be enabled (#4040)
if not trainer._data_connector._train_dataloader_source.is_module():
if not trainer.fit_loop._data_source.is_module():
raise MisconfigurationException(
"The Batch size finder cannot be used with dataloaders passed directly to `.fit()`. Please disable"
" the feature or incorporate the dataloader into your LightningModule or LightningDataModule."
)

# TODO: Add support for multiple eval dataloader
if stage != "fit":
dataloaders = dl_source.dataloader()
if isinstance(dataloaders, list) and len(dataloaders) > 1:
loop = trainer._active_loop
assert loop is not None
loop.setup_data()
combined_loader = loop._combined_loader
assert combined_loader is not None
if len(combined_loader._flattened) > 1:
stage = trainer.state.stage
assert stage is not None
raise MisconfigurationException(
f"The Batch size finder cannot be used with multiple {running_stage.dataloader_prefix} dataloaders."
f"The Batch size finder cannot be used with multiple {stage.dataloader_prefix} dataloaders."
)

if not lightning_hasattr(pl_module, self._batch_arg_name):
Expand All @@ -167,7 +168,6 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
new_size = _scale_batch_size(
trainer,
pl_module,
self._mode,
self._steps_per_trial,
self._init_val,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def on_predict_batch_end(
) -> None:
if not self.interval.on_batch:
return
batch_indices = trainer.predict_loop.epoch_loop.current_batch_indices
batch_indices = trainer.predict_loop.current_batch_indices
self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx)

def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/pytorch/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning.pytorch.loops.loop import _Loop # noqa: F401 isort: skip (avoids circular imports)
from lightning.pytorch.loops.dataloader import _DataLoaderLoop, _EvaluationLoop, _PredictionLoop # noqa: F401
from lightning.pytorch.loops.epoch import _EvaluationEpochLoop, _PredictionEpochLoop, _TrainingEpochLoop # noqa: F401
from lightning.pytorch.loops.epoch import _TrainingEpochLoop # noqa: F401
from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop # noqa: F401
from lightning.pytorch.loops.fit_loop import _FitLoop # noqa: F401
from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization # noqa: F401
from lightning.pytorch.loops.prediction_loop import _PredictionLoop # noqa: F401
17 changes: 0 additions & 17 deletions src/lightning/pytorch/loops/dataloader/__init__.py

This file was deleted.

68 changes: 0 additions & 68 deletions src/lightning/pytorch/loops/dataloader/dataloader_loop.py

This file was deleted.

178 changes: 0 additions & 178 deletions src/lightning/pytorch/loops/dataloader/prediction_loop.py

This file was deleted.

2 changes: 0 additions & 2 deletions src/lightning/pytorch/loops/epoch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from lightning.pytorch.loops.epoch.evaluation_epoch_loop import _EvaluationEpochLoop # noqa: F401
from lightning.pytorch.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop # noqa: F401
from lightning.pytorch.loops.epoch.training_epoch_loop import _TrainingEpochLoop # noqa: F401
Loading

0 comments on commit ec4f592

Please sign in to comment.