Skip to content

Conversation

jiayangshi
Copy link

Problem

In current save_checkpoint function, if the model is on multiple GPUs, i.e. model is a instance of torch.nn.DataParallel, and then the saved checkpoint could not be loaded again.

Describe your changes

Follow the pytorch tutorial, first the current model is checked, if it is a instance of DataParallel class.
If the model is a instance of DataParallel class, then model.module.state_dict() is saved instead of model.state_dict() in current implementation.

@Bjarten Bjarten deleted the branch Bjarten:main October 14, 2024 03:52
@Bjarten Bjarten closed this Oct 14, 2024
@Bjarten Bjarten reopened this Oct 14, 2024
@Bjarten Bjarten changed the base branch from master to main October 14, 2024 04:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants