diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index 3f7873a615af71..d1dd006ac7db1a 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -5,6 +5,6 @@ matplotlib>3.1, <3.7.3 omegaconf >=2.0.5, <2.4.0 hydra-core >=1.0.5, <1.4.0 -jsonargparse[signatures] >=4.18.0, <4.24.0 # strict +jsonargparse[signatures] @ https://github.com/omni-us/jsonargparse/zipball/issue-170-class-instantiator rich >=12.3.0, <=13.5.2 tensorboardX >=2.2, <=2.6.2 # min version is set by torch.onnx missing attribute diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 86fd7a9e7cd78e..20c9fb8a4cbc96 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -103,6 +103,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added validation of user input for `devices` and `num_nodes` when running with `SLURM` or `TorchElastic` ([#18292](https://github.com/Lightning-AI/lightning/pull/18292)) +- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105)) + ### Changed diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 95f105402b608a..691982c1d0fb41 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -15,9 +15,10 @@ import sys from functools import partial, update_wrapper from types import MethodType -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union import torch +import yaml from lightning_utilities.core.imports import RequirementCache from lightning_utilities.core.rank_zero import _warn from torch.optim import Optimizer @@ -26,6 +27,7 @@ from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER from lightning.pytorch import Callback, LightningDataModule, LightningModule, seed_everything, Trainer +from lightning.pytorch.core.mixins.hparams_mixin import given_hyperparameters_context from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.rank_zero import rank_zero_warn @@ -196,6 +198,30 @@ def add_lr_scheduler_args( self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs) self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) + def class_instantiator(self, class_type, *args, **kwargs): + for key, (base_type, hparams) in getattr(self, "_hparam_context", {}).items(): + if issubclass(class_type, base_type): + with given_hyperparameters_context(hparams): + return super().class_instantiator(class_type, *args, **kwargs) + return super().class_instantiator(class_type, *args, **kwargs) + + def instantiate_classes( + self, + cfg: Namespace, + instantiate_groups: bool = True, + hparam_context: Optional[Dict[str, type]] = None, + ) -> Namespace: + if hparam_context: + cfg_dict = yaml.safe_load(self.dump(cfg)) # TODO: do not remove link targets! + self._hparam_context = {} + for key, base_type in hparam_context.items(): + hparams = cfg_dict.get(key, {}) + self._hparam_context[key] = (base_type, hparams) + init = super().instantiate_classes(cfg, instantiate_groups=instantiate_groups) + if hparam_context: + delattr(self, "_hparam_context") + return init + class SaveConfigCallback(Callback): """Saves a LightningCLI config to the log_dir when training starts. @@ -530,7 +556,13 @@ def before_instantiate_classes(self) -> None: def instantiate_classes(self) -> None: """Instantiates the classes and sets their attributes.""" - self.config_init = self.parser.instantiate_classes(self.config) + hparam_prefix = "" + if "subcommand" in self.config: + hparam_prefix = self.config["subcommand"] + "." + hparam_context = {hparam_prefix + "model": self._model_class} + if self.datamodule_class is not None: + hparam_context[hparam_prefix + "data"] = self._datamodule_class + self.config_init = self.parser.instantiate_classes(self.config, hparam_context=hparam_context) self.datamodule = self._get(self.config_init, "data") self.model = self._get(self.config_init, "model") self._add_configure_optimizers_method_to_model(self.subcommand) @@ -754,3 +786,17 @@ def _get_short_description(component: object) -> Optional[str]: return docstring.short_description except (ValueError, docstring_parser.ParseError) as ex: rank_zero_warn(f"Failed parsing docstring for {component}: {ex}") + + +ModuleType = TypeVar("ModuleType") + + +def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType: + parser = ArgumentParser(exit_on_error=False) + if "class_path" in config: + parser.add_subclass_arguments(class_type, "module") + else: + parser.add_class_arguments(class_type, "module") + cfg = parser.parse_object({"module": config}) + init = parser.instantiate_classes(cfg) + return init.module diff --git a/src/lightning/pytorch/core/mixins/hparams_mixin.py b/src/lightning/pytorch/core/mixins/hparams_mixin.py index ca6ad172e0725f..d5640b5b7684ad 100644 --- a/src/lightning/pytorch/core/mixins/hparams_mixin.py +++ b/src/lightning/pytorch/core/mixins/hparams_mixin.py @@ -15,6 +15,8 @@ import inspect import types from argparse import Namespace +from contextlib import contextmanager +from contextvars import ContextVar from typing import Any, List, MutableMapping, Optional, Sequence, Union from lightning.pytorch.utilities.parsing import AttributeDict, save_hyperparameters @@ -23,6 +25,18 @@ _ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) +given_hyperparameters: ContextVar = ContextVar("given_hyperparameters", default=None) + + +@contextmanager +def given_hyperparameters_context(value): + token = given_hyperparameters.set(value) + try: + yield + finally: + given_hyperparameters.reset(token) + + class HyperparametersMixin: __jit_unused_properties__: List[str] = ["hparams", "hparams_initial"] @@ -103,12 +117,13 @@ class ``__init__`` to be ignored "arg3": 3.14 """ self._log_hyperparams = logger + given_hparams = given_hyperparameters.get() # the frame needs to be created in this file. - if not frame: + if given_hparams is None and not frame: current_frame = inspect.currentframe() if current_frame: frame = current_frame.f_back - save_hyperparameters(self, *args, ignore=ignore, frame=frame) + save_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams) def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None: hp = self._to_hparams_dict(hp) diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 0321a61bb2863f..cde3f546e6f1d4 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -123,6 +123,7 @@ def _load_state( cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], checkpoint: Dict[str, Any], strict: Optional[bool] = None, + instantiator=None, **cls_kwargs_new: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: cls_spec = inspect.getfullargspec(cls.__init__) @@ -160,7 +161,7 @@ def _load_state( # filter kwargs according to class init unless it allows any argument via kwargs _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name} - obj = cls(**_cls_kwargs) + obj = instantiator(cls, _cls_kwargs) if instantiator else cls(**_cls_kwargs) if isinstance(obj, pl.LightningModule): # give model a chance to load something diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 7ed77af18930af..41db289646ab21 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -138,7 +138,11 @@ def collect_init_args( def save_hyperparameters( - obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None + obj: Any, + *args: Any, + ignore: Optional[Union[Sequence[str], str]] = None, + frame: Optional[types.FrameType] = None, + given_hparams: Optional[Dict[str, Any]] = None, ) -> None: """See :meth:`~lightning.pytorch.LightningModule.save_hyperparameters`""" @@ -154,7 +158,9 @@ def save_hyperparameters( if not isinstance(frame, types.FrameType): raise AttributeError("There is no `frame` available while being required.") - if is_dataclass(obj): + if given_hparams is not None: + init_args = given_hparams + elif is_dataclass(obj): init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} else: init_args = {} diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index e21bb4b241b13a..a0df406d387bd9 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -40,6 +40,7 @@ from lightning.pytorch.cli import ( _JSONARGPARSE_SIGNATURES_AVAILABLE, instantiate_class, + instantiate_module, LightningArgumentParser, LightningCLI, LRSchedulerCallable, @@ -833,6 +834,53 @@ def configure_optimizers(self): assert init[1]["lr_scheduler"].gamma == 0.3 +def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir): + class TestModel(BoringModel): + def __init__( + self, + optimizer: OptimizerCallable = torch.optim.Adam, + scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, + activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReLU, negative_slope=0.05), + ): + super().__init__() + self.save_hyperparameters() + self.optimizer = optimizer + self.scheduler = scheduler + self.activation = activation + + def configure_optimizers(self): + optimizer = self.optimizer(self.parameters()) + scheduler = self.scheduler(optimizer) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + with mock.patch("sys.argv", ["any.py", "--trainer.max_epochs=1"]): + cli = LightningCLI(TestModel, run=False, auto_configure_optimizers=False) + cli.trainer.fit(cli.model) + + hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml" + assert hparams_path.is_file() + hparams = yaml.safe_load(hparams_path.read_text()) + expected = { + "optimizer": "torch.optim.Adam", + "scheduler": "torch.optim.lr_scheduler.ConstantLR", + "activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}}, + } + assert hparams == expected + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None) + assert checkpoint_path.is_file() + ckpt = torch.load(checkpoint_path) + assert ckpt["hyper_parameters"] == expected + + model = TestModel.load_from_checkpoint(checkpoint_path, instantiator=instantiate_module) + assert isinstance(model, TestModel) + assert isinstance(model.activation, torch.nn.LeakyReLU) + assert model.activation.negative_slope == 0.05 + optimizer, lr_scheduler = model.configure_optimizers().values() + assert isinstance(optimizer, torch.optim.Adam) + assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR) + + @pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn]) def test_lightning_cli_trainer_fn(fn): class TestCLI(LightningCLI):