This small Python script converts PyTorch checkpoints trained with nn.DataParallel (multi-GPU)
to be compatible with a single GPU or CPU setup.
Many PyTorch users face issues when loading multi-GPU checkpoints on a single device,
because the state_dict keys are prefixed with module.. This script removes that prefix
and saves a new checkpoint that works on one GPU.
- Supports
.pthand.ptfiles - Automatically detects if checkpoint contains
state_dictor full model - Saves a new checkpoint with
_unparalleledappended to the original filename - Easy to reuse on any checkpoint
- Clone or download this script.
- Edit the
checkpoint_pathvariable inside the script to point to your checkpoint file. - Run the script:
python convert_checkpoint.py