Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Summarization to Pipelines #3128

Merged
merged 15 commits into from
Mar 17, 2020
Merged

Add Summarization to Pipelines #3128

merged 15 commits into from
Mar 17, 2020

Conversation

sshleifer
Copy link
Contributor

@sshleifer sshleifer commented Mar 4, 2020

Choices:

  1. This is not TextGenerationPipeline, so it only supports bart-large-cnn.
  2. It doesn't return the input back to the caller because it is annoyingly long.

@codecov-io
Copy link

codecov-io commented Mar 5, 2020

Codecov Report

Merging #3128 into master will increase coverage by 0.05%.
The diff coverage is 95.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3128      +/-   ##
==========================================
+ Coverage   78.02%   78.08%   +0.05%     
==========================================
  Files          98       98              
  Lines       16670    16689      +19     
==========================================
+ Hits        13007    13031      +24     
+ Misses       3663     3658       -5     
Impacted Files Coverage Δ
src/transformers/__init__.py 98.91% <ø> (ø)
src/transformers/pipelines.py 72.53% <95.00%> (+1.57%) ⬆️
src/transformers/tokenization_utils.py 91.99% <0.00%> (+0.14%) ⬆️
src/transformers/modeling_tf_utils.py 88.37% <0.00%> (+0.17%) ⬆️
src/transformers/modeling_utils.py 94.14% <0.00%> (+0.27%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3814e16...a123599. Read the comment docs.

@@ -1324,7 +1421,7 @@ def pipeline(

# Use default model/config/tokenizer for the task if no model is provided
if model is None:
models, config, tokenizer = tuple(targeted_task["default"].values())
models, config, tokenizer = [targeted_task["default"][k] for k in ["model", "config", "tokenizer"]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this removes a dependency on dict ordering

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Py36 dicts are ordered (and we might move to Python3.6 "soon") so maybe just document this gotcha, but do not change this? Or do you need it for your tests to pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my code, the order changes in py3.5 if the value for config (one of the keys in the dict) is set to bart-large-cnn. Since we still support python 3.5, it seems safer (and more explicit) to avoid reliance on dict ordering to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok then

length_penalty=2.0,
max_length=140,
min_len=20,
no_repeat_ngram_size=3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to expose all those options? or a minimal subset? (the other pipelines do not expose any option)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point.
I feel like selecting just a subset of the generate method arguments is a bit arbitrary.

What about something like what we usually do:

  • good defaults for people who just want something to work out-of-the-box (like the one here), and
  • full customizability for the others, e.g. by having generate_kwargs that is directly transmitted to the model.generate() method and linking to the generate() method doc in the the docstring for the list of the arguments (that will likely evolve in the future as better decoding mechanisms are designed).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the idea of passing through generate_kwargs.
In terms of exposing options, I was preparing to support different kwargs for bart-large-xsum, bc the author wrote "Use beam=6, lenpen=1.0, max_len_b=60, min_len=10 for Xsum Generation" in https://github.com/pytorch/fairseq/blob/master/examples/bart/README.summarization.md#L121

Copy link
Member

@thomwolf thomwolf Mar 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's also a good point. We don't want to expose two different pipelines for two trained models.

Here is what we should do if it's not already the case: all the parameters of the generate() method are actually stored in the configuration of the pretrained models (or more precisely, they wil be once #3140 is merged, see here what I'm talking about) so you should update the configuration of the two pretrained summarization weights on S3 to have the best decoding parameters as selected by the authors and not hard-code any specific default values here in the pipeline.

cc @patrickvonplaten since it's related to his work

@@ -1324,7 +1421,7 @@ def pipeline(

# Use default model/config/tokenizer for the task if no model is provided
if model is None:
models, config, tokenizer = tuple(targeted_task["default"].values())
models, config, tokenizer = [targeted_task["default"][k] for k in ["model", "config", "tokenizer"]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Py36 dicts are ordered (and we might move to Python3.6 "soon") so maybe just document this gotcha, but do not change this? Or do you need it for your tests to pass?

@julien-c julien-c requested a review from thomwolf March 5, 2020 05:41
Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok good first draft, a few comments and questions.

Let's merge #3140 first to be able to use defaults generation configuration.

Comment on lines +1206 to +1207
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
Model card attributed to the model for this pipeline.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have this argument anymore

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have it for task-specific pipelines, just not for the pipeline factory. Though I agree it should be deleted for the task-specific pipelines as well.

Comment on lines 1241 to 1252
max_length: (`optional`) int
The max length of the sequence to be generated. Does not include tokens in input_ids.

num_beams: (`optional`) int
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
repetition_penalty: (`optional`) float
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
length_penalty: (`optional`) float Exponential penalty to the length. Default to 1.
num_return_sequences: (`optional`) int.
The number of independently computed returned sequences for each element in the batch. Default to 1.
min_len: (`optional`) int
no_repeat_ngram_size: (`optional`) int. ban ngrams of this length from being repeated in the generated text
Copy link
Member

@thomwolf thomwolf Mar 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed under @julien-c's comment above, remove all these arguments and replace them by default values in the pretrained model config.json.

https://huggingface.co/transformers/model_doc/bart.html#transformers.BartForMaskedLM.generate

"""
assert return_tensors or return_text
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good practice to have an associated error message with assert when possible (here it is possible).

record["summary_token_ids"] = summary
if return_text:
record["summary_text"] = self.tokenizer.decode(
summary, skip_special_tokens=True, clean_up_tokenization_spaces=False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could expose clean_up_tokenization_spaces as an argument I think

Comment on lines 1290 to 1291
def _forward(self, *args, **kwargs):
raise NotImplementedError("Should not be called")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have this method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parent implements it, so I didn't want to inherit a method that isn't supported and have users expect it to work.
I think since the method is private, I'll just delete this.

"default": {
"model": {"pt": "bart-large-cnn", "tf": None},
"config": None,
"tokenizer": ("bart-large-cnn", {"use_fast": False}),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have to prevent the (future) use of fast tokenizers here. Do we?

cc @mfuntowicz

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the other pipelines had that so I continued the practice, happy to go the other way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the 2.5.1, we disabled the use of fast tokenizers by default, and to be consistent, did the same for pipelines. Roberta still need some work to give the same output as Python.

=> I would recommend to stay like this (ie disabled by default) for now

Also, for use_fast to work, we need to provide a fast backend, inheriting from RobertaTokenizerFast (very similar example: https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_distilbert.py#L75).

@sshleifer
Copy link
Contributor Author

I addressed all comments, and am ready for review @julien-c @thomwolf.

@julien-c julien-c merged commit 38a555a into master Mar 17, 2020
@julien-c julien-c deleted the ss-summarization-to-pipelines branch March 17, 2020 22:04
@Weilin37
Copy link

Is this pipeline ready to go? When I tried to run an example it said that the summarization pipeline is not one of the options.

@sshleifer
Copy link
Contributor Author

Hey, @Weilin37 .
Could you send a snippet of code so that I can reproduce your error?
Thanks!

@julien-c
Copy link
Member

@Weilin37 are you running from master?

@Weilin37
Copy link

@Weilin37 are you running from master?

Hi, yes it is resolved now. I thought I upgraded but it didn't

jplu pushed a commit to jplu/transformers that referenced this pull request Mar 25, 2020
* passing

* Undo stupid chg

* docs

* undo rename

* delete-cruft

* only import if you have torch

* Dont rely on dict ordering

* Fix dict ordering upstream

* docstring link

* docstring link

* remove trailing comma for 3.5 compat

* new name

* delegate kwarging

* Update kwargs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants