From 344b9fb0c6b0cdec70f918fc59e862365815eb19 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 25 Apr 2022 20:58:50 -0400 Subject: [PATCH] Limit the use of PreTrainedModel.device (#16935) * Limit the use of PreTrainedModel.device * Fix --- src/transformers/generation_utils.py | 16 ++++++++++------ src/transformers/modeling_utils.py | 4 ++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 1bdcd06f0d5c..76086c4b7d63 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -502,7 +502,7 @@ def _prepare_attention_mask_for_generation( if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: return inputs.ne(pad_token_id).long() else: - return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device) + return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) def _prepare_encoder_decoder_kwargs_for_generation( self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None @@ -532,13 +532,16 @@ def _prepare_decoder_input_ids_for_generation( decoder_start_token_id: int = None, bos_token_id: int = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + device: torch.device = None, ) -> torch.LongTensor: if model_kwargs is not None and "decoder_input_ids" in model_kwargs: return model_kwargs.pop("decoder_input_ids") else: decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) - return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id + if device is None: + device = self.device + return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: decoder_start_token_id = ( @@ -1177,6 +1180,7 @@ def generate( decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, model_kwargs=model_kwargs, + device=inputs_tensor.device, ) else: # if decoder-only then inputs_tensor has to be `input_ids` @@ -1327,7 +1331,7 @@ def generate( beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, - device=self.device, + device=inputs_tensor.device, length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, @@ -1367,7 +1371,7 @@ def generate( beam_scorer = BeamSearchScorer( batch_size=batch_size * num_return_sequences, num_beams=num_beams, - device=self.device, + device=inputs_tensor.device, length_penalty=length_penalty, do_early_stopping=early_stopping, ) @@ -1410,7 +1414,7 @@ def generate( batch_size=batch_size, num_beams=num_beams, max_length=stopping_criteria.max_length, - device=self.device, + device=inputs_tensor.device, length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, @@ -1492,7 +1496,7 @@ def typeerror(): constraints=final_constraints, batch_size=batch_size, num_beams=num_beams, - device=self.device, + device=inputs_tensor.device, length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f23623e5a996..17968e72aa8e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1157,7 +1157,7 @@ def _get_resized_embeddings( # Build new embeddings new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) - new_embeddings.to(self.device, dtype=old_embeddings.weight.dtype) + new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype) # initialize all new embeddings (in particular added tokens) self._init_weights(new_embeddings) @@ -1228,7 +1228,7 @@ def _get_resized_lm_head( new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) has_new_lm_head_bias = old_lm_head.bias is not None new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias) - new_lm_head = new_lm_head.to(self.device, dtype=old_lm_head.weight.dtype) + new_lm_head = new_lm_head.to(old_lm_head.weight.device, dtype=old_lm_head.weight.dtype) # initialize new lm head (in particular added tokens) self._init_weights(new_lm_head)