Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential off by 1 error when resuming training of mid-epoch checkpoint #19367

Open
ivnle opened this issue Jan 29, 2024 · 0 comments
Open

Potential off by 1 error when resuming training of mid-epoch checkpoint #19367

ivnle opened this issue Jan 29, 2024 · 0 comments
Labels
bug Something isn't working help wanted Open to be worked on loops Related to the Loop API ver: 2.1.x

Comments

@ivnle
Copy link

ivnle commented Jan 29, 2024

Bug description

During the fit loop, here's a simple log of the global_step and batch_idx values during on_train_batch_start and on_train_batch_end.

[TRAIN STEP START] trainer.global_step=0, batch_idx=0
[TRAIN STEP END] trainer.global_step=1, batch_idx=0
[TRAIN STEP START] trainer.global_step=1, batch_idx=1
[TRAIN STEP END] trainer.global_step=2, batch_idx=1
[TRAIN STEP START] trainer.global_step=2, batch_idx=2
[TRAIN STEP END] trainer.global_step=3, batch_idx=2
[TRAIN STEP START] trainer.global_step=3, batch_idx=3
[TRAIN STEP END] trainer.global_step=4, batch_idx=3
[TRAIN STEP START] trainer.global_step=4, batch_idx=4
[TRAIN STEP END] trainer.global_step=5, batch_idx=4
[VAL STEP START] trainer.global_step=5, batch_idx=0
[VAL STEP START] trainer.global_step=5, batch_idx=1
[VAL STEP START] trainer.global_step=5, batch_idx=2
[VAL STEP START] trainer.global_step=5, batch_idx=3
[VAL STEP START] trainer.global_step=5, batch_idx=4
[TRAIN STEP START] trainer.global_step=5, batch_idx=5
[TRAIN STEP END] trainer.global_step=6, batch_idx=5
[TRAIN STEP START] trainer.global_step=6, batch_idx=6
[TRAIN STEP END] trainer.global_step=7, batch_idx=6
[TRAIN STEP START] trainer.global_step=7, batch_idx=7
[TRAIN STEP END] trainer.global_step=8, batch_idx=7
[TRAIN STEP START] trainer.global_step=8, batch_idx=8
[TRAIN STEP END] trainer.global_step=9, batch_idx=8
[TRAIN STEP START] trainer.global_step=9, batch_idx=9
[TRAIN STEP END] trainer.global_step=10, batch_idx=9
[VAL STEP START] trainer.global_step=10, batch_idx=0
[VAL STEP START] trainer.global_step=10, batch_idx=1
[VAL STEP START] trainer.global_step=10, batch_idx=2
[VAL STEP START] trainer.global_step=10, batch_idx=3
[VAL STEP START] trainer.global_step=10, batch_idx=4
`Trainer.fit` stopped: `max_steps=10` reached.

Notice that global_step and batch_idx are equal during batch_start and global step is 1 greater than batch index for batch_end. Now, if I save a mid-epoch checkpoint after 5 training steps and resume training, I see the following

[TRAIN STEP START] trainer.global_step=5, batch_idx=4
[TRAIN STEP END] trainer.global_step=6, batch_idx=4
[VAL STEP START] trainer.global_step=6, batch_idx=0
[VAL STEP START] trainer.global_step=6, batch_idx=1
[VAL STEP START] trainer.global_step=6, batch_idx=2
[VAL STEP START] trainer.global_step=6, batch_idx=3
[VAL STEP START] trainer.global_step=6, batch_idx=4
[TRAIN STEP START] trainer.global_step=6, batch_idx=5
[TRAIN STEP END] trainer.global_step=7, batch_idx=5
[TRAIN STEP START] trainer.global_step=7, batch_idx=6
[TRAIN STEP END] trainer.global_step=8, batch_idx=6
[TRAIN STEP START] trainer.global_step=8, batch_idx=7
[TRAIN STEP END] trainer.global_step=9, batch_idx=7
[TRAIN STEP START] trainer.global_step=9, batch_idx=8
[TRAIN STEP END] trainer.global_step=10, batch_idx=8
`Trainer.fit` stopped: `max_steps=10` reached.

Now the two values are off by 1 during batch start and off by 2 during batch end. This seems to be an issue because it changes when validation and checkpointing is run. In both runs, I have Trainer(val_check_interval=5, ...) and ModelCheckpoint(every_n_train_steps=5, ...). In the original run, validation happens after 5 and 10 training steps, as expected. In the resumed run, validation only happens once after 6 training steps.

My initial guess is that this is happening because self.batch_progress.increment_completed() in src/lightning/pytorch/loops/training_epoch_loop.py is called after

call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
call._call_lightning_module_hook(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
trainer._logger_connector.on_batch_end()

so the checkpoint thinks we've only completed global_steps-1 training steps.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

import lightning as L
import torch
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer, WikiText2
from lightning.pytorch.callbacks import ModelCheckpoint, Callback
from torch.utils.data import DataLoader, random_split


class LanguageModel(L.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        self.model = Transformer(vocab_size=vocab_size)
        self.model = torch.compile(self.model)

    def training_step(self, batch, batch_idx):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("test_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1)

    # def on_train_batch_end(self, outputs, batch, batch_idx):
    #     print(f"{self.trainer.global_step=}, {batch_idx=}")


class MyCallback(Callback):
    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        print(f"[TRAIN STEP START] {trainer.global_step=}, {batch_idx=}")

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        print(f"[TRAIN STEP END] {trainer.global_step=}, {batch_idx=}")

    def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
        print(f"[VAL STEP START] {trainer.global_step=}, {batch_idx=}")


class _ModelCheckpoint(ModelCheckpoint):
    """Modified version of ModelCheckpoint that saves the model after fit completes."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


def main():
    L.seed_everything(42)

    # Data
    dataset = WikiText2()

    # Split data in to train, val, test
    n = len(dataset)
    train_dataset, val_dataset, test_dataset = random_split(
        dataset, [n - 200, 100, 100]
    )
    train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False)
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    print(f"{len(train_dataset)=}")

    # Model
    model = LanguageModel(vocab_size=dataset.vocab_size)

    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        every_n_train_steps=5,
        save_top_k=-1,
        enable_version_counter=False,
        verbose=True,
    )
    my_callback = MyCallback()

    # Trainer
    trainer = L.Trainer(
        max_steps=10,
        val_check_interval=5,
        limit_val_batches=5,
        callbacks=[my_callback, checkpoint_callback],
        enable_progress_bar=False,
    )
    trainer.fit(model, train_dataloader, val_dataloader)
    # trainer.test(model, test_dataloader)

    # Resume training from checkpoint
    ckpt_path = "lightning_logs/version_0/checkpoints/epoch=0-step=5.ckpt"
    print(f"Resuming training from checkpoint {ckpt_path}")
    trainer.fit(
        model,
        train_dataloader,
        val_dataloader,
        ckpt_path=ckpt_path,
    )


if __name__ == "__main__":
    main()

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @carmocca @justusschock

@ivnle ivnle added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jan 29, 2024
@awaelchli awaelchli added loops Related to the Loop API help wanted Open to be worked on and removed needs triage Waiting to be triaged by maintainers labels Jan 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on loops Related to the Loop API ver: 2.1.x
Projects
None yet
Development

No branches or pull requests

2 participants