-
Notifications
You must be signed in to change notification settings - Fork 31.5k
[tests] update test_left_padding_compatibility (and minimize overwrites)
#40980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -438,88 +438,11 @@ def test_batching_equivalence(self): | |
| super().test_batching_equivalence() | ||
| self.model_tester.use_input_mask = orig | ||
|
|
||
| # essentially the same test in test_utils, just adjustment for rtol for this model | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this overwrite comment wasn't even true 😢 |
||
| @pytest.mark.generate | ||
| def test_left_padding_compatibility(self): | ||
| # NOTE: left-padding results in small numerical differences. This is expected. | ||
| # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 | ||
|
|
||
| # First, filter out models that don't support left padding | ||
| # - The model must have generative capabilities | ||
| if len(self.all_generative_model_classes) == 0: | ||
| self.skipTest(reason="No generative architecture available for this model.") | ||
|
|
||
| # - The model must support padding | ||
| if not self.has_attentions: | ||
| self.skipTest(reason="This model doesn't support padding.") | ||
|
|
||
| # - The model must be a decoder-only architecture (encoder-based architectures use right-padding) | ||
| decoder_only_classes = [] | ||
| for model_class in self.all_generative_model_classes: | ||
| config, _ = self.prepare_config_and_inputs_for_generate() | ||
| if config.is_encoder_decoder: | ||
| continue | ||
| else: | ||
| decoder_only_classes.append(model_class) | ||
| if len(decoder_only_classes) == 0: | ||
| self.skipTest(reason="No decoder-only architecture available for this model.") | ||
|
|
||
| # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't | ||
| # added support for it yet. We skip these models for now. | ||
| has_encoder_attributes = any( | ||
| attr_name | ||
| for attr_name in config.to_dict() | ||
| if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size" | ||
| ) | ||
| if has_encoder_attributes: | ||
| self.skipTest( | ||
| reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding." | ||
| ) | ||
|
|
||
| # Then, test left-padding | ||
| def _prepare_model_kwargs(input_ids, attention_mask, signature): | ||
| model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} | ||
| if "position_ids" in signature: | ||
| position_ids = torch.cumsum(attention_mask, dim=-1) - 1 | ||
| position_ids.masked_fill_(attention_mask == 0, 1) | ||
| model_kwargs["position_ids"] = position_ids | ||
| if "cache_position" in signature: | ||
| cache_position = torch.arange(input_ids.shape[-1], device=torch_device) | ||
| model_kwargs["cache_position"] = cache_position | ||
| return model_kwargs | ||
|
|
||
| for model_class in decoder_only_classes: | ||
| config, inputs_dict = self.prepare_config_and_inputs_for_generate() | ||
| input_ids = inputs_dict["input_ids"] | ||
|
|
||
| # - for left padding we absolutely need to use an all ones | ||
| # attention mask, so we do not use the one in inputs_dict | ||
| attention_mask = torch.ones_like(input_ids) | ||
|
|
||
| model = model_class(config).to(torch_device).eval() | ||
| signature = inspect.signature(model.forward).parameters.keys() | ||
|
|
||
| # no cache as some models require special cache classes to be init outside forward | ||
| model.generation_config.use_cache = False | ||
|
|
||
| # Without padding | ||
| model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) | ||
| next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] | ||
|
|
||
| # With left-padding (length 32) | ||
| # can hardcode pad_token to be 0 as we'll do attn masking anyway | ||
| pad_token_id = ( | ||
| config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0 | ||
| ) | ||
| pad_size = (input_ids.shape[0], 32) | ||
| padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id | ||
| padded_input_ids = torch.cat((padding, input_ids), dim=1) | ||
| padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) | ||
| model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) | ||
| next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] | ||
|
|
||
| # They should result in very similar logits | ||
| torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) | ||
| # TODO: document why a random attention mask causes this test to fail, but a full mask doesn't | ||
| unpadded_custom_inputs = {"attention_mask": None} | ||
| super().test_left_padding_compatibility(unpadded_custom_inputs=unpadded_custom_inputs) | ||
|
|
||
| @unittest.skip( | ||
| "Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training." | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TL;DR
token_type_ids