Skip to content

Commit

Permalink
load_from_checkpoint support for LightningCLI when using dependency i…
Browse files Browse the repository at this point in the history
…njection.
  • Loading branch information
mauvilsa committed Aug 22, 2023
1 parent fbdbe63 commit 8388f88
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 8 deletions.
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
matplotlib>3.1, <3.7.3
omegaconf >=2.0.5, <2.4.0
hydra-core >=1.0.5, <1.4.0
jsonargparse[signatures] >=4.18.0, <4.24.0 # strict
jsonargparse[signatures] @ https://github.com/omni-us/jsonargparse/zipball/issue-170-class-instantiator
rich >=12.3.0, <=13.5.2
tensorboardX >=2.2, <=2.6.2 # min version is set by torch.onnx missing attribute
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added validation of user input for `devices` and `num_nodes` when running with `SLURM` or `TorchElastic` ([#18292](https://github.com/Lightning-AI/lightning/pull/18292))


- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105))


### Changed

Expand Down
50 changes: 48 additions & 2 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import sys
from functools import partial, update_wrapper
from types import MethodType
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union

import torch
import yaml
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import _warn
from torch.optim import Optimizer
Expand All @@ -26,6 +27,7 @@
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
from lightning.pytorch import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from lightning.pytorch.core.mixins.hparams_mixin import given_hyperparameters_context
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
Expand Down Expand Up @@ -196,6 +198,30 @@ def add_lr_scheduler_args(
self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs)
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)

def class_instantiator(self, class_type, *args, **kwargs):
for key, (base_type, hparams) in getattr(self, "_hparam_context", {}).items():
if issubclass(class_type, base_type):
with given_hyperparameters_context(hparams):
return super().class_instantiator(class_type, *args, **kwargs)
return super().class_instantiator(class_type, *args, **kwargs)

def instantiate_classes(
self,
cfg: Namespace,
instantiate_groups: bool = True,
hparam_context: Optional[Dict[str, type]] = None,
) -> Namespace:
if hparam_context:
cfg_dict = yaml.safe_load(self.dump(cfg)) # TODO: do not remove link targets!
self._hparam_context = {}
for key, base_type in hparam_context.items():
hparams = cfg_dict.get(key, {})
self._hparam_context[key] = (base_type, hparams)
init = super().instantiate_classes(cfg, instantiate_groups=instantiate_groups)
if hparam_context:
delattr(self, "_hparam_context")
return init


class SaveConfigCallback(Callback):
"""Saves a LightningCLI config to the log_dir when training starts.
Expand Down Expand Up @@ -530,7 +556,13 @@ def before_instantiate_classes(self) -> None:

def instantiate_classes(self) -> None:
"""Instantiates the classes and sets their attributes."""
self.config_init = self.parser.instantiate_classes(self.config)
hparam_prefix = ""
if "subcommand" in self.config:
hparam_prefix = self.config["subcommand"] + "."
hparam_context = {hparam_prefix + "model": self._model_class}
if self.datamodule_class is not None:
hparam_context[hparam_prefix + "data"] = self._datamodule_class
self.config_init = self.parser.instantiate_classes(self.config, hparam_context=hparam_context)
self.datamodule = self._get(self.config_init, "data")
self.model = self._get(self.config_init, "model")
self._add_configure_optimizers_method_to_model(self.subcommand)
Expand Down Expand Up @@ -754,3 +786,17 @@ def _get_short_description(component: object) -> Optional[str]:
return docstring.short_description
except (ValueError, docstring_parser.ParseError) as ex:
rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")


ModuleType = TypeVar("ModuleType")


def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType:
parser = ArgumentParser(exit_on_error=False)
if "class_path" in config:
parser.add_subclass_arguments(class_type, "module")
else:
parser.add_class_arguments(class_type, "module")
cfg = parser.parse_object({"module": config})
init = parser.instantiate_classes(cfg)
return init.module
19 changes: 17 additions & 2 deletions src/lightning/pytorch/core/mixins/hparams_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import inspect
import types
from argparse import Namespace
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, List, MutableMapping, Optional, Sequence, Union

from lightning.pytorch.utilities.parsing import AttributeDict, save_hyperparameters
Expand All @@ -23,6 +25,18 @@
_ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)


given_hyperparameters: ContextVar = ContextVar("given_hyperparameters", default=None)


@contextmanager
def given_hyperparameters_context(value):
token = given_hyperparameters.set(value)
try:
yield
finally:
given_hyperparameters.reset(token)


class HyperparametersMixin:
__jit_unused_properties__: List[str] = ["hparams", "hparams_initial"]

Expand Down Expand Up @@ -103,12 +117,13 @@ class ``__init__`` to be ignored
"arg3": 3.14
"""
self._log_hyperparams = logger
given_hparams = given_hyperparameters.get()
# the frame needs to be created in this file.
if not frame:
if given_hparams is None and not frame:
current_frame = inspect.currentframe()
if current_frame:
frame = current_frame.f_back
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
save_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams)

def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
hp = self._to_hparams_dict(hp)
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def _load_state(
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
checkpoint: Dict[str, Any],
strict: Optional[bool] = None,
instantiator=None,
**cls_kwargs_new: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
cls_spec = inspect.getfullargspec(cls.__init__)
Expand Down Expand Up @@ -160,7 +161,7 @@ def _load_state(
# filter kwargs according to class init unless it allows any argument via kwargs
_cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}

obj = cls(**_cls_kwargs)
obj = instantiator(cls, _cls_kwargs) if instantiator else cls(**_cls_kwargs)

if isinstance(obj, pl.LightningModule):
# give model a chance to load something
Expand Down
10 changes: 8 additions & 2 deletions src/lightning/pytorch/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def collect_init_args(


def save_hyperparameters(
obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None
obj: Any,
*args: Any,
ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None,
given_hparams: Optional[Dict[str, Any]] = None,
) -> None:
"""See :meth:`~lightning.pytorch.LightningModule.save_hyperparameters`"""

Expand All @@ -154,7 +158,9 @@ def save_hyperparameters(
if not isinstance(frame, types.FrameType):
raise AttributeError("There is no `frame` available while being required.")

if is_dataclass(obj):
if given_hparams is not None:
init_args = given_hparams
elif is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
else:
init_args = {}
Expand Down
48 changes: 48 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from lightning.pytorch.cli import (
_JSONARGPARSE_SIGNATURES_AVAILABLE,
instantiate_class,
instantiate_module,
LightningArgumentParser,
LightningCLI,
LRSchedulerCallable,
Expand Down Expand Up @@ -833,6 +834,53 @@ def configure_optimizers(self):
assert init[1]["lr_scheduler"].gamma == 0.3


def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir):
class TestModel(BoringModel):
def __init__(
self,
optimizer: OptimizerCallable = torch.optim.Adam,
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReLU, negative_slope=0.05),
):
super().__init__()
self.save_hyperparameters()
self.optimizer = optimizer
self.scheduler = scheduler
self.activation = activation

def configure_optimizers(self):
optimizer = self.optimizer(self.parameters())
scheduler = self.scheduler(optimizer)
return {"optimizer": optimizer, "lr_scheduler": scheduler}

with mock.patch("sys.argv", ["any.py", "--trainer.max_epochs=1"]):
cli = LightningCLI(TestModel, run=False, auto_configure_optimizers=False)
cli.trainer.fit(cli.model)

hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())
expected = {
"optimizer": "torch.optim.Adam",
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
}
assert hparams == expected

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
assert checkpoint_path.is_file()
ckpt = torch.load(checkpoint_path)
assert ckpt["hyper_parameters"] == expected

model = TestModel.load_from_checkpoint(checkpoint_path, instantiator=instantiate_module)
assert isinstance(model, TestModel)
assert isinstance(model.activation, torch.nn.LeakyReLU)
assert model.activation.negative_slope == 0.05
optimizer, lr_scheduler = model.configure_optimizers().values()
assert isinstance(optimizer, torch.optim.Adam)
assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR)


@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
def test_lightning_cli_trainer_fn(fn):
class TestCLI(LightningCLI):
Expand Down

0 comments on commit 8388f88

Please sign in to comment.