Skip to content

Commit

Permalink
Fix to_torchscript() causing false positive deprecation warnings (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored and lexierule committed Nov 16, 2021
1 parent 53ff840 commit 5f4a5fe
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486))


- Fixed `to_torchscript()` causing false positive deprecation warnings ([#10470](https://github.com/PyTorchLightning/pytorch-lightning/issues/10470))


- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493))


Expand Down
17 changes: 12 additions & 5 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._param_requires_grad_state = {}
self._metric_attributes: Optional[Dict[int, str]] = None
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
# TODO: remove after the 1.6 release
self._running_torchscript = False

self._register_sharded_tensor_state_dict_hooks_if_available()

Expand Down Expand Up @@ -1962,6 +1964,8 @@ def to_torchscript(
"""
mode = self.training

self._running_torchscript = True

if method == "script":
torchscript_module = torch.jit.script(self.eval(), **kwargs)
elif method == "trace":
Expand All @@ -1987,6 +1991,8 @@ def to_torchscript(
with fs.open(file_path, "wb") as f:
torch.jit.save(torchscript_module, f)

self._running_torchscript = False

return torchscript_module

@property
Expand All @@ -1996,11 +2002,12 @@ def model_size(self) -> float:
Note:
This property will not return correct value for Deepspeed (stage 3) and fully-sharded training.
"""
rank_zero_deprecation(
"The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7."
" Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.",
stacklevel=5,
)
if not self._running_torchscript: # remove with the deprecation removal
rank_zero_deprecation(
"The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7."
" Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.",
stacklevel=5,
)
return get_model_size_mb(self)

def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
Expand Down

0 comments on commit 5f4a5fe

Please sign in to comment.