From 7ce2f8f1c08db09f67ce1c1fc9589512c8ab69d3 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 10 Jul 2024 14:50:54 -0400 Subject: [PATCH] no-op _add_instantiators (#207) --- cellarium/ml/cli.py | 5 +++++ cellarium/ml/core/module.py | 14 ++++++++++++++ pyproject.toml | 2 +- 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/cellarium/ml/cli.py b/cellarium/ml/cli.py index e4ff7525..ea54401a 100644 --- a/cellarium/ml/cli.py +++ b/cellarium/ml/cli.py @@ -274,6 +274,11 @@ def __init__(self, args: ArgsType = None) -> None: args=args, ) + def _add_instantiators(self) -> None: + # disable breaking dependency injection support change introduced in PyTorch Lightning 2.3 + # https://github.com/Lightning-AI/pytorch-lightning/pull/18105 + pass + def instantiate_classes(self) -> None: with torch.device("meta"): # skip the initialization of model parameters diff --git a/cellarium/ml/core/module.py b/cellarium/ml/core/module.py index e7926414..5826213a 100644 --- a/cellarium/ml/core/module.py +++ b/cellarium/ml/core/module.py @@ -61,6 +61,12 @@ def __init__( self.save_hyperparameters(logger=False) self.pipeline: CellariumPipeline | None = None + if optim_fn is None: + # Starting from PyTorch Lightning 2.3, automatic optimization doesn't allow to return None + # from the training_step during distributed training. https://github.com/Lightning-AI/pytorch-lightning/pull/19918 + # Thus, we need to use manual optimization for the No Optimizer case. + self.automatic_optimization = False + def configure_model(self) -> None: """ .. note:: @@ -156,6 +162,14 @@ def training_step( # type: ignore[override] if loss is not None: # Logging to TensorBoard by default self.log("train_loss", loss, sync_dist=True) + + if not self.automatic_optimization: + # Note, that running .step() is necessary for incrementing the global step even though no backpropagation + # is performed. + no_optimizer = self.optimizers() + assert isinstance(no_optimizer, pl.core.optimizer.LightningOptimizer) + no_optimizer.step() + return loss def forward(self, batch: dict[str, np.ndarray | torch.Tensor]) -> dict[str, np.ndarray | torch.Tensor]: diff --git a/pyproject.toml b/pyproject.toml index 7d46c9f9..22d7ecca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "crick>=0.0.4", "google-cloud-storage", "jsonargparse[signatures]==4.27.7", - "lightning>=2.2.0, <2.3", + "lightning>=2.2.0", "pyro-ppl>=1.9.1", "pytest", "torch>=2.2.0",