Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor CI: more explicit #30674

Merged
merged 207 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 115 commits
Commits
Show all changes
207 commits
Select commit Hold shift + click to select a range
806500d
don't run custom when not needed?
ArthurZucker May 6, 2024
f24c92f
update test fetcher filtering
ArthurZucker May 6, 2024
2286e70
fixup and updates
ArthurZucker May 7, 2024
1a7ce2b
update
ArthurZucker May 7, 2024
8cdf454
update
ArthurZucker May 7, 2024
cfd83d4
reduce burden
ArthurZucker May 7, 2024
2fe4894
nit
ArthurZucker May 7, 2024
6a1319c
nit
ArthurZucker May 7, 2024
4a0d81d
Merge branch 'main' of github.com:huggingface/transformers into impro…
ArthurZucker May 7, 2024
d142602
mising comma
ArthurZucker May 7, 2024
2c67087
this?
ArthurZucker May 7, 2024
253cc17
this?
ArthurZucker May 7, 2024
34217bd
more parallelism
ArthurZucker May 7, 2024
d57d6c8
more
ArthurZucker May 7, 2024
2115e53
nit for real parallelism on tf and torch examples
ArthurZucker May 7, 2024
691095a
update
ArthurZucker May 7, 2024
74e915d
update
ArthurZucker May 7, 2024
6f13afd
update
ArthurZucker May 7, 2024
027ddae
update
ArthurZucker May 7, 2024
c6612ce
update
ArthurZucker May 7, 2024
5776118
update
ArthurZucker May 7, 2024
2f2af2c
update
ArthurZucker May 7, 2024
e2a7140
update
ArthurZucker May 7, 2024
285efdd
update
ArthurZucker May 7, 2024
ef09327
update
ArthurZucker May 7, 2024
2578e56
update
ArthurZucker May 7, 2024
9b16746
update
ArthurZucker May 7, 2024
86e5f2d
update to make it more custom
ArthurZucker May 7, 2024
03e3064
update to make it more custom
ArthurZucker May 7, 2024
1670d48
update to make it more custom
ArthurZucker May 7, 2024
f1c18bf
update to make it more custom
ArthurZucker May 7, 2024
7b35a6f
Merge branch 'main' of github.com:huggingface/transformers into impro…
ArthurZucker May 7, 2024
ababb92
update
ArthurZucker May 7, 2024
1fcbdc2
update
ArthurZucker May 7, 2024
26667e7
update
ArthurZucker May 7, 2024
428aa49
update
ArthurZucker May 7, 2024
8423ae3
update
ArthurZucker May 7, 2024
481250a
update
ArthurZucker May 7, 2024
7cf9073
use correct path
ArthurZucker May 7, 2024
80ba28d
fix path to test files and examples
ArthurZucker May 7, 2024
f7386d3
filter-tests
ArthurZucker May 7, 2024
a4cdd9b
filter?
ArthurZucker May 7, 2024
a10e514
filter?
ArthurZucker May 7, 2024
c8b96e3
filter?
ArthurZucker May 7, 2024
9ac28cb
nits
ArthurZucker May 7, 2024
31373d4
fix naming of the artifacts to be pushed
ArthurZucker May 7, 2024
3f152ba
list vs files
ArthurZucker May 7, 2024
91f1acf
list vs files
ArthurZucker May 7, 2024
09dd0f8
fixup
ArthurZucker May 7, 2024
3b53d12
fix list of all tests
ArthurZucker May 7, 2024
50e431b
fix the install steps
ArthurZucker May 7, 2024
09fef76
fix the install steps
ArthurZucker May 7, 2024
765887d
fix the config
ArthurZucker May 7, 2024
63a36ac
fix the config
ArthurZucker May 7, 2024
fcc8d2e
only split if needed
ArthurZucker May 7, 2024
00337b5
only split if needed
ArthurZucker May 7, 2024
4fcefd7
extend should fix it
ArthurZucker May 7, 2024
1f5e218
extend should fix it
ArthurZucker May 7, 2024
1b7d273
arg
ArthurZucker May 7, 2024
e7ca4bc
arg
ArthurZucker May 7, 2024
19a4796
update
ArthurZucker May 7, 2024
e719ffb
update
ArthurZucker May 7, 2024
7650147
run tests
ArthurZucker May 7, 2024
7ad1e9a
run tests
ArthurZucker May 7, 2024
d63a356
run tests
ArthurZucker May 7, 2024
7e28c5b
more nits
ArthurZucker May 7, 2024
32f4f1d
update
ArthurZucker May 7, 2024
fe9c153
update
ArthurZucker May 7, 2024
0ae7334
update
ArthurZucker May 7, 2024
27f17a2
update
ArthurZucker May 7, 2024
36c690e
update
ArthurZucker May 8, 2024
bedc625
update
ArthurZucker May 8, 2024
f0714ff
update
ArthurZucker May 8, 2024
b52eadb
simpler way to show the test, reduces the complexity of the generated…
ArthurZucker May 8, 2024
bc7a843
simpler way to show the test, reduces the complexity of the generated…
ArthurZucker May 8, 2024
d582572
style
ArthurZucker May 8, 2024
8bb929e
oups
ArthurZucker May 8, 2024
bbbae00
oups
ArthurZucker May 8, 2024
2619e19
fix import errors
ArthurZucker May 8, 2024
7e53974
skip some tests for now
ArthurZucker May 8, 2024
a01cdd7
update doctestjob
ArthurZucker May 8, 2024
eda4f6b
more parallelism
ArthurZucker May 8, 2024
669aeaa
fixup
ArthurZucker May 8, 2024
a0fcdd0
test only the test in examples
ArthurZucker May 8, 2024
2d617b9
test only the test in examples
ArthurZucker May 8, 2024
8307b91
nits
ArthurZucker May 8, 2024
a3ffbd3
Merge branch 'main' of github.com:huggingface/transformers into impro…
ArthurZucker Jun 4, 2024
cf68e36
from Arthur
ydshieh Jun 6, 2024
4f5c896
Merge branches 'improve_fetcher' and 'main' of github.com:huggingface…
ArthurZucker Aug 8, 2024
eb3cf68
Merge branch 'main' of github.com:huggingface/transformers into impro…
ArthurZucker Aug 23, 2024
f7f4dfc
fix generated congi
ArthurZucker Aug 23, 2024
4608b9a
update
ArthurZucker Aug 23, 2024
de36383
update
ArthurZucker Aug 23, 2024
3600c7e
show tests
ArthurZucker Aug 23, 2024
ed083c1
oups
ArthurZucker Aug 23, 2024
79fb360
oups
ArthurZucker Aug 23, 2024
fc3617a
fix torch job for now
ArthurZucker Aug 23, 2024
c96eaa5
use single upload setp
ArthurZucker Aug 23, 2024
813ffcd
oups
ArthurZucker Aug 23, 2024
5af8ade
fu**k
ArthurZucker Aug 23, 2024
02b1b55
fix
ArthurZucker Aug 23, 2024
f00d715
nit
ArthurZucker Aug 23, 2024
4867ca5
update
ArthurZucker Aug 23, 2024
fa95aab
nit
ArthurZucker Aug 23, 2024
d40008c
fix
ArthurZucker Aug 23, 2024
f5a1543
fixes
ArthurZucker Aug 23, 2024
b9ac232
[test-all]
ArthurZucker Aug 23, 2024
73434f3
add generate marker and generate job
ArthurZucker Aug 23, 2024
aa31101
oups
ArthurZucker Aug 23, 2024
57c6cd0
torch job runs not generate tests
ArthurZucker Aug 23, 2024
7487359
let repo utils test all utils
ArthurZucker Aug 23, 2024
9df5097
UPdate
ArthurZucker Aug 23, 2024
c0a6c77
styling
ArthurZucker Aug 23, 2024
8ea7a67
fix repo utils test
ArthurZucker Aug 23, 2024
e683c37
more parallel please
ArthurZucker Aug 23, 2024
a606c08
Merge branch 'main' of github.com:huggingface/transformers into impro…
ArthurZucker Aug 29, 2024
0c2571b
don't test
ArthurZucker Aug 29, 2024
bdb29a3
update
ArthurZucker Aug 29, 2024
988df82
bit more verbose sir
ArthurZucker Aug 29, 2024
07a76c8
more
ArthurZucker Aug 29, 2024
8453ac9
hub were skipped
ArthurZucker Aug 29, 2024
d2fd4a6
split by classname
ArthurZucker Aug 29, 2024
e4266be
revert
ArthurZucker Aug 29, 2024
6a1be60
maybe?
ArthurZucker Aug 29, 2024
caaadef
Amazing catch
ArthurZucker Aug 29, 2024
9180e62
fix
ArthurZucker Aug 29, 2024
6d7f6bd
update
ArthurZucker Aug 29, 2024
0ba0463
update
ArthurZucker Aug 29, 2024
048266d
maybe non capturing
ArthurZucker Aug 29, 2024
3dffd4c
manual convert?
ArthurZucker Aug 29, 2024
fef117f
pass artifacts as parameters as otherwise the config is too long
ArthurZucker Aug 30, 2024
a9ca273
artifact.json
ArthurZucker Aug 30, 2024
6f78219
store output
ArthurZucker Aug 30, 2024
88fe213
might not be safe?
ArthurZucker Aug 30, 2024
4997483
my token
ArthurZucker Aug 30, 2024
9e989e2
mmm?
ArthurZucker Aug 30, 2024
17e7538
use CI job IS
ArthurZucker Aug 30, 2024
f27df25
can't get a proper id?
ArthurZucker Aug 30, 2024
b5ed61f
ups
ArthurZucker Aug 30, 2024
71641a8
build num
ArthurZucker Aug 30, 2024
852af8b
update
ArthurZucker Aug 30, 2024
94db0a1
echo url
ArthurZucker Aug 30, 2024
ecd3885
this?
ArthurZucker Aug 30, 2024
76fa821
this!
ArthurZucker Aug 30, 2024
16ae707
fix
ArthurZucker Aug 30, 2024
f03e231
wget
ArthurZucker Aug 30, 2024
65758e2
ish
ArthurZucker Aug 30, 2024
312b8da
dang
ArthurZucker Aug 30, 2024
4366a96
udpdate
ArthurZucker Aug 30, 2024
6d791db
there we go
ArthurZucker Aug 30, 2024
b8071c6
update
ArthurZucker Aug 30, 2024
1c70bce
update
ArthurZucker Aug 30, 2024
021d458
pass all
ArthurZucker Aug 30, 2024
0e1ee90
not .txt
ArthurZucker Aug 30, 2024
8e4ed17
update
ArthurZucker Aug 30, 2024
a438e50
fetcg
ArthurZucker Aug 30, 2024
908d0c6
fix naming
ArthurZucker Aug 30, 2024
a5c9ba3
fix
ArthurZucker Aug 30, 2024
346fce8
up
ArthurZucker Aug 30, 2024
2ec7266
update
ArthurZucker Aug 30, 2024
4c135dc
update
ArthurZucker Aug 30, 2024
6a8fee2
??
ArthurZucker Aug 30, 2024
d3048b5
update
ArthurZucker Aug 30, 2024
ffa46fa
more updates
ArthurZucker Aug 30, 2024
b9e2881
update
ArthurZucker Aug 30, 2024
f36201c
more
ArthurZucker Aug 30, 2024
8bb4eb4
skip
ArthurZucker Aug 30, 2024
5ed59ea
oups
ArthurZucker Aug 30, 2024
1980be4
pr documentation tests are currently created differently
ArthurZucker Aug 30, 2024
785fc50
update
ArthurZucker Aug 30, 2024
862e02b
hmmmm
ArthurZucker Aug 30, 2024
44fc481
oups
ArthurZucker Aug 30, 2024
9f23e5d
curl -L
ArthurZucker Aug 30, 2024
b81c5d5
update
ArthurZucker Aug 30, 2024
36f40cb
????
ArthurZucker Aug 30, 2024
5924212
nit
ArthurZucker Aug 30, 2024
f3b1175
mmmm
ArthurZucker Aug 30, 2024
1685930
ish
ArthurZucker Aug 30, 2024
c4cffb2
ouf
ArthurZucker Aug 30, 2024
c9c7206
update
ArthurZucker Aug 30, 2024
397a8da
ish
ArthurZucker Aug 30, 2024
dc7ba2f
update
ArthurZucker Aug 30, 2024
e0cf368
update
ArthurZucker Aug 30, 2024
ef47546
updatea
ArthurZucker Aug 30, 2024
8dd2877
nit
ArthurZucker Aug 30, 2024
29351e8
nit
ArthurZucker Aug 30, 2024
7566840
up
ArthurZucker Aug 30, 2024
4633c3c
oups
ArthurZucker Aug 30, 2024
4586d19
documentation_test fix
ArthurZucker Aug 30, 2024
5403588
test hub tests everything, just marker
ArthurZucker Aug 30, 2024
c507956
Merge branch 'main' of github.com:huggingface/transformers into impro…
ArthurZucker Aug 30, 2024
cfa8c8c
update
ArthurZucker Aug 30, 2024
36f7aa2
fix
ArthurZucker Aug 30, 2024
b832165
test_hub is the only annoying one now
ArthurZucker Aug 30, 2024
61d33d9
tf threads?
ArthurZucker Aug 30, 2024
5013546
oups
ArthurZucker Aug 30, 2024
10f2bd1
not sure what is happening?
ArthurZucker Aug 30, 2024
202d9db
fix?
ArthurZucker Aug 30, 2024
eea1314
just use folder for stating hub
ArthurZucker Aug 30, 2024
a6f7edd
I am getting fucking annoyed
ArthurZucker Aug 30, 2024
da3bafb
fix the test?
ArthurZucker Aug 30, 2024
c8251e8
update
ArthurZucker Aug 30, 2024
5e6a81a
uupdate
ArthurZucker Aug 30, 2024
86f5435
?
ArthurZucker Aug 30, 2024
c0dbe1b
fixes
ArthurZucker Aug 30, 2024
9e9d3ef
add comment!
ArthurZucker Aug 30, 2024
8c99bac
nit
ArthurZucker Aug 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 1 addition & 45 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,53 +34,9 @@ jobs:
- run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV"
- run: mkdir -p test_preparation
- run: python utils/tests_fetcher.py | tee tests_fetched_summary.txt
- store_artifacts:
path: ~/transformers/tests_fetched_summary.txt
- run: |
if [ -f test_list.txt ]; then
cp test_list.txt test_preparation/test_list.txt
else
touch test_preparation/test_list.txt
fi
- run: |
if [ -f examples_test_list.txt ]; then
mv examples_test_list.txt test_preparation/examples_test_list.txt
else
touch test_preparation/examples_test_list.txt
fi
- run: |
if [ -f filtered_test_list_cross_tests.txt ]; then
mv filtered_test_list_cross_tests.txt test_preparation/filtered_test_list_cross_tests.txt
else
touch test_preparation/filtered_test_list_cross_tests.txt
fi
- run: |
if [ -f doctest_list.txt ]; then
cp doctest_list.txt test_preparation/doctest_list.txt
else
touch test_preparation/doctest_list.txt
fi
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
- run: |
if [ -f test_repo_utils.txt ]; then
mv test_repo_utils.txt test_preparation/test_repo_utils.txt
else
touch test_preparation/test_repo_utils.txt
fi
- run: python utils/tests_fetcher.py --filter_tests
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
- run: |
if [ -f test_list.txt ]; then
mv test_list.txt test_preparation/filtered_test_list.txt
else
touch test_preparation/filtered_test_list.txt
fi
- store_artifacts:
path: test_preparation/test_list.txt
- store_artifacts:
path: test_preparation/doctest_list.txt
- store_artifacts:
path: ~/transformers/test_preparation/filtered_test_list.txt
- store_artifacts:
path: test_preparation/examples_test_list.txt
path: test_preparation
- run: export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)" && echo $GIT_COMMIT_MESSAGE && python .circleci/create_circleci_config.py --fetcher_folder test_preparation
- run: |
if [ ! -s test_preparation/generated_config.yml ]; then
Expand Down
381 changes: 105 additions & 276 deletions .circleci/create_circleci_config.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/pytorch/language-modeling/run_fim.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@
Trainer,
TrainingArguments,
default_data_collator,
is_deepspeed_zero3_enabled,
is_torch_tpu_available,
set_seed,
)
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.testing_utils import CaptureLogger
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/language-modeling/run_fim_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
SchedulerType,
default_data_collator,
get_scheduler,
is_deepspeed_zero3_enabled,
is_torch_tpu_available,
)
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
markers = [
"flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
"generate: marks tests that use the GenerationTesterMixin"
]
37 changes: 37 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import warnings

import numpy as np
import pytest
from parameterized import parameterized

from transformers import is_torch_available, pipeline, set_seed
Expand Down Expand Up @@ -88,6 +89,7 @@
from transformers.generation.utils import _speculative_sampling


@pytest.mark.generate
class GenerationTesterMixin:
model_tester = None
all_generative_model_classes = ()
Expand Down Expand Up @@ -417,6 +419,7 @@ def _contrastive_generate(

return output_generate

@pytest.mark.generate
def test_greedy_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
Expand All @@ -429,6 +432,7 @@ def test_greedy_generate(self):
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])

@pytest.mark.generate
def test_greedy_generate_dict_outputs(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
Expand Down Expand Up @@ -459,6 +463,7 @@ def test_greedy_generate_dict_outputs(self):

self._check_outputs(output_generate, input_ids, model.config)

@pytest.mark.generate
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
Expand Down Expand Up @@ -488,6 +493,7 @@ def test_greedy_generate_dict_outputs_use_cache(self):
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)

@pytest.mark.generate
def test_sample_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
Expand All @@ -505,6 +511,7 @@ def test_sample_generate(self):
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])

@pytest.mark.generate
def test_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
Expand Down Expand Up @@ -536,6 +543,7 @@ def test_sample_generate_dict_output(self):

self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=2)

@pytest.mark.generate
def test_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
Expand All @@ -555,6 +563,7 @@ def test_beam_search_generate(self):
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])

@pytest.mark.generate
def test_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
Expand Down Expand Up @@ -588,6 +597,7 @@ def test_beam_search_generate_dict_output(self):
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)

@pytest.mark.generate
def test_beam_search_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
# enable cache
Expand Down Expand Up @@ -626,6 +636,7 @@ def test_beam_search_generate_dict_outputs_use_cache(self):

@require_accelerate
@require_torch_multi_accelerator
@pytest.mark.generate
def test_model_parallel_beam_search(self):
for model_class in self.all_generative_model_classes:
if "xpu" in torch_device:
Expand All @@ -648,6 +659,7 @@ def test_model_parallel_beam_search(self):
num_beams=2,
)

@pytest.mark.generate
def test_beam_sample_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
Expand Down Expand Up @@ -684,6 +696,7 @@ def test_beam_sample_generate(self):

torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)

@pytest.mark.generate
def test_beam_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
Expand Down Expand Up @@ -719,6 +732,7 @@ def test_beam_sample_generate_dict_output(self):
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)

@pytest.mark.generate
def test_generate_without_input_ids(self):
config, _, _ = self._get_input_ids_and_config()

Expand All @@ -739,6 +753,7 @@ def test_generate_without_input_ids(self):
)
self.assertIsNotNone(output_ids_generate)

@pytest.mark.generate
def test_group_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
Expand Down Expand Up @@ -771,6 +786,7 @@ def test_group_beam_search_generate(self):
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])

@pytest.mark.generate
def test_group_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
Expand Down Expand Up @@ -806,6 +822,7 @@ def test_group_beam_search_generate_dict_output(self):

# TODO: @gante
@is_flaky()
@pytest.mark.generate
def test_constrained_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
Expand Down Expand Up @@ -863,6 +880,7 @@ def test_constrained_beam_search_generate(self):
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)

@pytest.mark.generate
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
Expand Down Expand Up @@ -907,6 +925,7 @@ def test_constrained_beam_search_generate_dict_output(self):
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)

@pytest.mark.generate
def test_contrastive_generate(self):
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
Expand All @@ -933,6 +952,7 @@ def test_contrastive_generate(self):
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])

@pytest.mark.generate
def test_contrastive_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
Expand Down Expand Up @@ -968,6 +988,7 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)

@pytest.mark.generate
def test_contrastive_generate_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
Expand Down Expand Up @@ -1011,6 +1032,7 @@ def test_contrastive_generate_low_memory(self):
)
self.assertListEqual(low_output.tolist(), high_output.tolist())

@pytest.mark.generate
def test_beam_search_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
Expand Down Expand Up @@ -1053,6 +1075,7 @@ def test_beam_search_low_memory(self):
)
self.assertListEqual(low_output.tolist(), high_output.tolist())

@pytest.mark.generate
@parameterized.expand([("random",), ("same",)])
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
Expand Down Expand Up @@ -1134,6 +1157,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
self._check_outputs(output, input_ids, model.config, use_cache=True)

@is_flaky()
@pytest.mark.generate
def test_prompt_lookup_decoding_matches_greedy_search(self):
# This test ensures that the prompt lookup generation does not introduce output changes over greedy search.
# This test is mostly a copy of test_assisted_decoding_matches_greedy_search
Expand Down Expand Up @@ -1196,6 +1220,7 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
for output in (output_greedy, output_prompt_lookup):
self._check_outputs(output, input_ids, model.config, use_cache=True)

@pytest.mark.generate
def test_dola_decoding_sample(self):
# TODO (joao): investigate skips, try to reduce incompatibilities
for model_class in self.all_generative_model_classes:
Expand Down Expand Up @@ -1240,6 +1265,7 @@ def test_dola_decoding_sample(self):
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs)
self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache"))

@pytest.mark.generate
def test_assisted_decoding_sample(self):
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
Expand Down Expand Up @@ -1299,6 +1325,7 @@ def test_assisted_decoding_sample(self):

self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)

@pytest.mark.generate
def test_prompt_lookup_decoding_stops_at_eos(self):
# This test ensures that the prompt lookup generation stops at eos token and does not suggest more tokens
# (see https://github.com/huggingface/transformers/pull/31301)
Expand Down Expand Up @@ -1327,6 +1354,7 @@ def test_prompt_lookup_decoding_stops_at_eos(self):
# PLD shouldn't propose any new tokens based on eos-match
self.assertTrue(output_prompt_lookup.shape[-1] == 10)

@pytest.mark.generate
def test_generate_with_head_masking(self):
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
Expand Down Expand Up @@ -1366,6 +1394,7 @@ def test_generate_with_head_masking(self):
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)

@pytest.mark.generate
def test_left_padding_compatibility(self):
# NOTE: left-padding results in small numerical differences. This is expected.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
Expand Down Expand Up @@ -1434,6 +1463,7 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature):
# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5))

@pytest.mark.generate
def test_past_key_values_format(self):
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
# standard KV cache format is important for a consistent API (and for advanced generation methods).
Expand Down Expand Up @@ -1505,6 +1535,7 @@ def test_past_key_values_format(self):
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)

@pytest.mark.generate
def test_generate_from_inputs_embeds_decoder_only(self):
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
# if fails, you should probably update the `prepare_inputs_for_generation` function
Expand Down Expand Up @@ -1555,6 +1586,7 @@ def test_generate_from_inputs_embeds_decoder_only(self):
outputs_from_embeds_wo_ids.tolist(),
)

@pytest.mark.generate
def test_generate_continue_from_past_key_values(self):
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
Expand Down Expand Up @@ -1632,6 +1664,7 @@ def test_generate_continue_from_past_key_values(self):
)

@parameterized.expand([(1, False), (1, True), (4, False)])
@pytest.mark.generate
def test_new_cache_format(self, num_beams, do_sample):
# Tests that generating with the new format is exactly the same as the legacy one (for models that support it).
# 👉 tests with and without beam search so that we can test with and without cache reordering.
Expand Down Expand Up @@ -1696,6 +1729,7 @@ def test_new_cache_format(self, num_beams, do_sample):
)
)

@pytest.mark.generate
def test_generate_with_static_cache(self):
"""
Tests if StaticCache works if we set attn_implementation=static when generation.
Expand Down Expand Up @@ -1744,6 +1778,7 @@ def test_generate_with_static_cache(self):
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)

@require_quanto
@pytest.mark.generate
def test_generate_with_quant_cache(self):
for model_class in self.all_generative_model_classes:
if not model_class._supports_quantized_cache:
Expand Down Expand Up @@ -1776,6 +1811,7 @@ def test_generate_with_quant_cache(self):
with self.assertRaises(ValueError):
model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)

@pytest.mark.generate
@require_torch_gpu
@slow
@is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky
Expand Down Expand Up @@ -2078,6 +2114,7 @@ def test_speculative_sampling(self):
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])


@pytest.mark.generate
@require_torch
class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
Expand Down
Loading
Loading