Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_from_checkpoint support for LightningCLI when using dependency injection #18105

Merged
merged 10 commits into from
Feb 23, 2024
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added shortcut name `strategy='deepspeed_stage_1_offload'` to the strategy registry ([#19075](https://github.com/Lightning-AI/lightning/pull/19075))
- Added support for non-strict state-dict loading in Trainer via the new `LightningModule.strict_loading = True | False` attribute ([#19404](https://github.com/Lightning-AI/lightning/pull/19404))

- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105))
mauvilsa marked this conversation as resolved.
Show resolved Hide resolved


### Changed

- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))
Expand Down
51 changes: 50 additions & 1 deletion src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import sys
from functools import partial, update_wrapper
from types import MethodType
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union

import torch
import yaml
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import _warn
from torch.optim import Optimizer
Expand All @@ -27,6 +29,7 @@
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, seed_everything
from lightning.pytorch.core.mixins.hparams_mixin import given_hyperparameters_context
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
Expand All @@ -50,6 +53,8 @@
locals()["ArgumentParser"] = object
locals()["Namespace"] = object

ModuleType = TypeVar("ModuleType")


class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -381,6 +386,7 @@ def __init__(

self._set_seed()

self._add_instantiators()
self.before_instantiate_classes()
self.instantiate_classes()

Expand Down Expand Up @@ -527,6 +533,22 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def _add_instantiators(self) -> None:
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False))
if "subcommand" in self.config:
self.config_dump = self.config_dump[self.config.subcommand]

self.parser.add_instantiator(
_InstantiatorFn(cli=self, key="model"),
_get_module_type(self._model_class),
subclasses=self.subclass_mode_model,
)
self.parser.add_instantiator(
_InstantiatorFn(cli=self, key="data"),
_get_module_type(self._datamodule_class),
subclasses=self.subclass_mode_data,
)

def before_instantiate_classes(self) -> None:
"""Implement to run some code before instantiating the classes."""

Expand Down Expand Up @@ -755,3 +777,30 @@ def _get_short_description(component: object) -> Optional[str]:
return docstring.short_description
except (ValueError, docstring_parser.ParseError) as ex:
rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")


def _get_module_type(value: Union[Callable, type]) -> type:
if callable(value) and not isinstance(value, type):
return inspect.signature(value).return_annotation
return value


class _InstantiatorFn:
def __init__(self, cli: LightningCLI, key: str) -> None:
self.cli = cli
self.key = key

def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType:
with given_hyperparameters_context(self.cli.config_dump.get(self.key, {})):
return class_type(*args, **kwargs)


def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType:
parser = ArgumentParser(exit_on_error=False)
if "class_path" in config:
parser.add_subclass_arguments(class_type, "module")
else:
parser.add_class_arguments(class_type, "module")
cfg = parser.parse_object({"module": config})
init = parser.instantiate_classes(cfg)
return init.module
21 changes: 18 additions & 3 deletions src/lightning/pytorch/core/mixins/hparams_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import inspect
import types
from argparse import Namespace
from typing import Any, List, MutableMapping, Optional, Sequence, Union
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Iterator, List, MutableMapping, Optional, Sequence, Union

from lightning.fabric.utilities.data import AttributeDict
from lightning.pytorch.utilities.parsing import save_hyperparameters
Expand All @@ -24,6 +26,18 @@
_ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)


given_hyperparameters: ContextVar = ContextVar("given_hyperparameters", default=None)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


@contextmanager
def given_hyperparameters_context(value: dict) -> Iterator[None]:
mauvilsa marked this conversation as resolved.
Show resolved Hide resolved
token = given_hyperparameters.set(value)
try:
yield
finally:
given_hyperparameters.reset(token)


class HyperparametersMixin:
__jit_unused_properties__: List[str] = ["hparams", "hparams_initial"]

Expand Down Expand Up @@ -105,12 +119,13 @@ class ``__init__`` to be ignored

"""
self._log_hyperparams = logger
given_hparams = given_hyperparameters.get()
# the frame needs to be created in this file.
if not frame:
if given_hparams is None and not frame:
current_frame = inspect.currentframe()
if current_frame:
frame = current_frame.f_back
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
save_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams)

def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
hp = self._to_hparams_dict(hp)
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def _load_state(
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
checkpoint: Dict[str, Any],
strict: Optional[bool] = None,
instantiator: Optional[Callable] = None,
**cls_kwargs_new: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
cls_spec = inspect.getfullargspec(cls.__init__)
Expand Down Expand Up @@ -155,7 +156,7 @@ def _load_state(
# filter kwargs according to class init unless it allows any argument via kwargs
_cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}

obj = cls(**_cls_kwargs)
obj = instantiator(cls, _cls_kwargs) if instantiator else cls(**_cls_kwargs)

if isinstance(obj, pl.LightningDataModule):
if obj.__class__.__qualname__ in checkpoint:
Expand Down
10 changes: 8 additions & 2 deletions src/lightning/pytorch/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,11 @@ def collect_init_args(


def save_hyperparameters(
obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None
obj: Any,
*args: Any,
ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None,
given_hparams: Optional[Dict[str, Any]] = None,
) -> None:
"""See :meth:`~lightning.pytorch.LightningModule.save_hyperparameters`"""

Expand All @@ -156,7 +160,9 @@ def save_hyperparameters(
if not isinstance(frame, types.FrameType):
raise AttributeError("There is no `frame` available while being required.")

if is_dataclass(obj):
if given_hparams is not None:
init_args = given_hparams
elif is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
else:
init_args = {}
Expand Down
48 changes: 48 additions & 0 deletions tests/tests_pytorch/test_cli.py
mauvilsa marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
OptimizerCallable,
SaveConfigCallback,
instantiate_class,
instantiate_module,
)
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
Expand Down Expand Up @@ -835,6 +836,53 @@ def configure_optimizers(self):
assert init[1]["lr_scheduler"].gamma == 0.3


def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir):
class TestModel(BoringModel):
def __init__(
self,
optimizer: OptimizerCallable = torch.optim.Adam,
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReLU, negative_slope=0.05),
):
super().__init__()
self.save_hyperparameters()
self.optimizer = optimizer
self.scheduler = scheduler
self.activation = activation

def configure_optimizers(self):
optimizer = self.optimizer(self.parameters())
scheduler = self.scheduler(optimizer)
return {"optimizer": optimizer, "lr_scheduler": scheduler}

with mock.patch("sys.argv", ["any.py", "--trainer.max_epochs=1"]):
cli = LightningCLI(TestModel, run=False, auto_configure_optimizers=False)
cli.trainer.fit(cli.model)

hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())
expected = {
"optimizer": "torch.optim.Adam",
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
}
assert hparams == expected

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
assert checkpoint_path.is_file()
ckpt = torch.load(checkpoint_path)
assert ckpt["hyper_parameters"] == expected

model = TestModel.load_from_checkpoint(checkpoint_path, instantiator=instantiate_module)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(model, TestModel)
assert isinstance(model.activation, torch.nn.LeakyReLU)
assert model.activation.negative_slope == 0.05
optimizer, lr_scheduler = model.configure_optimizers().values()
assert isinstance(optimizer, torch.optim.Adam)
assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR)


@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
def test_lightning_cli_trainer_fn(fn):
class TestCLI(LightningCLI):
Expand Down
Loading