Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Step when validation happens drifts for val_check_interval when gradient accumulation turned on #17207

Open
hrukalive opened this issue Mar 27, 2023 · 4 comments
Labels
bug Something isn't working help wanted Open to be worked on loops Related to the Loop API ver: 2.0.x
Milestone

Comments

@hrukalive
Copy link

hrukalive commented Mar 27, 2023

Bug description

First of all, my task relies on step count instead of epochs. So I am doing validation checks by steps and saving checkpoints after that. However, as I turned gradient accumulation on, and the batch count is not divisible, I encountered weird drifts for the actual step when the validation is performed, and thus the checkpointing.

In the example below, I override the _save_checkpoint function to monitor the actual file name and it turns out to be drifting. My general setting is val_check_interval=accumulation*5 to make it validate every 5 effective optimizer steps, accumulation=3 and #batches=67 so there is one batch leftover.

How to reproduce the bug

import numpy as np
import pathlib

import time
import torch
import torch.nn as nn
import torch.optim

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint

class Quadratic(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.a = nn.Parameter(torch.tensor(0.0))
        self.b = nn.Parameter(torch.tensor(0.0))
        self.c = nn.Parameter(torch.tensor(0.0))

    def forward(self, x):
        time.sleep(0.02)
        return self.a * x * x + self.b * x + self.c
    
    def _common_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        return loss 

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

class CustomModelCheckpoint(ModelCheckpoint):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
        monitor_candidates = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in self._monitor_candidates(trainer).items()}
        print("\n", "Save checkpoint, global_step: ", trainer.global_step, pathlib.Path(filepath).stem, "monitor_candidates: " + str(monitor_candidates), "\n", flush=True)
        
    def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
        # print("Remove checkpoint: ", filepath, flush=True)
        pass
    
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-a', type=float, default=2.0)
    parser.add_argument('-b', type=float, default=3.0)
    parser.add_argument('-c', type=float, default=4.0)
    parser.add_argument('--epoch', type=int, default=500)
    args = parser.parse_args()

    x = torch.from_numpy(np.random.uniform(-10, 10, 2144)).float() # Make 67 batches
    y = args.a * x * x + args.b * x + args.c
    x2 = torch.from_numpy(np.random.uniform(-10, 10, 100)).float()
    y2 = args.a * x2 * x2 + args.b * x2 + args.c

    dataset = torch.utils.data.TensorDataset(x, y)
    val_dataset = torch.utils.data.TensorDataset(x2, y2)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

    model = Quadratic()
    
    ####
    accumulate_grad_batches = 3
    val_check_interval = 5 * accumulate_grad_batches # to make interval for effective batches
    ####

    trainer = pl.Trainer(max_epochs=args.epoch, accelerator='cpu', callbacks=[CustomModelCheckpoint(
                    dirpath='.',
                    filename='steps_{step}',
                    monitor='step',
                    mode='max',
                    save_last=False,
                    save_top_k=5
                )],
            val_check_interval=val_check_interval,
            check_val_every_n_epoch=None,
            num_sanity_val_steps=0,
            accumulate_grad_batches=accumulate_grad_batches)
    trainer.fit(model, dataloader, val_dataloader)
    
    # Print the results
    print("a = ", model.a.item())
    print("b = ", model.b.item())
    print("c = ", model.c.item())

Error messages and logs

Save checkpoint, global_step:  5 steps_step=5 monitor_candidates: {'epoch': 0, 'step': 5}
Save checkpoint, global_step:  10 steps_step=10 monitor_candidates: {'epoch': 0, 'step': 10}
Save checkpoint, global_step:  15 steps_step=15 monitor_candidates: {'epoch': 0, 'step': 15}
Save checkpoint, global_step:  20 steps_step=20 monitor_candidates: {'epoch': 0, 'step': 20}
Save checkpoint, global_step:  25 steps_step=25 monitor_candidates: {'epoch': 1, 'step': 25}
Save checkpoint, global_step:  30 steps_step=30 monitor_candidates: {'epoch': 1, 'step': 30}
Save checkpoint, global_step:  35 steps_step=35 monitor_candidates: {'epoch': 1, 'step': 35}
Save checkpoint, global_step:  40 steps_step=40 monitor_candidates: {'epoch': 1, 'step': 40}

Save checkpoint, global_step:  46 steps_step=46 monitor_candidates: {'epoch': 2, 'step': 46}  <-- drift
Save checkpoint, global_step:  51 steps_step=51 monitor_candidates: {'epoch': 2, 'step': 51}
Save checkpoint, global_step:  56 steps_step=56 monitor_candidates: {'epoch': 2, 'step': 56}

Environment

Current environment
* CUDA:
        - GPU:
                - NVIDIA RTX A5000
                - NVIDIA RTX A5000
                - NVIDIA RTX A5000
                - NVIDIA RTX A5000
        - available:         True
        - version:           11.7
* Lightning:
        - lightning:         2.0.0
        - lightning-cloud:   0.5.32
        - lightning-lite:    1.8.6
        - lightning-utilities: 0.8.0
        - pytorch-lightning: 2.0.0
        - torch:             1.13.1
        - torchaudio:        0.13.1
        - torchcrepe:        0.0.17
        - torchmetrics:      0.11.4
        - torchvision:       0.14.1
* Packages:
        - absl-py:           1.3.0
        - aiobotocore:       2.4.2
        - aiohttp:           3.8.4
        - aioitertools:      0.11.0
        - aiosignal:         1.3.1
        - altgraph:          0.17.3
        - anyio:             3.6.2
        - appdirs:           1.4.4
        - arrow:             1.2.3
        - async-timeout:     4.0.2
        - attrs:             22.2.0
        - audioread:         3.0.0
        - backcall:          0.2.0
        - beautifulsoup4:    4.12.0
        - blessed:           1.20.0
        - blinker:           1.4
        - botocore:          1.27.59
        - brotlipy:          0.7.0
        - cachetools:        5.3.0
        - certifi:           2022.12.7
        - cffi:              1.15.1
        - charset-normalizer: 2.0.4
        - click:             8.1.3
        - contourpy:         1.0.7
        - croniter:          1.3.8
        - cryptography:      39.0.1
        - cycler:            0.11.0
        - dateutils:         0.6.12
        - decorator:         5.1.1
        - deepdiff:          6.3.0
        - distance:          0.1.3
        - dnspython:         2.3.0
        - einops:            0.6.0
        - email-validator:   1.3.1
        - et-xmlfile:        1.0.1
        - fastapi:           0.88.0
        - fire:              0.5.0
        - flit-core:         3.8.0
        - fonttools:         4.39.2
        - frozenlist:        1.3.3
        - fsspec:            2023.3.0
        - future:            0.18.2
        - g2p-en:            2.1.0
        - g2pm:              0.1.2.5
        - google-auth:       2.16.3
        - google-auth-oauthlib: 0.4.6
        - grpcio:            1.51.3
        - h11:               0.14.0
        - h5py:              3.7.0
        - httpcore:          0.16.3
        - httptools:         0.5.0
        - httpx:             0.23.3
        - idna:              3.4
        - imageio:           2.23.0
        - importlib-metadata: 6.1.0
        - inflect:           6.0.2
        - inquirer:          3.1.3
        - itsdangerous:      2.1.2
        - jinja2:            3.1.2
        - jmespath:          1.0.1
        - joblib:            1.2.0
        - kiwisolver:        1.4.4
        - librosa:           0.9.1
        - lightning:         2.0.0
        - lightning-cloud:   0.5.32
        - lightning-lite:    1.8.6
        - lightning-utilities: 0.8.0
        - llvmlite:          0.39.1
        - markdown:          3.4.3
        - markdown-it-py:    2.2.0
        - markupsafe:        2.1.2
        - matplotlib:        3.6.2
        - mdurl:             0.1.2
        - mkl-fft:           1.3.1
        - mkl-random:        1.2.2
        - mkl-service:       2.4.0
        - multidict:         6.0.4
        - networkx:          3.0
        - nltk:              3.8.1
        - numba:             0.56.4
        - numpy:             1.23.5
        - oauthlib:          3.2.2
        - ordered-set:       4.1.0
        - orjson:            3.8.8
        - packaging:         23.0
        - pillow:            9.4.0
        - pip:               23.0.1
        - platformdirs:      3.1.1
        - pooch:             1.7.0
        - praat-parselmouth: 0.4.3
        - protobuf:          3.13.0
        - psutil:            5.9.4
        - pyasn1:            0.4.8
        - pyasn1-modules:    0.2.8
        - pycparser:         2.21
        - pycwt:             0.3.0a22
        - pydantic:          1.10.7
        - pygments:          2.14.0
        - pyjwt:             2.6.0
        - pyloudnorm:        0.1.0
        - pyopenssl:         23.0.0
        - pyparsing:         3.0.9
        - pypinyin:          0.39.0
        - pysocks:           1.7.1
        - python-dateutil:   2.8.2
        - python-dotenv:     1.0.0
        - python-editor:     1.0.4
        - python-levenshtein: 0.12.2
        - python-multipart:  0.0.6
        - pytorch-lightning: 2.0.0
        - pytz:              2022.7.1
        - pywavelets:        1.4.1
        - pyyaml:            6.0
        - readchar:          4.0.5
        - regex:             2023.3.23
        - requests:          2.28.1
        - requests-oauthlib: 1.3.1
        - resampy:           0.4.2
        - resemblyzer:       0.1.1.dev0
        - rfc3986:           1.5.0
        - rich:              13.3.2
        - rsa:               4.9
        - s3fs:              2023.3.0
        - scikit-image:      0.19.3
        - scikit-learn:      1.2.2
        - scipy:             1.9.3
        - setuptools:        65.6.3
        - six:               1.16.0
        - snakeviz:          2.1.1
        - sniffio:           1.3.0
        - soundfile:         0.12.1
        - soupsieve:         2.4
        - starlette:         0.22.0
        - starsessions:      1.3.0
        - tensorboard:       2.11.0
        - tensorboard-data-server: 0.6.1
        - tensorboard-plugin-wit: 1.8.1
        - tensorboardx:      2.6
        - termcolor:         2.2.0
        - threadpoolctl:     3.1.0
        - tifffile:          2023.3.21
        - torch:             1.13.1
        - torchaudio:        0.13.1
        - torchcrepe:        0.0.17
        - torchmetrics:      0.11.4
        - torchvision:       0.14.1
        - tornado:           6.2
        - tqdm:              4.65.0
        - traitlets:         5.9.0
        - typing:            3.7.4.3
        - typing-extensions: 4.4.0
        - ujson:             5.7.0
        - urllib3:           1.26.14
        - uvicorn:           0.21.1
        - uvloop:            0.17.0
        - watchfiles:        0.18.1
        - wcwidth:           0.2.6
        - webrtcvad:         2.0.10
        - websocket-client:  1.5.1
        - websockets:        10.4
        - werkzeug:          2.2.3
        - wheel:             0.38.4
        - wrapt:             1.15.0
        - yarl:              1.8.2
        - zipp:              3.15.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.9.16
        - version:           #153-Ubuntu SMP Thu Nov 24 15:56:58 UTC 2022

More info

Other than this phenomenon, I have two more questions

  1. Why is val_check_interval tied to the number of batches rather than global_step?
  2. Why is validation re-run after loading a checkpoint just saved after the validation step? This is also going to produce a duplicate checkpoint, which is very frustrating

cc @carmocca @justusschock

@hrukalive hrukalive added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Mar 27, 2023
@hrukalive hrukalive changed the title Step of checkpointing drifts for val_check_interval when gradient accumulation turned on Step when validation happens drifts for val_check_interval when gradient accumulation turned on Mar 29, 2023
@hrukalive
Copy link
Author

I think it is actually the moment when validation happens drift. The checkpoint saving is just a side effect.

@awaelchli awaelchli added loops Related to the Loop API help wanted Open to be worked on and removed needs triage Waiting to be triaged by maintainers labels May 1, 2023
@awaelchli awaelchli modified the milestones: 2.0.x, v1.9.x May 1, 2023
@bkiat1123
Copy link
Contributor

Validation check tracks training batches instead of training steps. According to the documentation,

An int value can only be higher than the number of training batches when check_val_every_n_epoch=None, which validates after every N training batches across epochs or during iteration-based training.

However, training batches does not always equal to training steps (global steps).

Training step is (total_batch_idx // accumulate_grad_batches) + (accumulates_on_final_batch * epoch_trained). The accumulates_on_final_batch is where the draft happens.

I think it would make sense to validate after N training steps instead of training batches. Other module such as Logger and Model Checkpoint use global steps to track training steps too.

I propose we can change from

https://github.com/Lightning-AI/lightning/blob/83f683243dde71898e2110b12fd5c78ebac5418b/src/lightning/pytorch/loops/training_epoch_loop.py#L392-L398

to

elif self.trainer.val_check_batch != float("inf"):
    # if `check_val_every_n_epoch is` None`, run a validation loop every n training steps
    # else condition it based on the batch_idx of the current epoch
    next_iteration = self.global_step if self.trainer.check_val_every_n_epoch is None else self.batch_idx + 1
    is_val_check_batch = next_iteration % self.trainer.val_check_batch == 0

@Borda Borda modified the milestones: 2.0.x, 2.1.x Oct 12, 2023
@awaelchli awaelchli modified the milestones: 2.1.x, 2.2.x Feb 8, 2024
@anjali-chadha
Copy link

Is there a plan to add step-based validation checks in Lightning?

Until Lightning adds official support for this, any recommendations on how we can override the default Lightning behavior and use number of steps instead of batches to trigger the validation?

@hrukalive
Copy link
Author

Is there a plan to add step-based validation checks in Lightning?

Until Lightning adds official support for this, any recommendations on how we can override the default Lightning behavior and use number of steps instead of batches to trigger the validation?

Right now, for myself, I have to discard the last batch to make steps multiples of accum grad.

@awaelchli awaelchli modified the milestones: 2.2.x, 2.3.x Jun 13, 2024
@awaelchli awaelchli modified the milestones: 2.3.x, 2.4.x Aug 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on loops Related to the Loop API ver: 2.0.x
Projects
None yet
Development

No branches or pull requests

5 participants