-
Notifications
You must be signed in to change notification settings - Fork 29.2k
[Performance improvement] "Bad tokens ids" optimization #6064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
LysandreJik
merged 17 commits into
huggingface:master
from
guillaume-be:banned_tokens_optimization
Aug 11, 2020
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
6ff6548
Optimized banned token masking
guillaume-be 3f046f6
Avoid duplicate EOS masking if in bad_words_id
guillaume-be c717232
Updated mask generation to handle empty banned token list
guillaume-be f0adddb
Addition of unit tests for the updated bad_words_ids masking
guillaume-be f4dabda
Updated timeout handling in `test_postprocess_next_token_scores_large…
guillaume-be e1e0b55
Updated timeout handling in `test_postprocess_next_token_scores_large…
guillaume-be 7d3686b
Moving Marian import to the test context to allow TF only environment…
guillaume-be bd8e908
Moving imports to torch_available test
guillaume-be f482b63
Merge remote-tracking branch 'remotes/upstream/master' into banned_to…
guillaume-be d8a9368
Updated operations device and test
guillaume-be 3e0e8fb
Updated operations device and test
guillaume-be 4f74e2a
Added docstring and comment for in-place scores modification
guillaume-be 4a2a4e6
Moving test to own test_generation_utils, use of lighter models for t…
guillaume-be 568fbb9
removed unneded imports in test_modeling_common
guillaume-be 2618221
revert formatting change for ModelTesterMixin
guillaume-be 01c0697
Updated caching, simplified eos token id test, removed unnecessary @r…
guillaume-be 7d9767a
formatting compliance
guillaume-be File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import random | ||
import unittest | ||
|
||
import timeout_decorator | ||
|
||
from transformers import is_torch_available | ||
from transformers.file_utils import cached_property | ||
from transformers.testing_utils import require_torch | ||
|
||
|
||
if is_torch_available(): | ||
import torch | ||
|
||
from transformers import ( | ||
MarianConfig, | ||
MarianMTModel, | ||
) | ||
|
||
|
||
@require_torch | ||
class GenerationUtilsTest(unittest.TestCase): | ||
@cached_property | ||
def config(self): | ||
config = MarianConfig.from_pretrained("sshleifer/tiny-marian-en-de") | ||
return config | ||
|
||
@cached_property | ||
def model(self): | ||
return MarianMTModel(self.config) | ||
|
||
def test_postprocess_next_token_scores(self): | ||
config = self.config | ||
model = self.model | ||
# Initialize an input id tensor with batch size 8 and sequence length 12 | ||
input_ids = torch.arange(0, 96, 1).view((8, 12)) | ||
eos = config.eos_token_id | ||
bad_words_ids_test_cases = [[[299]], [[23, 24], [54]], [[config.eos_token_id]], []] | ||
masked_scores = [ | ||
[(0, 299), (1, 299), (2, 299), (3, 299), (4, 299), (5, 299), (6, 299), (7, 299)], | ||
[(1, 24), (0, 54), (1, 54), (2, 54), (3, 54), (4, 54), (5, 54), (6, 54), (7, 54)], | ||
[(0, eos), (1, eos), (2, eos), (3, eos), (4, eos), (5, eos), (6, eos), (7, eos)], | ||
[], | ||
] | ||
|
||
for test_case_index, bad_words_ids in enumerate(bad_words_ids_test_cases): | ||
# Initialize a scores tensor with batch size 8 and vocabulary size 300 | ||
scores = torch.rand((8, 300)) | ||
output = model.postprocess_next_token_scores( | ||
scores, | ||
input_ids, | ||
0, | ||
bad_words_ids, | ||
13, | ||
15, | ||
config.max_length, | ||
config.eos_token_id, | ||
config.repetition_penalty, | ||
32, | ||
5, | ||
) | ||
for masked_score in masked_scores[test_case_index]: | ||
self.assertTrue(output[masked_score[0], masked_score[1]] == -float("inf")) | ||
|
||
@timeout_decorator.timeout(10) | ||
def test_postprocess_next_token_scores_large_bad_words_list(self): | ||
|
||
config = self.config | ||
model = self.model | ||
# Initialize an input id tensor with batch size 8 and sequence length 12 | ||
input_ids = torch.arange(0, 96, 1).view((8, 12)) | ||
|
||
bad_words_ids = [] | ||
for _ in range(100): | ||
length_bad_word = random.randint(1, 4) | ||
bad_words_ids.append(random.sample(range(1, 300), length_bad_word)) | ||
|
||
scores = torch.rand((8, 300)) | ||
_ = model.postprocess_next_token_scores( | ||
scores, | ||
input_ids, | ||
0, | ||
bad_words_ids, | ||
13, | ||
15, | ||
config.max_length, | ||
config.eos_token_id, | ||
config.repetition_penalty, | ||
32, | ||
5, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.