Skip to content

Commit

Permalink
Limit the use of PreTrainedModel.device (huggingface#16935)
Browse files Browse the repository at this point in the history
* Limit the use of PreTrainedModel.device

* Fix
  • Loading branch information
sgugger authored Apr 26, 2022
1 parent 6568752 commit 344b9fb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
16 changes: 10 additions & 6 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def _prepare_attention_mask_for_generation(
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).long()
else:
return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device)
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)

def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
Expand Down Expand Up @@ -532,13 +532,16 @@ def _prepare_decoder_input_ids_for_generation(
decoder_start_token_id: int = None,
bos_token_id: int = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
device: torch.device = None,
) -> torch.LongTensor:

if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
return model_kwargs.pop("decoder_input_ids")
else:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
if device is None:
device = self.device
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id

def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
Expand Down Expand Up @@ -1177,6 +1180,7 @@ def generate(
decoder_start_token_id=decoder_start_token_id,
bos_token_id=bos_token_id,
model_kwargs=model_kwargs,
device=inputs_tensor.device,
)
else:
# if decoder-only then inputs_tensor has to be `input_ids`
Expand Down Expand Up @@ -1327,7 +1331,7 @@ def generate(
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=self.device,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
Expand Down Expand Up @@ -1367,7 +1371,7 @@ def generate(
beam_scorer = BeamSearchScorer(
batch_size=batch_size * num_return_sequences,
num_beams=num_beams,
device=self.device,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
)
Expand Down Expand Up @@ -1410,7 +1414,7 @@ def generate(
batch_size=batch_size,
num_beams=num_beams,
max_length=stopping_criteria.max_length,
device=self.device,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
Expand Down Expand Up @@ -1492,7 +1496,7 @@ def typeerror():
constraints=final_constraints,
batch_size=batch_size,
num_beams=num_beams,
device=self.device,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,7 @@ def _get_resized_embeddings(

# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(self.device, dtype=old_embeddings.weight.dtype)
new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype)

# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
Expand Down Expand Up @@ -1228,7 +1228,7 @@ def _get_resized_lm_head(
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
new_lm_head = new_lm_head.to(self.device, dtype=old_lm_head.weight.dtype)
new_lm_head = new_lm_head.to(old_lm_head.weight.device, dtype=old_lm_head.weight.dtype)

# initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head)
Expand Down

0 comments on commit 344b9fb

Please sign in to comment.