diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 54ce68c696187..2b76b36902977 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with the LightningCLI not being able to set the `ModelCheckpoint(save_last=...)` argument ([#19808](https://github.com/Lightning-AI/pytorch-lightning/pull/19808)) +- Fixed an issue causing ValueError for certain object such as TorchMetrics when dumping hyperparameters to YAML ([#19804](https://github.com/Lightning-AI/pytorch-lightning/pull/19804)) + ## [2.2.2] - 2024-04-11 diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index f8e9c8300337a..521192f500b53 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -359,7 +359,7 @@ def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], us try: v = v.name if isinstance(v, Enum) else v yaml.dump(v) - except TypeError: + except (TypeError, ValueError): warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.") hparams[k] = type(v).__name__ else: diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 0d7fced3b8197..e8a3cf680170b 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -552,7 +552,7 @@ def test_hparams_pickle_warning(tmp_path): trainer.fit(model) -def test_hparams_save_yaml(tmp_path): +def test_save_hparams_to_yaml(tmp_path): class Options(str, Enum): option1name = "option1val" option2name = "option2val" @@ -590,6 +590,14 @@ def _compare_params(loaded_params, default_params: dict): _compare_params(load_hparams_from_yaml(path_yaml), hparams) +def test_save_hparams_to_yaml_warning(tmp_path): + """Test that we warn about unserializable parameters that need to be dropped.""" + path_yaml = tmp_path / "hparams.yaml" + hparams = {"torch_type": torch.float32} + with pytest.warns(UserWarning, match="Skipping 'torch_type' parameter"): + save_hparams_to_yaml(path_yaml, hparams) + + class NoArgsSubClassBoringModel(CustomBoringModel): def __init__(self): super().__init__()