Skip to content

Commit

Permalink
[S2T, Whisper] Add copied from statements (huggingface#20787)
Browse files Browse the repository at this point in the history
* [S2T, Whisper] Add copied from statements

* rebase and fix-copies
  • Loading branch information
sanchit-gandhi authored Dec 20, 2022
1 parent 5eecf3f commit bd1a43b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value


# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text
class Speech2TextEncoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig):
super().__init__()
Expand All @@ -377,14 +378,14 @@ def forward(
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
output_attentions: bool = False,
):
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
Expand Down Expand Up @@ -422,6 +423,7 @@ def forward(
return outputs


# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text
class Speech2TextDecoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig):
super().__init__()
Expand Down Expand Up @@ -460,7 +462,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
):
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
Expand All @@ -473,7 +475,7 @@ def forward(
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size *(decoder_attention_heads,)*.
size `(decoder_attention_heads,)`.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value


# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextEncoderLayer with Speech2Text->Whisper
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper
class WhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
Expand All @@ -285,14 +285,14 @@ def forward(
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
output_attentions: bool = False,
):
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
Expand Down Expand Up @@ -330,7 +330,7 @@ def forward(
return outputs


# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextDecoderLayer with Speech2Text->Whisper
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper
class WhisperDecoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
Expand Down Expand Up @@ -369,7 +369,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
):
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
Expand All @@ -382,7 +382,7 @@ def forward(
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size *(decoder_attention_heads,)*.
size `(decoder_attention_heads,)`.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
Expand Down

0 comments on commit bd1a43b

Please sign in to comment.