Skip to content

[Performance improvement] "Bad tokens ids" optimization #6064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

guillaume-be
Copy link
Contributor

Running some benchmarks I noticed that the generation pipeline was varying quite a bit in terms of execution time. Especially the banned token masking seems to be fairly expensive (I ran some experiments where up to 30% of the time for an entire generation process was spent in this step - which seems too high considering its expected simplicity).

This PR accelerates the entire generation pipeline for models using a bad_words_ids in their configuration by around 20% on a GPU-enabled node (this includes for example translation using the Marian models).

The following changes contribute to the performance improvement:

  • Single conversion from tensor to list. Previous approach was accessing the GPU buffer for every banned token and every batch element, causing this operation to be slower than the entire forward pass through the model
  • Vectorized update of the banned tokens using a masked fill
  • Skipping the EOS token for the banned tokens (avoiding a potential duplicate masking)

@codecov
Copy link

codecov bot commented Jul 27, 2020

Codecov Report

Merging #6064 into master will decrease coverage by 0.10%.
The diff coverage is 27.77%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #6064      +/-   ##
==========================================
- Coverage   78.49%   78.38%   -0.11%     
==========================================
  Files         146      147       +1     
  Lines       26335    26384      +49     
==========================================
+ Hits        20671    20681      +10     
- Misses       5664     5703      +39     
Impacted Files Coverage Δ
src/transformers/data/test_generation_utils.py 0.00% <0.00%> (ø)
src/transformers/generation_utils.py 96.64% <93.75%> (-0.19%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8a8ae27...7d9767a. Read the comment docs.

banned_mask_list.append([idx, token])
if len(banned_mask_list) > 0:
banned_mask = torch.LongTensor(banned_mask_list)
indices = torch.ones(len(banned_mask))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ones_like?

Comment on lines 103 to 107
banned_mask = (
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size())
.to_dense()
.bool()
.to(scores.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% understand what's going on here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just make a dense tensor to start?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I kinda understand now after reading torch.sparse document

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is as far as I know the most effective way to create a mask tensor from a list of [idx, pos]. I am not aware of a method to create a dense tensor populated with ones at specific locations, e.g.

[ 0  1  1 ]
[ 0  0  0 ]
[ 1  0  0 ]

from: [[0, 1], [0, 2], [2, 0]]. As far as I know a dense matrix approach would require a zeroes initialization followed by a for loop to modify the tensor element by element.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a little comment would help people better understand!

.bool()
.to(scores.device)
)
scores.masked_fill_(banned_mask, -float("inf"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love this!

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great change as far as I can tell, thanks for the contribution! Since the generation code is very important for model performance and fairly hairy, we need to be pretty careful before we merge.

  1. Could you add a test to prove that this works on edge cases. It should me model independent.
  • passes bad_words_ids = []
  • huge list of bad_words_ids (maybe you can add an @timeout_decorator
  • passes bad_words_ids = [[eos_token]] (we execute lines 95-107 on an empty list).
  1. Could you keep the calc_banned_bad_words_ids helper function (will make 1 easier)

  2. After you're done, could you verify that these integration tests still pass

RUN_SLOW=1 pytest tests/test_modeling_marian.py
RUN_SLOW=1 pytest tests/test_modeling_bart.py

Thanks again and sorry for the trouble!

@guillaume-be
Copy link
Contributor Author

guillaume-be commented Jul 28, 2020

@sshleifer Thank you very much for the review. I have added unit tests for the modified method that hopefully aligns with what you had in mind. I have re-run the Marian integration tests that run without issue. I somehow have issues running the BART integration tests (even on master) due to an ImportError and unable to see if these still run:

from .test_configuration_common import ConfigTester
ImportError: attempted relative import with no known parent package

regarding point 2) could you please clarify? The calc_banned_bad_words_ids still exists (and is used) in the proposed PR. Would you recommend making a copy of it instead of changing its behaviour? Then the original calc_banned_bad_words_ids would no longer be used anywhere

@JetRunner I have added a comment to clarify the mask tensor generation

I am currently running into issues with Tensorflow test failing - but I do not see how it relates to the proposed changes

Thank you!

def test_postprocess_next_token_scores(self):

config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en")
model = MarianMTModel(config=config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might want to test a smaller model here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should definitely test a tiny model since the tiny model doesn't matter.
use sshleifer/tiny-marian-en-de/

32,
5,
)
except timeout_decorator.timeout_decorator.TimeoutError:
Copy link
Contributor

@sshleifer sshleifer Jul 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is meant to test that the code stays fast, so you should let it raise. Ideally this test would have failed using the old logic.

@sshleifer
Copy link
Contributor

sshleifer commented Jul 28, 2020

I misread your code, sorry.
My point 2 should be that it feels like the new masking logic could be put into a helper method like

def set_scores_to_inf_for_banned_tokens(self, scores, bad_words_ids) -> None:

just for the sake of namespace control.
You could also test that method without running generate.

Also, how significant is the speedup here?

@guillaume-be
Copy link
Contributor Author

@sshleifer This makes sense, just pushed a few more changes:

  • Moved the masking to a utility function
  • Updated the unit test to let it fail if it hits timeout. As this is configuration dependent, the limit was increased to 10 if the CI compute power available fluctuates. In general I am not sure if unit tests are the best way to perform performance regression tests
  • I have created a gist to share the performance difference between the current and the proposed approach: https://gist.github.com/guillaume-be/e335b099005e9bf38448d0e2eb02f74f . On this simple example with a GPU on Colab, the proposed approach is twice as fast. This actually has a significant impact on the entire generation process, but I did not manage to create a good example on Colab (the resources fluctuate too much from notebook to notebook, and not aware of a way to change a library version within a same notebook). Running locally with a consumer-grade Turing GPU (2070), I observe a time reduction of around 20% for the end-to-end generation process.

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just nitpicks, great gist!

for idx, batch_banned_tokens in enumerate(banned_tokens):
for token in batch_banned_tokens:
banned_mask_list.append([idx, token])
if len(banned_mask_list) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit)

if not banned_mask_list:
    return

will save some indentation

@@ -971,3 +974,67 @@ def test_top_k_top_p_filtering(self):

self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))

@require_torch
def test_postprocess_next_token_scores(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried that this is going to run for every model. Maybe it should be in it's own test_generation_utils.py file?

def test_postprocess_next_token_scores(self):

config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en")
model = MarianMTModel(config=config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should definitely test a tiny model since the tiny model doesn't matter.
use sshleifer/tiny-marian-en-de/

masked_scores = [
[(0, 299), (1, 299), (2, 299), (3, 299), (4, 299), (5, 299), (6, 299), (7, 299)],
[(1, 24), (0, 54), (1, 54), (2, 54), (3, 54), (4, 54), (5, 54), (6, 54), (7, 54)],
[(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0), (7, 0)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the 0th token is masked by default? is that a special token? if so lets get it from config.eos_token_id

this could be more readable if it were [(i,0) for i in range(batch_size)] (same for 988 and everything but first entry of 989.

@guillaume-be
Copy link
Contributor Author

@sshleifer Thank you again for the thorough review! Tried to address the latest comments - I believe it cleans it up quite a bit thank you for the suggestions

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM @LysandreJik, merge whenever!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a very useful change! Also linking this issue here: #5345 since it can probably be resolved the same way.

@sshleifer
Copy link
Contributor

This is ready to be merged @LysandreJik !

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks @guillaume-be !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants