Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trainer.fit_loop.setup_data() does not refresh train dataset in LightningModule #17327

Open
LarsKue opened this issue Apr 11, 2023 · 6 comments
Labels
data handling Generic data-related topic feature Is an improvement or enhancement loops Related to the Loop API pl Generic label for PyTorch Lightning package ver: 2.0.x

Comments

@LarsKue
Copy link

LarsKue commented Apr 11, 2023

Bug description

PR #16726 replaces the reset_*_dataloader() method calls with the respective Loop.setup_data() calls. This is also mentioned in the migration guide.

However, on versions <= 1.9, calling reset_train_dataloader() would reinstantiate the dataloader from a LightningModule's train_dataloader() method. This behaviour is now gone.

My specific use case is that I need to update the dataset of my model during training. I then use on_train_epoch_end() or a similar hook to call reset_train_dataloader(), to have the updated dataset in the next training epoch. I posted a minimal example below. You can run this example on both v1.9 and v2.0 to see the exact difference. v1.9 runs without problems, whereas v2.0 fails the second assertion in training_step(). I tested it on a fresh conda env install of both versions using python 3.10.

In case I am using the wrong loop to call setup_data() or am using the new interface incorrectly, please let me know. In that case I would also recommend providing some more hints in the migration guide or on PR #16726 since the current advice is not exactly clear. (i.e. which loops are "top level"?)

What version are you seeing the problem on?

2.0+

How to reproduce the bug

try:
    import lightning
except ModuleNotFoundError:
    import pytorch_lightning as lightning

import torch
from torch.utils.data import DataLoader, TensorDataset


class Model(lightning.LightningModule):

    def __init__(self):
        super().__init__()
        self.train_data = TensorDataset(torch.zeros(1, 1))

    def configure_optimizers(self):
        return None

    def on_train_epoch_end(self):
        self.train_data = TensorDataset(torch.ones(1, 1))

        if int(lightning.__version__[0]) < 2:
            # for version < 2.0 (works)
            self.trainer.reset_train_dataloader()
        else:
            # for version >= 2.0 (does not work)
            self.trainer.fit_loop.setup_data()

    def train_dataloader(self):
        return DataLoader(
            self.train_data,
        )

    def training_step(self, batch, batch_idx):
        # de-tuple
        batch = batch[0]

        if self.trainer.global_step == 0:
            assert torch.allclose(batch, torch.zeros_like(batch))
        else:
            # this assertion fails on lightning v2.0
            assert torch.allclose(batch, torch.ones_like(batch))

        return torch.tensor(0.0, requires_grad=True)


model = Model()
trainer = lightning.Trainer(max_steps=2)
trainer.fit(model)

Error messages and logs

File "/home/lars/code/python/lightning-trainable/playground.py", line 38, in training_step
    assert torch.allclose(batch, torch.ones_like(batch))
AssertionError

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): Trainer, FitLoop, LightningModule
#- PyTorch Lightning Version (e.g., 1.5.0): 1.9 / 2.0
#- Lightning App Version (e.g., 0.5.2): -
#- PyTorch Version (e.g., 2.0): 1.9 / 2.0
#- Python version (e.g., 3.9): 3.10
#- OS (e.g., Linux): Ubuntu
#- CUDA/cuDNN version: 11.7
#- GPU models and configuration: RTX 2070
#- How you installed Lightning(`conda`, `pip`, source): conda
#- Running environment of LightningApp (e.g. local, cloud): local

More info

No response

cc @Borda @justusschock @awaelchli @carmocca

@LarsKue LarsKue added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 11, 2023
@carmocca
Copy link
Contributor

You can do Trainer(reload_dataloaders_every_n_epochs=1) to accomplish this

@carmocca carmocca added data handling Generic data-related topic question Further information is requested and removed bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 11, 2023
@LarsKue
Copy link
Author

LarsKue commented Apr 11, 2023

You can do Trainer(reload_dataloaders_every_n_epochs=1) to accomplish this

This solution is unsatisfactory since I want to a) avoid reloading every epoch and b) be able to reload irregularly and on command.

I also think you should re-add the bug tag, since the functionality of setup_data still seems broken to me, even if there technically is a different way to do this.

@carmocca
Copy link
Contributor

This is "working as expected" given the current design of setup_data, which doesn't run if the data is already setup and the trainer flag is not configured, see this early exit: https://github.com/Lightning-AI/lightning/blob/b2717f68789638f34bd9baca2d74b62a06c16ca9/src/lightning/pytorch/loops/fit_loop.py#L210-L211

If you make that if statement not trigger, you'll see your code passing. For example by adding trainer.fit_loop._combined_loader = None before you call setup_data

The easiest way to change this would be to add a force: bool flag to setup_data so that you can skip that logic, making this a feature

@carmocca carmocca added feature Is an improvement or enhancement loops Related to the Loop API pl Generic label for PyTorch Lightning package and removed question Further information is requested labels Apr 11, 2023
@awaelchli
Copy link
Contributor

An idea for this part:

This solution is unsatisfactory since I want to a) avoid reloading every epoch and b) be able to reload irregularly and on command.

You could still set Trainer(reload_dataloaders_every_n_epochs=1) just so that the trainer calls the dataloader methods. In there, you can still decide whether you actually want to rebuild the dataloaders or just return the cached one:

    def train_dataloader(self):
        if condition:
            # recreate
            self.train_dl = DataLoader(
                self.train_data,
            )
       
        return self.train_dl

@albertfgu
Copy link

👍 to this issue.

  1. The migration guide is pretty sparse on details. It just says to replace trainer.reset_*_dataloader() with Loop.setup_data() which is vague. It took me some time to figure out how to actually invoke it, I had to set a breakpoint and dir() the Trainer and guess and check to arrive at trying trainer.fit_loop.setup_data(). More documentation would definitely help.
  2. I agree with the original poster that the fact that this replacement does not, in fact, have the same functionality as the old one, is unexpected behavior and worth documenting as well.
  3. The suggested solution by @carmocca for adding trainer.fit_loop._combined_loader = None seems to have done the trick for me. I also agree with the suggestion to make this a boolean flag passed to Loop.setup_data

@lorenzomammana
Copy link

@Borda @carmocca

Adding on this discussion, I also have a custom callback that was using the reset_xyz_dataloader that I'm migrating to Lightning (right now I'm using 2.1.3)

class FeatureExtractorCallback(Callback):

    def __init__(self, devices, feature_extractor: nn.Module) -> None:
        super().__init__()
        self.devices = devices
        self.feature_extractor = feature_extractor

    @rank_zero_only
    def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        """Called when fit begins."""
        if not hasattr(trainer, "datamodule"):
            raise ValueError("Trainer must have a datamodule attribute.")

        log.info("Extracting features!")
        device = f"cuda:{self.devices[0]}" if (self.devices is not None and isinstance(self.devices, list)) else "cpu"
        self.feature_extractor.to(device)
        train_dataset = convert_feature_dataset(trainer.datamodule.train_dataloader(), self.feature_extractor, device)
        val_dataset = convert_feature_dataset(trainer.datamodule.val_dataloader(), self.feature_extractor, device)
        trainer.datamodule.train_dataset = train_dataset
        trainer.datamodule.val_dataset = val_dataset
        trainer._should_reload_train_dl = True
        trainer._should_reload_val_dl = True
        trainer.fit_loop.setup_data()
        trainer.validate_loop.setup_data()
        trainer._should_reload_train_dl = False
        trainer._should_reload_val_dl = False
        # trainer.reset_train_dataloader()
        # trainer.reset_val_dataloader()


    @rank_zero_only
    def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        """Called when test begins."""
        if not hasattr(trainer, "datamodule"):
            raise ValueError("Trainer must have a datamodule attribute.")

        log.info("Extracting features!")
        device = f"cuda:{self.devices[0]}" if (self.devices is not None and isinstance(self.devices, list)) else "cpu"
        self.feature_extractor.to(device)
        test_dataset = convert_feature_dataset(trainer.datamodule.test_dataloader(), self.feature_extractor, device)
        trainer.datamodule.test_dataset = test_dataset
        trainer._should_reload_val_dl = True
        trainer.test_loop.setup_data()
        trainer._should_reload_val_dl = False
        # trainer.reset_test_dataloader()

The on_fit_start part is working, training and validation datamodule is changed and reloaded and I receive batches coming from the newly created dataset correctly.

But test is not working! Debugging I see that the test_loop setup_data is called properly and the datamodule is loaded with the new batch (_combined_loader contains a reference to the correct dataset), but when I receive the batch in the test_step it is the one coming from the original test dataset and not from the updated one, any idea on this? It seems to be a bug to me but I can't get what's causing the issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data handling Generic data-related topic feature Is an improvement or enhancement loops Related to the Loop API pl Generic label for PyTorch Lightning package ver: 2.0.x
Projects
None yet
Development

No branches or pull requests

6 participants