-
Notifications
You must be signed in to change notification settings - Fork 29.2k
[Performance improvement] "Bad tokens ids" optimization #6064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance improvement] "Bad tokens ids" optimization #6064
Conversation
Codecov Report
@@ Coverage Diff @@
## master #6064 +/- ##
==========================================
- Coverage 78.49% 78.38% -0.11%
==========================================
Files 146 147 +1
Lines 26335 26384 +49
==========================================
+ Hits 20671 20681 +10
- Misses 5664 5703 +39
Continue to review full report at Codecov.
|
src/transformers/generation_utils.py
Outdated
banned_mask_list.append([idx, token]) | ||
if len(banned_mask_list) > 0: | ||
banned_mask = torch.LongTensor(banned_mask_list) | ||
indices = torch.ones(len(banned_mask)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ones_like
?
src/transformers/generation_utils.py
Outdated
banned_mask = ( | ||
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()) | ||
.to_dense() | ||
.bool() | ||
.to(scores.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not 100% understand what's going on here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just make a dense tensor to start?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I kinda understand now after reading torch.sparse document
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is as far as I know the most effective way to create a mask tensor from a list of [idx, pos]. I am not aware of a method to create a dense tensor populated with ones at specific locations, e.g.
[ 0 1 1 ]
[ 0 0 0 ]
[ 1 0 0 ]
from: [[0, 1], [0, 2], [2, 0]]
. As far as I know a dense matrix approach would require a zeroes initialization followed by a for loop to modify the tensor element by element.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a little comment would help people better understand!
src/transformers/generation_utils.py
Outdated
.bool() | ||
.to(scores.device) | ||
) | ||
scores.masked_fill_(banned_mask, -float("inf")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
love this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great change as far as I can tell, thanks for the contribution! Since the generation code is very important for model performance and fairly hairy, we need to be pretty careful before we merge.
- Could you add a test to prove that this works on edge cases. It should me model independent.
- passes
bad_words_ids = []
- huge list of
bad_words_ids
(maybe you can add an@timeout_decorator
- passes
bad_words_ids = [[eos_token]]
(we execute lines 95-107 on an empty list).
-
Could you keep the
calc_banned_bad_words_ids
helper function (will make 1 easier) -
After you're done, could you verify that these integration tests still pass
RUN_SLOW=1 pytest tests/test_modeling_marian.py
RUN_SLOW=1 pytest tests/test_modeling_bart.py
Thanks again and sorry for the trouble!
@sshleifer Thank you very much for the review. I have added unit tests for the modified method that hopefully aligns with what you had in mind. I have re-run the Marian integration tests that run without issue. I somehow have issues running the BART integration tests (even on master) due to an from .test_configuration_common import ConfigTester
ImportError: attempted relative import with no known parent package regarding point 2) could you please clarify? The @JetRunner I have added a comment to clarify the mask tensor generation I am currently running into issues with Tensorflow test failing - but I do not see how it relates to the proposed changes Thank you! |
…_bad_words_list` unit test
…_bad_words_list` unit test (timeout does not work on Windows)
tests/test_modeling_common.py
Outdated
def test_postprocess_next_token_scores(self): | ||
|
||
config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en") | ||
model = MarianMTModel(config=config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might want to test a smaller model here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should definitely test a tiny model since the tiny model doesn't matter.
use sshleifer/tiny-marian-en-de/
tests/test_modeling_common.py
Outdated
32, | ||
5, | ||
) | ||
except timeout_decorator.timeout_decorator.TimeoutError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is meant to test that the code stays fast, so you should let it raise. Ideally this test would have failed using the old logic.
I misread your code, sorry. def set_scores_to_inf_for_banned_tokens(self, scores, bad_words_ids) -> None: just for the sake of namespace control. Also, how significant is the speedup here? |
@sshleifer This makes sense, just pushed a few more changes:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just nitpicks, great gist!
src/transformers/generation_utils.py
Outdated
for idx, batch_banned_tokens in enumerate(banned_tokens): | ||
for token in batch_banned_tokens: | ||
banned_mask_list.append([idx, token]) | ||
if len(banned_mask_list) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit)
if not banned_mask_list:
return
will save some indentation
tests/test_modeling_common.py
Outdated
@@ -971,3 +974,67 @@ def test_top_k_top_p_filtering(self): | |||
|
|||
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12)) | |||
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx))) | |||
|
|||
@require_torch | |||
def test_postprocess_next_token_scores(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm worried that this is going to run for every model. Maybe it should be in it's own test_generation_utils.py
file?
tests/test_modeling_common.py
Outdated
def test_postprocess_next_token_scores(self): | ||
|
||
config = MarianConfig.from_pretrained("Helsinki-NLP/opus-mt-fr-en") | ||
model = MarianMTModel(config=config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should definitely test a tiny model since the tiny model doesn't matter.
use sshleifer/tiny-marian-en-de/
tests/test_modeling_common.py
Outdated
masked_scores = [ | ||
[(0, 299), (1, 299), (2, 299), (3, 299), (4, 299), (5, 299), (6, 299), (7, 299)], | ||
[(1, 24), (0, 54), (1, 54), (2, 54), (3, 54), (4, 54), (5, 54), (6, 54), (7, 54)], | ||
[(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0), (7, 0)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the 0th token is masked by default? is that a special token? if so lets get it from config.eos_token_id
this could be more readable if it were [(i,0) for i in range(batch_size)] (same for 988 and everything but first entry of 989.
@sshleifer Thank you again for the thorough review! Tried to address the latest comments - I believe it cleans it up quite a bit thank you for the suggestions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM @LysandreJik, merge whenever!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a very useful change! Also linking this issue here: #5345 since it can probably be resolved the same way.
This is ready to be merged @LysandreJik ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks @guillaume-be !
Running some benchmarks I noticed that the generation pipeline was varying quite a bit in terms of execution time. Especially the banned token masking seems to be fairly expensive (I ran some experiments where up to 30% of the time for an entire generation process was spent in this step - which seems too high considering its expected simplicity).
This PR accelerates the entire generation pipeline for models using a
bad_words_ids
in their configuration by around 20% on a GPU-enabled node (this includes for example translation using the Marian models).The following changes contribute to the performance improvement: