From 123ad5363f1a24bdf8ff97e10200f05144cdebf6 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 10 Aug 2023 10:42:34 +0100 Subject: [PATCH] Generation: strict generation config validation at save time (#25411) * strict gen config save; Add tests * add note that the warning will be an exception in v4.34 --- .../generation/configuration_utils.py | 30 ++++++++++++---- tests/generation/test_configuration_utils.py | 35 +++++++++++++++++++ 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 51644d9a6f0c93..ef0963f675d020 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -354,8 +354,8 @@ def validate(self, is_init=False): # 1. detect sampling-only parameterization when not in sampling mode if self.do_sample is False: greedy_wrong_parameter_msg = ( - "`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used " - "in sample-based generation modes. You should set `do_sample=True` or unset {flag_name}." + "`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only " + "used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`." + fix_location ) if self.temperature != 1.0: @@ -392,8 +392,8 @@ def validate(self, is_init=False): # 2. detect beam-only parameterization when not in beam mode if self.num_beams == 1: single_beam_wrong_parameter_msg = ( - "`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in " - "beam-based generation modes. You should set `num_beams>1` or unset {flag_name}." + fix_location + "`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used " + "in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + fix_location ) if self.early_stopping is not False: warnings.warn( @@ -430,9 +430,9 @@ def validate(self, is_init=False): # constrained beam search if self.constraints is not None: constrained_wrong_parameter_msg = ( - "`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to " - "{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset " - "{flag_name} to continue." + fix_location + "`constraints` is not `None`, triggering constrained beam search. However, `{flag_name}` is set " + "to `{flag_value}`, which is incompatible with this generation mode. Set `constraints=None` or " + "unset `{flag_name}` to continue." + fix_location ) if self.do_sample is True: raise ValueError( @@ -497,6 +497,22 @@ def save_pretrained( kwargs (`Dict[str, Any]`, *optional*): Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ + + # At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance + try: + with warnings.catch_warnings(record=True) as caught_warnings: + self.validate() + for w in caught_warnings: + raise ValueError(w.message) + except ValueError as exc: + warnings.warn( + "The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. " + "Fix these issues to save the configuration. This warning will be raised to an exception in v4.34." + "\n\nThrown during validation:\n" + str(exc), + UserWarning, + ) + return + use_auth_token = kwargs.pop("use_auth_token", None) if use_auth_token is not None: diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index f58b227c14598e..a181b00ee08d2c 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -14,8 +14,10 @@ # limitations under the License. import copy +import os import tempfile import unittest +import warnings from huggingface_hub import HfFolder, delete_repo from parameterized import parameterized @@ -118,6 +120,39 @@ def test_kwarg_init(self): self.assertEqual(loaded_config.do_sample, True) self.assertEqual(loaded_config.num_beams, 1) # default value + def test_refuse_to_save(self): + """Tests that we refuse to save a generation config that fails validation.""" + + # setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that + # is caught, doesn't save, and raises a warning + config = GenerationConfig() + config.temperature = 0.5 + with tempfile.TemporaryDirectory() as tmp_dir: + with warnings.catch_warnings(record=True) as captured_warnings: + config.save_pretrained(tmp_dir) + self.assertEqual(len(captured_warnings), 1) + self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message)) + self.assertTrue(len(os.listdir(tmp_dir)) == 0) + + # greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is + # caught, doesn't save, and raises a warning + config = GenerationConfig() + config.num_return_sequences = 2 + with tempfile.TemporaryDirectory() as tmp_dir: + with warnings.catch_warnings(record=True) as captured_warnings: + config.save_pretrained(tmp_dir) + self.assertEqual(len(captured_warnings), 1) + self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message)) + self.assertTrue(len(os.listdir(tmp_dir)) == 0) + + # final check: no warnings thrown if it is correct, and file is saved + config = GenerationConfig() + with tempfile.TemporaryDirectory() as tmp_dir: + with warnings.catch_warnings(record=True) as captured_warnings: + config.save_pretrained(tmp_dir) + self.assertEqual(len(captured_warnings), 0) + self.assertTrue(len(os.listdir(tmp_dir)) == 1) + @is_staging_test class ConfigPushToHubTester(unittest.TestCase):