Skip to content

Commit

Permalink
Generation: strict generation config validation at save time (#25411)
Browse files Browse the repository at this point in the history
* strict gen config save; Add tests

* add note that the warning will be an exception in v4.34
  • Loading branch information
gante authored Aug 10, 2023
1 parent 16edf4d commit 123ad53
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
30 changes: 23 additions & 7 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,8 @@ def validate(self, is_init=False):
# 1. detect sampling-only parameterization when not in sampling mode
if self.do_sample is False:
greedy_wrong_parameter_msg = (
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used "
"in sample-based generation modes. You should set `do_sample=True` or unset {flag_name}."
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
+ fix_location
)
if self.temperature != 1.0:
Expand Down Expand Up @@ -392,8 +392,8 @@ def validate(self, is_init=False):
# 2. detect beam-only parameterization when not in beam mode
if self.num_beams == 1:
single_beam_wrong_parameter_msg = (
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in "
"beam-based generation modes. You should set `num_beams>1` or unset {flag_name}." + fix_location
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + fix_location
)
if self.early_stopping is not False:
warnings.warn(
Expand Down Expand Up @@ -430,9 +430,9 @@ def validate(self, is_init=False):
# constrained beam search
if self.constraints is not None:
constrained_wrong_parameter_msg = (
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to "
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset "
"{flag_name} to continue." + fix_location
"`constraints` is not `None`, triggering constrained beam search. However, `{flag_name}` is set "
"to `{flag_value}`, which is incompatible with this generation mode. Set `constraints=None` or "
"unset `{flag_name}` to continue." + fix_location
)
if self.do_sample is True:
raise ValueError(
Expand Down Expand Up @@ -497,6 +497,22 @@ def save_pretrained(
kwargs (`Dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""

# At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance
try:
with warnings.catch_warnings(record=True) as caught_warnings:
self.validate()
for w in caught_warnings:
raise ValueError(w.message)
except ValueError as exc:
warnings.warn(
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
"Fix these issues to save the configuration. This warning will be raised to an exception in v4.34."
"\n\nThrown during validation:\n" + str(exc),
UserWarning,
)
return

use_auth_token = kwargs.pop("use_auth_token", None)

if use_auth_token is not None:
Expand Down
35 changes: 35 additions & 0 deletions tests/generation/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# limitations under the License.

import copy
import os
import tempfile
import unittest
import warnings

from huggingface_hub import HfFolder, delete_repo
from parameterized import parameterized
Expand Down Expand Up @@ -118,6 +120,39 @@ def test_kwarg_init(self):
self.assertEqual(loaded_config.do_sample, True)
self.assertEqual(loaded_config.num_beams, 1) # default value

def test_refuse_to_save(self):
"""Tests that we refuse to save a generation config that fails validation."""

# setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that
# is caught, doesn't save, and raises a warning
config = GenerationConfig()
config.temperature = 0.5
with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 1)
self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message))
self.assertTrue(len(os.listdir(tmp_dir)) == 0)

# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
# caught, doesn't save, and raises a warning
config = GenerationConfig()
config.num_return_sequences = 2
with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 1)
self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message))
self.assertTrue(len(os.listdir(tmp_dir)) == 0)

# final check: no warnings thrown if it is correct, and file is saved
config = GenerationConfig()
with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 0)
self.assertTrue(len(os.listdir(tmp_dir)) == 1)


@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
Expand Down

0 comments on commit 123ad53

Please sign in to comment.