From 82297aa1b3dd8f3d509505c9def266c179506211 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 15 Nov 2021 23:12:55 +0100 Subject: [PATCH] Fix `to_torchscript()` causing false positive deprecation warnings (#10470) --- CHANGELOG.md | 3 +++ pytorch_lightning/core/lightning.py | 17 ++++++++++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 03b8650b6d401..102f74894d1dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486)) +- Fixed `to_torchscript()` causing false positive deprecation warnings ([#10470](https://github.com/PyTorchLightning/pytorch-lightning/issues/10470)) + + - Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c59193859b171..bc89cc2b18e93 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -116,6 +116,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._param_requires_grad_state = {} self._metric_attributes: Optional[Dict[int, str]] = None self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False + # TODO: remove after the 1.6 release + self._running_torchscript = False self._register_sharded_tensor_state_dict_hooks_if_available() @@ -1962,6 +1964,8 @@ def to_torchscript( """ mode = self.training + self._running_torchscript = True + if method == "script": torchscript_module = torch.jit.script(self.eval(), **kwargs) elif method == "trace": @@ -1987,6 +1991,8 @@ def to_torchscript( with fs.open(file_path, "wb") as f: torch.jit.save(torchscript_module, f) + self._running_torchscript = False + return torchscript_module @property @@ -1996,11 +2002,12 @@ def model_size(self) -> float: Note: This property will not return correct value for Deepspeed (stage 3) and fully-sharded training. """ - rank_zero_deprecation( - "The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7." - " Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.", - stacklevel=5, - ) + if not self._running_torchscript: # remove with the deprecation removal + rank_zero_deprecation( + "The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7." + " Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.", + stacklevel=5, + ) return get_model_size_mb(self) def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: