Open
Description
Bug description
When running Lightning with a multi-device training strategy (e.g. with DDP), using the OnExceptionCheckpoint
callback:
- silently swallows exceptions, which makes it challenging to identify the cause of errors
- results in a NCCL timeout
This is due to the following:
- When we catch an exception, it gets handled by
_call_and_handle_interrupt
, which calls into_interrupt
: - We are supposed to re-raise the original exception at the end of this function, but we never get there because...
- In
_interrupt
, we call_call_callback_hooks
, which calls theon_exception
callbacks: - If the
OnExceptionCheckpoint
is enabled, we then call that callback. However, we never finish executing this callback, because in that callback, we calltrainer.save_checkpoint
: https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/callbacks/on_exception_checkpoint.py#L67 - The
trainer.save_checkpoint
method saves the checkpoint, and then callsself.strategy.barrier("Trainer.save_checkpoint")
, which waits for the other processes to get reach that barrier. However, if those processes haven't had an exception, they will never hit this codepath, which means we never advance beyond that barrier (until it times out).
As described in the docstring for Trainer.save_checkpoint
:
This method needs to be called on all processes in case the selected strategy is handling distributed checkpointing.
In practice, this means that our jobs eventually time out with a NCCL error, and don't print the original exception.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
No response
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response