Skip to content

AL-160: support fast tokenizer #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions src/transformers/data/processors/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from ...file_utils import is_tf_available, is_torch_available
from ...tokenization_bert import whitespace_tokenize
from ...tokenization_utils_base import TruncationStrategy
from ...tokenization_utils_base import PaddingStrategy, TruncationStrategy
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from .utils import DataProcessor

Expand Down Expand Up @@ -107,6 +108,14 @@ def squad_convert_example_to_features(
tok_to_orig_index = []
orig_to_tok_index = []
all_doc_tokens = []
if isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer.set_truncation_and_padding(
padding_strategy=PaddingStrategy.DO_NOT_PAD,
truncation_strategy=TruncationStrategy.LONGEST_FIRST,
max_length=64,
stride=0,
pad_to_multiple_of=None,
)
for (i, token) in enumerate(example.doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = tokenizer.tokenize(token)
Expand All @@ -131,6 +140,12 @@ def squad_convert_example_to_features(
example.question_text, add_special_tokens=False, truncation=True, max_length=max_query_length
)

# Handle case where tokenized query is empty, since the fast tokenizer doesn't do so
if len(truncated_query) == 0:
raise ValueError(
f"Input {truncated_query} is not valid. Should be a string or a list/tuple of strings when `is_pretokenized=True`."
)

# Tokenizers who insert 2 SEP tokens in-between <context> & <question> need to have special handling
# in the way they compute mask of added tokens.
tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower()
Expand All @@ -146,12 +161,28 @@ def squad_convert_example_to_features(

# Define the side we want to truncate / pad and the text/pair sorting
if tokenizer.padding_side == "right":
texts = truncated_query
pairs = span_doc_tokens
texts = (
truncated_query
if not isinstance(tokenizer, PreTrainedTokenizerFast)
else tokenizer.decode(truncated_query)
)
# Needed because some tokenizers seem to produce actual tokens,
# while others produce token_ids for overflow tokens
if isinstance(span_doc_tokens[0], str):
pairs = " ".join(span_doc_tokens).replace(" ##", "").strip()
else:
pairs = span_doc_tokens
truncation = TruncationStrategy.ONLY_SECOND.value
else:
texts = span_doc_tokens
pairs = truncated_query
if isinstance(span_doc_tokens[0], str):
texts = " ".join(span_doc_tokens).replace(" ##", "").strip()
else:
texts = span_doc_tokens
pairs = (
truncated_query
if not isinstance(tokenizer, PreTrainedTokenizerFast)
else tokenizer.decode(truncated_query)
)
truncation = TruncationStrategy.ONLY_FIRST.value

encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic
Expand All @@ -165,6 +196,10 @@ def squad_convert_example_to_features(
return_token_type_ids=True,
)

# Handle case where fast tokenizer returns list[list[int]]
if isinstance(encoded_dict["input_ids"][0], list):
encoded_dict = {k: v[0] for k, v in encoded_dict.items()}

paragraph_len = min(
len(all_doc_tokens) - len(spans) * doc_stride,
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
Expand Down
4 changes: 4 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,12 +672,16 @@ def test_torch_question_answering(self):
for model_name in QA_FINETUNED_MODELS:
nlp = pipeline(task="question-answering", model=model_name, tokenizer=model_name)
self._test_qa_pipeline(nlp)
nlp = pipeline(task="question-answering", model=model_name, tokenizer=(model_name, {"use_fast": True}))
self._test_qa_pipeline(nlp)

@require_tf
def test_tf_question_answering(self):
for model_name in QA_FINETUNED_MODELS:
nlp = pipeline(task="question-answering", model=model_name, tokenizer=model_name, framework="tf")
self._test_qa_pipeline(nlp)
nlp = pipeline(task="question-answering", model=model_name, tokenizer=(model_name, {"use_fast": True}))
self._test_qa_pipeline(nlp)


class NerPipelineTests(unittest.TestCase):
Expand Down