Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generation: strict generation config validation at save time #25411

Merged
merged 5 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,8 @@ def validate(self, is_init=False):
# 1. detect sampling-only parameterization when not in sampling mode
if self.do_sample is False:
greedy_wrong_parameter_msg = (
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used "
"in sample-based generation modes. You should set `do_sample=True` or unset {flag_name}."
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
+ fix_location
)
if self.temperature != 1.0:
Expand Down Expand Up @@ -392,8 +392,8 @@ def validate(self, is_init=False):
# 2. detect beam-only parameterization when not in beam mode
if self.num_beams == 1:
single_beam_wrong_parameter_msg = (
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in "
"beam-based generation modes. You should set `num_beams>1` or unset {flag_name}." + fix_location
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + fix_location
)
if self.early_stopping is not False:
warnings.warn(
Expand Down Expand Up @@ -430,9 +430,9 @@ def validate(self, is_init=False):
# constrained beam search
if self.constraints is not None:
constrained_wrong_parameter_msg = (
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to "
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset "
"{flag_name} to continue." + fix_location
"`constraints` is not `None`, triggering constrained beam search. However, `{flag_name}` is set "
"to `{flag_value}`, which is incompatible with this generation mode. Set `constraints=None` or "
"unset `{flag_name}` to continue." + fix_location
)
if self.do_sample is True:
raise ValueError(
Expand Down Expand Up @@ -497,6 +497,21 @@ def save_pretrained(
kwargs (`Dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""

# At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance
try:
with warnings.catch_warnings(record=True) as caught_warnings:
self.validate()
for w in caught_warnings:
raise ValueError(w.message)
except ValueError as exc:
warnings.warn(
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
"Fix these issues to save the configuration.\n\nThrown during validation:\n" + str(exc),
UserWarning,
)
return
Copy link
Member Author

@gante gante Aug 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decided to throw a warning here (and not an exception) so this doesn't break long scripts like training -- calling a model save_pretrained also calls this function

Note that some old checks were already exceptions, the code here aims at moving them to the same level (exception) and treat them in a non-blocking way (warning)


use_auth_token = kwargs.pop("use_auth_token", None)

if use_auth_token is not None:
Expand Down
35 changes: 35 additions & 0 deletions tests/generation/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# limitations under the License.

import copy
import os
import tempfile
import unittest
import warnings

from huggingface_hub import HfFolder, delete_repo
from parameterized import parameterized
Expand Down Expand Up @@ -118,6 +120,39 @@ def test_kwarg_init(self):
self.assertEqual(loaded_config.do_sample, True)
self.assertEqual(loaded_config.num_beams, 1) # default value

def test_refuse_to_save(self):
"""Tests that we refuse to save a generation config that fails validation."""

# setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that
# is caught, doesn't save, and raises a warning
config = GenerationConfig()
config.temperature = 0.5
with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 1)
self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message))
self.assertTrue(len(os.listdir(tmp_dir)) == 0)

# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
# caught, doesn't save, and raises a warning
config = GenerationConfig()
config.num_return_sequences = 2
with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 1)
self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message))
self.assertTrue(len(os.listdir(tmp_dir)) == 0)

# final check: no warnings thrown if it is correct, and file is saved
config = GenerationConfig()
with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 0)
self.assertTrue(len(os.listdir(tmp_dir)) == 1)


@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
Expand Down
Loading