Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete TensorBoardLogger experiment before spawning the processes. #10777

Merged
merged 16 commits into from
Nov 26, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762))


- Fixed TensorBoardLogger `SummaryWriter` not close before spawning the processes ([#10777](https://github.com/PyTorchLightning/pytorch-lightning/pull/10777))


## [1.5.2] - 2021-11-16

### Fixed
Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.nn.parallel.distributed import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
Expand Down Expand Up @@ -172,6 +173,7 @@ def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kw
Return:
The output of the function of process 0.
"""
self._clean_logger(*args)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
context = mp.get_context("spawn")
return_queue = context.SimpleQueue() if return_result else None
Expand Down Expand Up @@ -415,3 +417,13 @@ def teardown(self) -> None:
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()

@staticmethod
def _clean_logger(*args: Any) -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(args[0], pl.Trainer):
trainer = args[0]
loggers = trainer._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
for logger in loggers:
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
rank_zero_warn("When using `ddp_spawn`, the Tensorboard experiment should be `None`.")
logger._experiment = None