-
Notifications
You must be signed in to change notification settings - Fork 31.3k
[CI] Check test if the GenerationTesterMixin inheritance is correct 🐛 🔫
#36180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fc5ef0c
4ee061e
286270a
fc63465
92d65cf
d2c2859
34d4a1e
e971049
e90c389
8c4f298
129a5a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| import datetime | ||
| import gc | ||
| import inspect | ||
| import random | ||
| import tempfile | ||
| import unittest | ||
| import warnings | ||
|
|
@@ -48,8 +49,6 @@ | |
| ) | ||
| from transformers.utils import is_ipex_available | ||
|
|
||
| from ..test_modeling_common import floats_tensor, ids_tensor | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| import torch | ||
|
|
@@ -2786,6 +2785,43 @@ def test_speculative_sampling_target_distribution(self): | |
| self.assertTrue(last_token_counts[8] > last_token_counts[3]) | ||
|
|
||
|
|
||
| global_rng = random.Random() | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copied from
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a comment with "copied from" can be added i think
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are these used in a lot of places in this file, or just inside one method? If so, we can probably avoid circular dependencies by importing them within that (single) method ..?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good idea, moving to an internal import to prevent code bloat
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. uhmmm local imports would be needed in many places, will go with
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one another possible approach is not to use
but check the Up to you. p.s.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's also okay to have a pure copy of the short functions :P It's just a handful of lines, I don't think it's worth the extra work for now -- I will have to refactor these lines when we remove TF (i.e. very soon) 👀
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK 👍 |
||
|
|
||
|
|
||
| # Copied from tests.test_modeling_common.ids_tensor | ||
| def ids_tensor(shape, vocab_size, rng=None, name=None): | ||
| # Creates a random int32 tensor of the shape within the vocab size | ||
| if rng is None: | ||
| rng = global_rng | ||
|
|
||
| total_dims = 1 | ||
| for dim in shape: | ||
| total_dims *= dim | ||
|
|
||
| values = [] | ||
| for _ in range(total_dims): | ||
| values.append(rng.randint(0, vocab_size - 1)) | ||
|
|
||
| return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() | ||
|
|
||
|
|
||
| # Copied from tests.test_modeling_common.floats_tensor | ||
| def floats_tensor(shape, scale=1.0, rng=None, name=None): | ||
| """Creates a random float32 tensor""" | ||
| if rng is None: | ||
| rng = global_rng | ||
|
|
||
| total_dims = 1 | ||
| for dim in shape: | ||
| total_dims *= dim | ||
|
|
||
| values = [] | ||
| for _ in range(total_dims): | ||
| values.append(rng.random() * scale) | ||
|
|
||
| return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous() | ||
|
|
||
|
|
||
| @pytest.mark.generate | ||
| @require_torch | ||
| class GenerationIntegrationTests(unittest.TestCase): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -451,6 +451,8 @@ class BigBirdModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) | |
| if is_torch_available() | ||
| else () | ||
| ) | ||
| # Doesn't run generation tests. There are interface mismatches when using `generate` -- TODO @gante | ||
| all_generative_model_classes = () | ||
|
Comment on lines
+454
to
+455
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for my understanding: do we need to have empty
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If a model inherits
option 2 is intentionally annoying (we are forced to overwrite a property), so we are very explicit about skipping tests. We don't want skips to happen unless we're very intentional about it. |
||
| pipeline_model_mapping = ( | ||
| { | ||
| "feature-extraction": BigBirdModel, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can_generate()is only used inGenerationMixin-related code. Let's remove time series model from this function.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it completely different or uses part of
generate(), like some audio models?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's completely different ☠️