-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Summarization to Pipelines #3128
Changes from all commits
f1d4930
d46a79a
0ed6d2d
8862122
160e767
2fbdd0e
bfa1bad
597b362
57481a7
a489359
8e0d6d3
40b8840
f78d85f
2746370
a123599
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,6 +60,7 @@ | |
AutoModelForTokenClassification, | ||
AutoModelWithLMHead, | ||
) | ||
from .modeling_bart import BartForConditionalGeneration | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -1104,6 +1105,107 @@ def span_to_answer(self, text: str, start: int, end: int): | |
return {"answer": " ".join(words), "start": max(0, char_start_idx), "end": min(len(text), char_end_idx)} | ||
|
||
|
||
class SummarizationPipeline(Pipeline): | ||
""" | ||
Summarize news articles and other documents | ||
|
||
Usage:: | ||
|
||
summarizer = pipeline("summarization") | ||
summarizer("Sam Shleifer writes the best docstring examples in the whole world.") | ||
|
||
Supported Models: | ||
The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is | ||
currently only ``BartForConditionalGeneration.from_pretrained('bart-large-cnn')`` | ||
|
||
Arguments: | ||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`): | ||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string | ||
checkpoint identifier or an actual pre-trained model inheriting from | ||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for | ||
TensorFlow. | ||
|
||
If :obj:`None`, the default of the pipeline will be loaded. | ||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`): | ||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`, | ||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from | ||
:class:`~transformers.PreTrainedTokenizer`. | ||
|
||
If :obj:`None`, the default of the pipeline will be loaded. | ||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`): | ||
Model card attributed to the model for this pipeline. | ||
framework (:obj:`str`, `optional`, defaults to :obj:`None`): | ||
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be | ||
installed. | ||
|
||
If no framework is specified, will default to the one currently installed. If no framework is specified | ||
and both frameworks are installed, will default to PyTorch. | ||
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`): | ||
Reference to the object in charge of parsing supplied pipeline parameters. | ||
device (:obj:`int`, `optional`, defaults to :obj:`-1`): | ||
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model | ||
on the associated CUDA device id. | ||
""" | ||
|
||
task = "summarization" | ||
|
||
def __call__( | ||
self, | ||
*documents, | ||
return_tensors=False, | ||
return_text=True, | ||
max_length=142, | ||
min_length=21, | ||
clean_up_tokenization_spaces=False, | ||
**generate_kwargs | ||
): | ||
r""" | ||
Args: | ||
*documents: (list of strings) articles to be summarized | ||
return_text: (bool, default=True) whether to add a decoded "summary_text" to each result | ||
return_tensors: (bool, default=False) whether to return the raw "summary_token_ids" to each result | ||
|
||
max_length: (`optional`) int | ||
The max length of the sequence to be generated. Does not include tokens in input_ids. | ||
min_len: (`optional`) int | ||
no_repeat_ngram_size: (`optional`) int. ban ngrams of this length from being repeated in the generated text | ||
clean_up_tokenization_spaces: (`optional`) bool whether to include extra spaces in the output | ||
**generate_kwargs: extra kwargs passed to `self.model.generate`_ | ||
|
||
Returns: | ||
list of dicts with 'summary_text' and/or 'summary_token_ids' for each document_to_summarize | ||
|
||
.. _`self.model.generate`: | ||
https://huggingface.co/transformers/model_doc/bart.html#transformers.BartForConditionalGeneration.generate | ||
|
||
""" | ||
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True" | ||
if self.framework == "tf": | ||
raise NotImplementedError("Tensorflow not supported") | ||
with self.device_placement(): | ||
inputs = self._parse_and_tokenize(*documents) | ||
inputs = self.ensure_tensor_on_device(**inputs) | ||
summaries = self.model.generate( | ||
inputs["input_ids"], | ||
attention_mask=inputs["attention_mask"], | ||
max_length=max_length, | ||
min_length=min_length, | ||
do_sample=False, | ||
**generate_kwargs, | ||
) | ||
results = [] | ||
for summary in summaries: | ||
record = {} | ||
if return_tensors: | ||
record["summary_token_ids"] = summary | ||
if return_text: | ||
record["summary_text"] = self.tokenizer.decode( | ||
summary, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces | ||
) | ||
results.append(record) | ||
return results | ||
|
||
|
||
# Register all the supported task here | ||
SUPPORTED_TASKS = { | ||
"feature-extraction": { | ||
|
@@ -1162,6 +1264,16 @@ def span_to_answer(self, text: str, start: int, end: int): | |
"tokenizer": ("distilroberta-base", {"use_fast": False}), | ||
}, | ||
}, | ||
"summarization": { | ||
"impl": SummarizationPipeline, | ||
"pt": BartForConditionalGeneration if is_torch_available() else None, | ||
"tf": None, | ||
"default": { | ||
"model": {"pt": "bart-large-cnn", "tf": None}, | ||
"config": None, | ||
"tokenizer": ("bart-large-cnn", {"use_fast": False}), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we have to prevent the (future) use of fast tokenizers here. Do we? cc @mfuntowicz There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All the other pipelines had that so I continued the practice, happy to go the other way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the 2.5.1, we disabled the use of fast tokenizers by default, and to be consistent, did the same for pipelines. Roberta still need some work to give the same output as Python. => I would recommend to stay like this (ie disabled by default) for now Also, for use_fast to work, we need to provide a fast backend, inheriting from |
||
}, | ||
}, | ||
} | ||
|
||
|
||
|
@@ -1253,7 +1365,7 @@ def pipeline( | |
|
||
# Use default model/config/tokenizer for the task if no model is provided | ||
if model is None: | ||
models, config, tokenizer = tuple(targeted_task["default"].values()) | ||
models, config, tokenizer = [targeted_task["default"][k] for k in ["model", "config", "tokenizer"]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this removes a dependency on dict ordering There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Py36 dicts are ordered (and we might move to Python3.6 "soon") so maybe just document this gotcha, but do not change this? Or do you need it for your tests to pass? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my code, the order changes in py3.5 if the value for config (one of the keys in the dict) is set to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok then |
||
model = models[framework] | ||
|
||
# Try to infer tokenizer from model or config name (if provided as str) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we have this argument anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have it for task-specific pipelines, just not for the
pipeline
factory. Though I agree it should be deleted for the task-specific pipelines as well.