Skip to content

AL-160: added code and tests for batching #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/data/processors/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def squad_convert_examples_to_features(
)
features = list(
tqdm(
p.imap(annotate_, examples, chunksize=32),
p.map(annotate_, examples),
total=len(examples),
desc="convert squad examples to features",
disable=not tqdm_enabled,
Expand All @@ -408,7 +408,7 @@ def squad_convert_examples_to_features(
new_features.append(example_feature)
unique_id += 1
example_index += 1
features = new_features
features = new_features if new_features else [[]]
del new_features
if return_dataset == "pt":
if not is_torch_available():
Expand Down
172 changes: 107 additions & 65 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import sys
import uuid
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import contextmanager
from itertools import chain
from multiprocessing import cpu_count
Expand All @@ -29,6 +30,7 @@
from uuid import UUID

import numpy as np
from tqdm import tqdm

from .configuration_auto import AutoConfig
from .configuration_utils import PretrainedConfig
Expand All @@ -38,7 +40,7 @@
from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_base import BatchEncoding, PaddingStrategy
from .tokenization_utils_base import BatchEncoding
from .utils import logging


Expand Down Expand Up @@ -1383,7 +1385,7 @@ def __call__(self, *args, **kwargs):
if self.framework == "tf":
entities = self.model(tokens.data)[0][0].numpy()
input_ids = tokens["input_ids"].numpy()[0]
else:
elif self.framework == "pt":
with torch.no_grad():
tokens = self.ensure_tensor_on_device(**tokens)
entities = self.model(**tokens)[0][0].cpu().numpy()
Expand All @@ -1401,7 +1403,6 @@ def __call__(self, *args, **kwargs):
]

for idx, label_idx in filtered_labels_idx:

entity = {
"word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
"score": score[idx][label_idx].item(),
Expand Down Expand Up @@ -1637,6 +1638,57 @@ def create_sample(
else:
return SquadExample(None, question, context, None, None, None)

def extract_answers(self, example, features_and_positions, handle_impossible_answer, topk, max_answer_len):
features, start, end = [], [], []

for feature, start_pos, end_pos in features_and_positions:
features.append(feature)
start.append(start_pos)
end.append(end_pos)

min_null_score = 1000000 # large and positive
answers = []
for (feature, start_, end_) in zip(features, start, end):
# Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
undesired_tokens = np.abs(np.array(feature.p_mask) - 1) & feature.attention_mask

# Generate mask
undesired_tokens_mask = undesired_tokens == 0.0

# Make sure non-context indexes in the tensor cannot contribute to the softmax
start_ = np.where(undesired_tokens_mask, -10000.0, start_)
end_ = np.where(undesired_tokens_mask, -10000.0, end_)

# Normalize logits and spans to retrieve the answer
start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))

if handle_impossible_answer:
min_null_score = min(min_null_score, (start_[0] * end_[0]).item())

# Mask CLS
start_[0] = end_[0] = 0.0

starts, ends, scores = self.decode(start_, end_, topk, max_answer_len)
char_to_word = np.array(example.char_to_word_offset)

# Convert the answer (tokens) back to the original text
answers += [
{
"score": score.item(),
"start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
"end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
"answer": " ".join(
example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
),
}
for s, e, score in zip(starts, ends, scores)
]

if handle_impossible_answer:
answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})
return sorted(answers, key=lambda x: x["score"], reverse=True)[:topk]

def __call__(self, *args, **kwargs):
"""
Answer the question(s) given as inputs by using the context(s).
Expand Down Expand Up @@ -1686,6 +1738,8 @@ def __call__(self, *args, **kwargs):
kwargs.setdefault("max_seq_len", 384)
kwargs.setdefault("max_question_len", 64)
kwargs.setdefault("handle_impossible_answer", False)
kwargs.setdefault("batch_size", 16)
kwargs.setdefault("enable_tqdm", True)

if kwargs["topk"] < 1:
raise ValueError("topk parameter should be >= 1 (got {})".format(kwargs["topk"]))
Expand All @@ -1695,25 +1749,29 @@ def __call__(self, *args, **kwargs):

# Convert inputs to features
examples = self._args_parser(*args, **kwargs)
features_list = [
squad_convert_examples_to_features(
examples=[example],
tokenizer=self.tokenizer,
max_seq_length=kwargs["max_seq_len"],
doc_stride=kwargs["doc_stride"],
max_query_length=kwargs["max_question_len"],
padding_strategy=PaddingStrategy.DO_NOT_PAD.value,
is_training=False,
tqdm_enabled=False,
)
for example in examples
]
all_answers = []
for features, example in zip(features_list, examples):
model_input_names = self.tokenizer.model_input_names + ["input_ids"]
fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names}

# Manage tensor allocation on correct device
features_list = squad_convert_examples_to_features(
examples=examples,
tokenizer=self.tokenizer,
max_seq_length=kwargs["max_seq_len"],
doc_stride=kwargs["doc_stride"],
max_query_length=kwargs["max_question_len"],
# padding_strategy=PaddingStrategy.DO_NOT_PAD.value,
is_training=False,
tqdm_enabled=kwargs["enable_tqdm"],
)

flattend_examples = [examples[feature.example_index] for feature in features_list]
ex_feat = [flattend_examples, features_list]

# Encoding
model_input_names = self.tokenizer.model_input_names + ["input_ids"]
batch_size = kwargs["batch_size"]
all_starts = []
all_ends = []
for i in tqdm(range(0, len(ex_feat[0]), batch_size), desc="Querying model", disable=not kwargs["enable_tqdm"]):
batch = ex_feat[1][i : i + batch_size]
fw_args = {k: [feature.__dict__[k] for feature in batch] for k in model_input_names}
with self.device_placement():
if not self.use_onnx:
if self.framework == "tf":
Expand All @@ -1728,51 +1786,35 @@ def __call__(self, *args, **kwargs):
start, end = start.cpu().numpy(), end.cpu().numpy()
else:
start, end = self.model.run(None, fw_args)[:2]
# Shape of start and end = (batch_size, context_len)
all_starts.extend(start)
all_ends.extend(end)
all_starts = np.stack(all_starts).tolist()
all_ends = np.stack(all_ends).tolist()
ex_feat.append(all_starts)
ex_feat.append(all_ends)

ex_feat_dict = OrderedDict()
for index, example in enumerate(ex_feat[0]):
if example not in ex_feat_dict:
ex_feat_dict[example] = [(ex_feat[1][index], ex_feat[2][index], ex_feat[3][index])]
else:
ex_feat_dict[example].append((ex_feat[1][index], ex_feat[2][index], ex_feat[3][index]))

min_null_score = 1000000 # large and positive
answers = []
for (feature, start_, end_) in zip(features, start, end):
# Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
undesired_tokens = np.abs(np.array(feature.p_mask) - 1) & feature.attention_mask

# Generate mask
undesired_tokens_mask = undesired_tokens == 0.0

# Make sure non-context indexes in the tensor cannot contribute to the softmax
start_ = np.where(undesired_tokens_mask, -10000.0, start_)
end_ = np.where(undesired_tokens_mask, -10000.0, end_)

# Normalize logits and spans to retrieve the answer
start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))

if kwargs["handle_impossible_answer"]:
min_null_score = min(min_null_score, (start_[0] * end_[0]).item())

# Mask CLS
start_[0] = end_[0] = 0.0

starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"])
char_to_word = np.array(example.char_to_word_offset)

# Convert the answer (tokens) back to the original text
answers += [
{
"score": score.item(),
"start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
"end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
"answer": " ".join(
example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
),
}
for s, e, score in zip(starts, ends, scores)
]

if kwargs["handle_impossible_answer"]:
answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})

answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]]
all_answers += answers
# Extract answers
all_answers = []
for example, features_and_positions in tqdm(
ex_feat_dict.items(), desc="Extracting answers", disable=not kwargs["enable_tqdm"]
):
all_answers.append(
self.extract_answers(
example,
features_and_positions,
kwargs["handle_impossible_answer"],
kwargs["topk"],
kwargs["max_answer_len"],
)
)

if len(all_answers) == 1:
return all_answers[0]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,13 +650,13 @@ def _test_qa_pipeline(self, nlp):
]
self.assertIsNotNone(nlp)

mono_result = nlp(valid_inputs[0])
mono_result = nlp(valid_inputs[0])[0]
self.assertIsInstance(mono_result, dict)

for key in output_keys:
self.assertIn(key, mono_result)

multi_result = nlp(valid_inputs)
multi_result = nlp(valid_inputs)[0]
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], dict)

Expand Down