Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Summarization to Pipelines #3128

Merged
merged 15 commits into from
Mar 17, 2020
5 changes: 5 additions & 0 deletions docs/source/main_classes/pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,8 @@ QuestionAnsweringPipeline

.. autoclass:: transformers.QuestionAnsweringPipeline


SummarizationPipeline
==========================================

.. autoclass:: transformers.SummarizationPipeline
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
Pipeline,
PipelineDataFormat,
QuestionAnsweringPipeline,
SummarizationPipeline,
TextClassificationPipeline,
TokenClassificationPipeline,
pipeline,
Expand Down
114 changes: 113 additions & 1 deletion src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
AutoModelForTokenClassification,
AutoModelWithLMHead,
)
from .modeling_bart import BartForConditionalGeneration


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1104,6 +1105,107 @@ def span_to_answer(self, text: str, start: int, end: int):
return {"answer": " ".join(words), "start": max(0, char_start_idx), "end": min(len(text), char_end_idx)}


class SummarizationPipeline(Pipeline):
"""
Summarize news articles and other documents

Usage::

summarizer = pipeline("summarization")
summarizer("Sam Shleifer writes the best docstring examples in the whole world.")

Supported Models:
The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is
currently only ``BartForConditionalGeneration.from_pretrained('bart-large-cnn')``

Arguments:
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
checkpoint identifier or an actual pre-trained model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.

If :obj:`None`, the default of the pipeline will be loaded.
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
:class:`~transformers.PreTrainedTokenizer`.

If :obj:`None`, the default of the pipeline will be loaded.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
Model card attributed to the model for this pipeline.
Comment on lines +1135 to +1136
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have this argument anymore

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have it for task-specific pipelines, just not for the pipeline factory. Though I agree it should be deleted for the task-specific pipelines as well.

framework (:obj:`str`, `optional`, defaults to :obj:`None`):
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
installed.

If no framework is specified, will default to the one currently installed. If no framework is specified
and both frameworks are installed, will default to PyTorch.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
Reference to the object in charge of parsing supplied pipeline parameters.
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
on the associated CUDA device id.
"""

task = "summarization"

def __call__(
self,
*documents,
return_tensors=False,
return_text=True,
max_length=142,
min_length=21,
clean_up_tokenization_spaces=False,
**generate_kwargs
):
r"""
Args:
*documents: (list of strings) articles to be summarized
return_text: (bool, default=True) whether to add a decoded "summary_text" to each result
return_tensors: (bool, default=False) whether to return the raw "summary_token_ids" to each result

max_length: (`optional`) int
The max length of the sequence to be generated. Does not include tokens in input_ids.
min_len: (`optional`) int
no_repeat_ngram_size: (`optional`) int. ban ngrams of this length from being repeated in the generated text
clean_up_tokenization_spaces: (`optional`) bool whether to include extra spaces in the output
**generate_kwargs: extra kwargs passed to `self.model.generate`_

Returns:
list of dicts with 'summary_text' and/or 'summary_token_ids' for each document_to_summarize

.. _`self.model.generate`:
https://huggingface.co/transformers/model_doc/bart.html#transformers.BartForConditionalGeneration.generate

"""
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
if self.framework == "tf":
raise NotImplementedError("Tensorflow not supported")
with self.device_placement():
inputs = self._parse_and_tokenize(*documents)
inputs = self.ensure_tensor_on_device(**inputs)
summaries = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length,
min_length=min_length,
do_sample=False,
**generate_kwargs,
)
results = []
for summary in summaries:
record = {}
if return_tensors:
record["summary_token_ids"] = summary
if return_text:
record["summary_text"] = self.tokenizer.decode(
summary, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces
)
results.append(record)
return results


# Register all the supported task here
SUPPORTED_TASKS = {
"feature-extraction": {
Expand Down Expand Up @@ -1162,6 +1264,16 @@ def span_to_answer(self, text: str, start: int, end: int):
"tokenizer": ("distilroberta-base", {"use_fast": False}),
},
},
"summarization": {
"impl": SummarizationPipeline,
"pt": BartForConditionalGeneration if is_torch_available() else None,
"tf": None,
"default": {
"model": {"pt": "bart-large-cnn", "tf": None},
"config": None,
"tokenizer": ("bart-large-cnn", {"use_fast": False}),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have to prevent the (future) use of fast tokenizers here. Do we?

cc @mfuntowicz

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the other pipelines had that so I continued the practice, happy to go the other way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the 2.5.1, we disabled the use of fast tokenizers by default, and to be consistent, did the same for pipelines. Roberta still need some work to give the same output as Python.

=> I would recommend to stay like this (ie disabled by default) for now

Also, for use_fast to work, we need to provide a fast backend, inheriting from RobertaTokenizerFast (very similar example: https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_distilbert.py#L75).

},
},
}


Expand Down Expand Up @@ -1253,7 +1365,7 @@ def pipeline(

# Use default model/config/tokenizer for the task if no model is provided
if model is None:
models, config, tokenizer = tuple(targeted_task["default"].values())
models, config, tokenizer = [targeted_task["default"][k] for k in ["model", "config", "tokenizer"]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this removes a dependency on dict ordering

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Py36 dicts are ordered (and we might move to Python3.6 "soon") so maybe just document this gotcha, but do not change this? Or do you need it for your tests to pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my code, the order changes in py3.5 if the value for config (one of the keys in the dict) is set to bart-large-cnn. Since we still support python 3.5, it seems safer (and more explicit) to avoid reliance on dict ordering to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok then

model = models[framework]

# Try to infer tokenizer from model or config name (if provided as str)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,16 @@ def test_tf_fill_mask(self):
expected_check_keys=["sequence"],
)

@require_torch
def test_summarization(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["summary_text"]
nlp = pipeline(task="summarization")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)


class MultiColumnInputTestCase(unittest.TestCase):
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
Expand Down