Skip to content

Commit

Permalink
Use model.from_pretrained for DataParallel also (huggingface#8795)
Browse files Browse the repository at this point in the history
* Use model.from_pretrained for DataParallel also

When training on multiple GPUs, the code wraps a model with torch.nn.DataParallel. However if the model has custom from_pretrained logic, it does not get applied during load_best_model_at_end.

This commit uses the underlying model during load_best_model_at_end, and re-wraps the loaded model with DataParallel.

If you choose to reject this change, then could you please move the this logic to a function, e.g. def load_best_model_checkpoint(best_model_checkpoint) or something, so that it can be overridden?

* Fix silly bug

* Address review comments

Thanks for the feedback. I made the change that you proposed, but I also think we should update L811 to check if `self.mode` is an instance of `PreTrained`, otherwise we would still not get into that `if` section, right?
  • Loading branch information
shaie authored Nov 30, 2020
1 parent 4062c75 commit 7738494
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,8 +808,8 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)
if isinstance(model, PreTrainedModel):
self.model = model.from_pretrained(self.state.best_model_checkpoint)
if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
if not self.args.model_parallel:
self.model = self.model.to(self.args.device)
else:
Expand Down

0 comments on commit 7738494

Please sign in to comment.