diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 494367b813bf3e..e85d8e304d09bd 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -807,7 +807,7 @@ def create_extended_attention_mask_for_decoder(input_shape, attention_mask, devi return extended_attention_mask def get_extended_attention_mask( - self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None, dtype: torch.float = None + self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None ) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.