Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when disabling an optimizer with native AMP turned on #20116

Open
schopra8 opened this issue Jul 22, 2024 · 1 comment
Open

Error when disabling an optimizer with native AMP turned on #20116

schopra8 opened this issue Jul 22, 2024 · 1 comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.2.x

Comments

@schopra8
Copy link

schopra8 commented Jul 22, 2024

Bug description

I'm using 2 optimizers and trying to train with AMP (FP16). I can take steps with my first optimizer. When I take my first step with the second optimizer I get the following error:

  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/amp/grad_scaler.py", line 450, in step
    len(optimizer_state["found_inf_per_device"]) > 0
AssertionError: No inf checks were recorded for this optimizer.

I can train this correctly in FP32 -- so it seems to be an issue with AMP.

What version are you seeing the problem on?

version 2.3.3

How to reproduce the bug

def training_step(self, batch: Dict, batch_idx: int):
        """
        We have 2 sets of optimizers.
        Every N batches (self.n_batches_per_optimizer), we make an optimizer update and
        switch the optimizer to update.

        If self.n_batches_per_optimizer = 1, then we make updates every batch and alternate optimizers
        every batch.

        If self.n_batches_per_optimizer > 1, then we're doing gradient accumulartion, where we are making
        updates evern n_batches_per_optimizer batches and alternating optimizers every n_batches_per_optimizer
        batches.
        """
        opts = self.optimizers()
        current_cycle = (batch_idx // self.n_batches_per_optimizer) % len(opts)
        opt = opts[current_cycle]
        opt.zero_grad()

        if current_cycle == 0:
            compute_model_1_loss = True
        elif current_cycle == 1:
            compute_model_1_loss = False
        else:
            raise NotImplementedError(f"Unknown optimizer {current_cycle}")

        with opt.toggle_model():
            loss = self.inner_training_step(batch=batch, compute_model_1_loss=compute_model_1_loss)
            self.manual_backward(loss=loss)

            # Perform the optimization step every accumulate_grad_batches steps
            if (batch_idx + 1) % self.n_batches_per_optimizer == 0:
                if not compute_model_1_loss:
                    print("About to take compute model 2 loss ...")
                opt.step()
                opt.zero_grad()     

Error messages and logs

Traceback (most recent call last):
  File "/home/sahil/train.py", line 82, in <module>
    main(config)
  File "/home/sahil/train.py", line 62, in main
    trainer.fit(model, datamodule=data_module, ckpt_path=ckpt)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
    call._call_and_handle_interrupt(
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
    results = self._run_stage()
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1030, in _run_stage
    self.fit_loop.run()
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 252, in advance
    batch_output = self.manual_optimization.run(kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/manual.py", line 94, in run
    self.advance(kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/manual.py", line 114, in advance
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 390, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
  File "/home/sahil/model/model.py", line 169, in training_step
    opt.step()
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 153, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 238, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/amp.py", line 93, in optimizer_step
    step_output = self.scaler.step(optimizer, **kwargs)  # type: ignore[arg-type]
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/amp/grad_scaler.py", line 450, in step
    len(optimizer_state["found_inf_per_device"]) > 0
AssertionError: No inf checks were recorded for this optimizer.

Environment

Current environment
* CUDA:
        - GPU:
                - NVIDIA A100-SXM4-80GB
        - available:         True
        - version:           12.1
* Lightning:
        - lightning-utilities: 0.11.5
        - pytorch-lightning: 2.3.3
        - torch:             2.3.1
        - torchmetrics:      1.4.0.post0
        - torchvision:       0.18.1
* Packages:
        - aiohttp:           3.9.5
        - aiosignal:         1.3.1
        - annotated-types:   0.7.0
        - antlr4-python3-runtime: 4.9.3
        - anyio:             4.4.0
        - argon2-cffi:       23.1.0
        - argon2-cffi-bindings: 21.2.0
        - arrow:             1.3.0
        - asttokens:         2.4.1
        - async-lru:         2.0.4
        - async-timeout:     4.0.3
        - attrs:             23.2.0
        - autocommand:       2.2.2
        - babel:             2.15.0
        - backports.tarfile: 1.2.0
        - beautifulsoup4:    4.12.3
        - bitsandbytes:      0.43.1
        - bleach:            6.1.0
        - boto3:             1.34.144
        - botocore:          1.34.144
        - braceexpand:       0.1.7
        - certifi:           2024.7.4
        - nvidia-curand-cu12: 10.3.2.106
        - nvidia-cusolver-cu12: 11.4.5.107
        - nvidia-cusparse-cu12: 12.1.0.106
        - nvidia-nccl-cu12:  2.20.5
        - nvidia-nvjitlink-cu12: 12.5.82
        - nvidia-nvtx-cu12:  12.1.105
        - omegaconf:         2.3.0
        - opencv-python:     4.10.0.84
        - ordered-set:       4.1.0
        - overrides:         7.7.0
        - packaging:         24.1
        - pandocfilters:     1.5.1
        - parso:             0.8.4
        - pexpect:           4.9.0
        - pillow:            10.4.0
        - pip:               24.1
        - platformdirs:      4.2.2
        - pre-commit:        3.7.1
        - proglog:           0.1.10
        - prometheus-client: 0.20.0
        - prompt-toolkit:    3.0.47
        - protobuf:          5.27.2
        - psutil:            6.0.0
        - ptyprocess:        0.7.0
        - pure-eval:         0.2.2
        - pycparser:         2.22
        - pydantic:          2.8.2
        - pydantic-core:     2.20.1
        - pydantic-settings: 2.3.4
        - pygments:          2.18.0
        - python-dateutil:   2.9.0.post0
        - python-dotenv:     1.0.1
        - python-json-logger: 2.0.7
        - pytorch-lightning: 2.3.3
        - pyyaml:            6.0.1
        - pyzmq:             26.0.3
        - referencing:       0.35.1
        - requests:          2.32.3
        - rfc3339-validator: 0.1.4
        - rfc3986-validator: 0.1.1
        - rpds-py:           0.19.0
        - s3transfer:        0.10.2
        - send2trash:        1.8.3
        - sentry-sdk:        2.10.0
        - setproctitle:      1.3.3
        - setuptools:        71.0.2
        - six:               1.16.0
        - smmap:             5.0.1
        - sniffio:           1.3.1
        - soupsieve:         2.5
        - stack-data:        0.6.3
        - sympy:             1.13.0
        - terminado:         0.18.1
        - tinycss2:          1.3.0
        - tomli:             2.0.1
        - torch:             2.3.1
        - torchmetrics:      1.4.0.post0
        - torchvision:       0.18.1
        - tornado:           6.4.1
        - tqdm:              4.66.4
        - traitlets:         5.14.3
        - triton:            2.3.1
        - typeguard:         4.3.0
        - types-python-dateutil: 2.9.0.20240316
        - typing-extensions: 4.12.2
        - uri-template:      1.3.0
        - urllib3:           2.2.2
        - virtualenv:        20.26.3
        - wandb:             0.17.4
        - wcwidth:           0.2.13
        - webcolors:         24.6.0
        - webdataset:        0.2.86
        - webencodings:      0.5.1
        - websocket-client:  1.8.0
        - wheel:             0.43.0
        - yarl:              1.9.4
        - zipp:              3.19.2
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:
        - python:            3.10.14
        - release:           5.10.0-31-cloud-amd64
        - version:           #1 SMP Debian 5.10.221-1 (2024-07-14)

More info

No response

@schopra8 schopra8 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jul 22, 2024
@schopra8
Copy link
Author

There were similar issues reported a few years back -- #7792

And they were seem to be solved -- #7975

So not sure if the bug was re-introduced in subsequent years OR if I'm missing something in my example code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.2.x
Projects
None yet
Development

No branches or pull requests

1 participant