Skip to content

Commit

Permalink
Remove device parameter from create_extended_attention_mask_for_decod…
Browse files Browse the repository at this point in the history
  • Loading branch information
pbelevich authored May 3, 2022
1 parent dd739f7 commit 39f8eaf
Show file tree
Hide file tree
Showing 31 changed files with 48 additions and 42 deletions.
2 changes: 1 addition & 1 deletion examples/research_projects/longform-qa/eli5_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_bat
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device
attention_mask, input_shape
)

# define function for checkpointing
Expand Down
20 changes: 16 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,13 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
return encoder_extended_attention_mask

@staticmethod
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device):
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
if device is not None:
warnings.warn(
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
else:
device = attention_mask.device
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
Expand All @@ -672,7 +678,9 @@ def create_extended_attention_mask_for_decoder(input_shape, attention_mask, devi
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
return extended_attention_mask

def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Expand All @@ -681,12 +689,16 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
device: (`torch.device`):
The device of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
if not (attention_mask.dim() == 2 and self.config.is_decoder):
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
if device is not None:
warnings.warn(
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,7 @@ def forward(
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = None
if not use_cache:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, device
)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -2112,9 +2112,7 @@ def forward(
to_mask = None
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, device
)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
else:
raise ValueError(
f"attention_type can either be original_full or block_sparse, but is {self.attention_type}"
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/canine/modeling_canine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,12 +1130,12 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
molecule_attention_mask = self._downsample_attention_mask(
attention_mask, downsampling_rate=self.config.downsampling_rate
)
extended_molecule_attention_mask: torch.Tensor = self.get_extended_attention_mask(
molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1]), device
molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1])
)

# Prepare head mask if needed
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/convbert/modeling_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

hidden_states = self.embeddings(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/data2vec/modeling_data2vec_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/ibert/modeling_ibert.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/longformer/modeling_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,7 +1692,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)[
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)[
:, 0, 0, :
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mmbt/modeling_mmbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def forward(
[torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1
)

extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/mobilebert/modeling_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,9 +875,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, self.device
)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mpnet/modeling_mpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def forward(

if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/qdqbert/modeling_qdqbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/realm/modeling_realm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/rembert/modeling_rembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/retribert/modeling_retribert.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def embed_sentences_checkpointed(
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device
attention_mask, input_shape
)

# define function for checkpointing
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/roberta/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/roformer/modeling_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/splinter/modeling_splinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def forward(
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/tapas/modeling_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/vilt/modeling_vilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

encoder_outputs = self.encoder(
embedding_output,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/visual_bert/modeling_visual_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,12 +794,12 @@ def forward(
if visual_embeds is not None:
combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
combined_attention_mask, [batch_size, input_shape + visual_input_shape], device
combined_attention_mask, (batch_size, input_shape + visual_input_shape)
)

else:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, [batch_size, input_shape], device
attention_mask, (batch_size, input_shape)
)

# Prepare head mask if needed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/yoso/modeling_yoso.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down
Loading

0 comments on commit 39f8eaf

Please sign in to comment.