Skip to content

Commit

Permalink
Add type hints for ProphetNet (Pytorch) (huggingface#17223)
Browse files Browse the repository at this point in the history
* added type hints to prophetnet

* reformatted with black

* fix bc black misformatted some parts

* fix imports

* fix imports

* Update src/transformers/models/prophetnet/configuration_prophetnet.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* update OPTIONAL type hint and docstring

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
  • Loading branch information
jQuinRivero and Rocketknight1 authored May 18, 2022
1 parent d6b8e9c commit 7ba1d4e
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 48 deletions.
53 changes: 27 additions & 26 deletions src/transformers/models/prophetnet/configuration_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
""" ProphetNet model configuration"""

from typing import Callable, Optional, Union

from ...configuration_utils import PretrainedConfig
from ...utils import logging
Expand Down Expand Up @@ -105,32 +106,32 @@ class ProphetNetConfig(PretrainedConfig):

def __init__(
self,
activation_dropout=0.1,
activation_function="gelu",
vocab_size=30522,
hidden_size=1024,
encoder_ffn_dim=4096,
num_encoder_layers=12,
num_encoder_attention_heads=16,
decoder_ffn_dim=4096,
num_decoder_layers=12,
num_decoder_attention_heads=16,
attention_dropout=0.1,
dropout=0.1,
max_position_embeddings=512,
init_std=0.02,
is_encoder_decoder=True,
add_cross_attention=True,
decoder_start_token_id=0,
ngram=2,
num_buckets=32,
relative_max_distance=128,
disable_ngram_loss=False,
eps=0.0,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
activation_dropout: Optional[float] = 0.1,
activation_function: Optional[Union[str, Callable]] = "gelu",
vocab_size: Optional[int] = 30522,
hidden_size: Optional[int] = 1024,
encoder_ffn_dim: Optional[int] = 4096,
num_encoder_layers: Optional[int] = 12,
num_encoder_attention_heads: Optional[int] = 16,
decoder_ffn_dim: Optional[int] = 4096,
num_decoder_layers: Optional[int] = 12,
num_decoder_attention_heads: Optional[int] = 16,
attention_dropout: Optional[float] = 0.1,
dropout: Optional[float] = 0.1,
max_position_embeddings: Optional[int] = 512,
init_std: Optional[float] = 0.02,
is_encoder_decoder: Optional[bool] = True,
add_cross_attention: Optional[bool] = True,
decoder_start_token_id: Optional[int] = 0,
ngram: Optional[int] = 2,
num_buckets: Optional[int] = 32,
relative_max_distance: Optional[int] = 128,
disable_ngram_loss: Optional[bool] = False,
eps: Optional[float] = 0.0,
use_cache: Optional[bool] = True,
pad_token_id: Optional[int] = 0,
bos_token_id: Optional[int] = 1,
eos_token_id: Optional[int] = 2,
**kwargs
):
self.vocab_size = vocab_size
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/models/prophetnet/modeling_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput):
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`):
last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`, *optional*):
Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
Expand Down Expand Up @@ -590,7 +590,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding):
the forward function.
"""

def __init__(self, config: ProphetNetConfig):
def __init__(self, config: ProphetNetConfig) -> None:
self.max_length = config.max_position_embeddings
super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id)

Expand Down Expand Up @@ -1407,7 +1407,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
embeddings instead of randomly initialized word embeddings.
"""

def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = None):
def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None):
super().__init__(config)

self.ngram = config.ngram
Expand Down Expand Up @@ -1769,7 +1769,7 @@ def prepare_predict_attention_mask(self, hidden_states, attention_mask):
PROPHETNET_START_DOCSTRING,
)
class ProphetNetModel(ProphetNetPreTrainedModel):
def __init__(self, config):
def __init__(self, config: ProphetNetConfig):
super().__init__(config)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)

Expand Down Expand Up @@ -2106,7 +2106,7 @@ def get_decoder(self):
PROPHETNET_START_DOCSTRING,
)
class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
def __init__(self, config):
def __init__(self, config: ProphetNetConfig):
# set config for CLM
config = copy.deepcopy(config)
config.is_decoder = True
Expand Down Expand Up @@ -2341,7 +2341,7 @@ class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel):
classes.
"""

def __init__(self, config):
def __init__(self, config: ProphetNetConfig):
super().__init__(config)
self.decoder = ProphetNetDecoder(config)

Expand Down
35 changes: 19 additions & 16 deletions src/transformers/models/prophetnet/tokenization_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import collections
import os
from typing import List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple

from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
Expand Down Expand Up @@ -111,17 +111,17 @@ class ProphetNetTokenizer(PreTrainedTokenizer):

def __init__(
self,
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
x_sep_token="[X_SEP]",
pad_token="[PAD]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
vocab_file: str,
do_lower_case: Optional[bool] = True,
do_basic_tokenize: Optional[bool] = True,
never_split: Optional[Iterable] = None,
unk_token: Optional[str] = "[UNK]",
sep_token: Optional[str] = "[SEP]",
x_sep_token: Optional[str] = "[X_SEP]",
pad_token: Optional[str] = "[PAD]",
mask_token: Optional[str] = "[MASK]",
tokenize_chinese_chars: Optional[bool] = True,
strip_accents: Optional[bool] = None,
**kwargs
):
super().__init__(
Expand Down Expand Up @@ -177,21 +177,24 @@ def _tokenize(self, text):
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens

def _convert_token_to_id(self, token):
def _convert_token_to_id(self, token: str):
"""Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token))

def _convert_id_to_token(self, index):
def _convert_id_to_token(self, index: int):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token)

def convert_tokens_to_string(self, tokens):
def convert_tokens_to_string(self, tokens: str):
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string

def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: Optional[bool] = False,
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
Expand Down

0 comments on commit 7ba1d4e

Please sign in to comment.