You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm using 2 optimizers and trying to train with AMP (FP16). I can take steps with my first optimizer. When I take my first step with the second optimizer I get the following error:
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/amp/grad_scaler.py", line 450, in step
len(optimizer_state["found_inf_per_device"]) > 0
AssertionError: No inf checks were recorded for this optimizer.
I can train this correctly in FP32 -- so it seems to be an issue with AMP.
What version are you seeing the problem on?
version 2.3.3
How to reproduce the bug
deftraining_step(self, batch: Dict, batch_idx: int):
""" We have 2 sets of optimizers. Every N batches (self.n_batches_per_optimizer), we make an optimizer update and switch the optimizer to update. If self.n_batches_per_optimizer = 1, then we make updates every batch and alternate optimizers every batch. If self.n_batches_per_optimizer > 1, then we're doing gradient accumulartion, where we are making updates evern n_batches_per_optimizer batches and alternating optimizers every n_batches_per_optimizer batches. """opts=self.optimizers()
current_cycle= (batch_idx//self.n_batches_per_optimizer) %len(opts)
opt=opts[current_cycle]
opt.zero_grad()
ifcurrent_cycle==0:
compute_model_1_loss=Trueelifcurrent_cycle==1:
compute_model_1_loss=Falseelse:
raiseNotImplementedError(f"Unknown optimizer {current_cycle}")
withopt.toggle_model():
loss=self.inner_training_step(batch=batch, compute_model_1_loss=compute_model_1_loss)
self.manual_backward(loss=loss)
# Perform the optimization step every accumulate_grad_batches stepsif (batch_idx+1) %self.n_batches_per_optimizer==0:
ifnotcompute_model_1_loss:
print("About to take compute model 2 loss ...")
opt.step()
opt.zero_grad()
Error messages and logs
Traceback (most recent call last):
File "/home/sahil/train.py", line 82, in <module>
main(config)
File "/home/sahil/train.py", line 62, in main
trainer.fit(model, datamodule=data_module, ckpt_path=ckpt)
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
call._call_and_handle_interrupt(
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
results = self._run_stage()
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1030, in _run_stage
self.fit_loop.run()
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
self.advance()
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
self.epoch_loop.run(self._data_fetcher)
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
self.advance(data_fetcher)
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 252, in advance
batch_output = self.manual_optimization.run(kwargs)
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/manual.py", line 94, in run
self.advance(kwargs)
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/manual.py", line 114, in advance
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 390, in training_step
return self.lightning_module.training_step(*args, **kwargs)
File "/home/sahil/model/model.py", line 169, in training_step
opt.step()
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 153, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 238, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/amp.py", line 93, in optimizer_step
step_output = self.scaler.step(optimizer, **kwargs) # type: ignore[arg-type]
File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/amp/grad_scaler.py", line 450, in step
len(optimizer_state["found_inf_per_device"]) > 0
AssertionError: No inf checks were recorded for this optimizer.
Bug description
I'm using 2 optimizers and trying to train with AMP (FP16). I can take steps with my first optimizer. When I take my first step with the second optimizer I get the following error:
I can train this correctly in FP32 -- so it seems to be an issue with AMP.
What version are you seeing the problem on?
version 2.3.3
How to reproduce the bug
Error messages and logs
Environment
Current environment
More info
No response
The text was updated successfully, but these errors were encountered: