Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding missing type hints for BigBird model #16555

Merged
merged 16 commits into from
Apr 5, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding missing type hints for mBART model
Tensorflow Implementation model added with missing type hints
  • Loading branch information
reichenbch authored and Rocketknight1 committed Mar 21, 2022
commit ef9ef212d2ff56b80bdaa2c47a2316e89d0fbfe7
147 changes: 74 additions & 73 deletions src/transformers/models/mbart/modeling_tf_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ...modeling_tf_utils import (
DUMMY_INPUTS,
TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
TFSharedEmbeddings,
TFWrappedEmbeddings,
Expand Down Expand Up @@ -299,7 +300,7 @@ def __init__(self, config: MBartConfig, **kwargs):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")

def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training: Optional[bool] = False):
"""
Args:
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
Expand Down Expand Up @@ -374,7 +375,7 @@ def call(
layer_head_mask: Optional[tf.Tensor] = None,
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False,
training: Optional[bool] = False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
"""
Args:
Expand Down Expand Up @@ -669,16 +670,16 @@ def set_embed_tokens(self, embed_tokens):
@unpack_inputs
def call(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
):
) -> TFBaseModelOutput:
"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Expand Down Expand Up @@ -828,21 +829,21 @@ def set_embed_tokens(self, embed_tokens):
@unpack_inputs
def call(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_ids: TFModelInputType = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
):
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, tuple(tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor)]:
r"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Expand Down Expand Up @@ -1049,24 +1050,24 @@ def set_input_embeddings(self, new_embeddings):
@unpack_inputs
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
input_ids: TFModelInputType = None,
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs
):
) -> Union[TFSeq2SeqModelOutput, tf.Tensor]:

if decoder_input_ids is None and decoder_inputs_embeds is None:
use_cache = False
Expand Down Expand Up @@ -1157,24 +1158,24 @@ def get_decoder(self):
)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
input_ids: tf.Tensor = None,
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs
):
) -> Union[TFSeq2SeqModelOutput, Tuple[tf.tensor]] :

outputs = self.model(
input_ids=input_ids,
Expand Down Expand Up @@ -1261,25 +1262,25 @@ def set_bias(self, value):
@add_end_docstrings(MBART_GENERATION_EXAMPLE)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
input_ids: TFModelInputType = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
past_key_values: [Tuple[Tuple[tf.Tensor]]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[tf.Tensor] = None,
training: Optional[bool] = False,
**kwargs,
):
) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]] :
"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
Expand Down