Skip to content

Commit

Permalink
[trainer] refactor place_model_on_device logic, add deepspeed (huggin…
Browse files Browse the repository at this point in the history
…gface#10243)

* refactor place_model_on_device logic, add deepspeed

* doc

* style
  • Loading branch information
stas00 authored Feb 17, 2021
1 parent d1eb88f commit dee876c
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ class Trainer:
inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
- **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
data parallelism, this means some of the model layers are split on different GPUs).
- **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
to :obj:`False` if model parallel or deepspeed is used, or if the default
``TrainingArguments.place_model_on_device`` is overridden to return :obj:`False` .
"""

def __init__(
Expand Down Expand Up @@ -262,6 +266,11 @@ def __init__(
else:
self.is_model_parallel = False

# one place to sort out whether to place the model on device or not
self.place_model_on_device = args.place_model_on_device
if self.is_model_parallel or (args.deepspeed and args.do_train):
self.place_model_on_device = False

default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset
Expand All @@ -272,7 +281,7 @@ def __init__(
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
# and we only use deepspeed for training at the moment
if not (self.is_model_parallel or (args.deepspeed and args.do_train)) and self.args.place_model_on_device:
if self.place_model_on_device:
model = model.to(args.device)

# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
Expand Down Expand Up @@ -780,7 +789,7 @@ def train(

# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
if not self.is_model_parallel and self.args.place_model_on_device:
if self.place_model_on_device:
self.model = self.model.to(self.args.device)
self.model_wrapped = self.model

Expand Down Expand Up @@ -1033,7 +1042,7 @@ def train(
)
if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
if not self.is_model_parallel and self.args.place_model_on_device:
if self.place_model_on_device:
self.model = self.model.to(self.args.device)
else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
Expand Down

0 comments on commit dee876c

Please sign in to comment.