Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed NeptuneLogger when using DDP #11030

Merged
merged 13 commits into from
Dec 18, 2021
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
-


- Avoid the deprecated `onnx.export(example_outputs=...)` in torch 1.10 ([#11116](https://github.com/PyTorchLightning/pytorch-lightning/pull/11116))
- Fixed `NeptuneLogger` when using DDP ([#11030](https://github.com/PyTorchLightning/pytorch-lightning/pull/11030))


- Avoid the deprecated `onnx.export(example_outputs=...)` in torch 1.10 ([#11116](https://github.com/PyTorchLightning/pytorch-lightning/pull/11116))


- Fixed an issue when torch-scripting a `LightningModule` after training with `Trainer(sync_batchnorm=True)` ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078))
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ dependencies:
- mlflow>=1.0.0
- comet_ml>=3.1.12
- wandb>=0.8.21
- neptune-client>=0.4.109
- neptune-client>=0.10.0
124 changes: 75 additions & 49 deletions pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from neptune.new.types import File as NeptuneFile
except ModuleNotFoundError:
import neptune
from neptune.exceptions import NeptuneLegacyProjectException
from neptune.exceptions import NeptuneLegacyProjectException, NeptuneOfflineModeFetchException
from neptune.run import Run
from neptune.types import File as NeptuneFile
else:
Expand Down Expand Up @@ -266,51 +266,64 @@ def __init__(
prefix: str = "training",
**neptune_run_kwargs,
):

# verify if user passed proper init arguments
self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs)
if neptune is None:
raise ModuleNotFoundError(
"You want to use the `Neptune` logger which is not installed yet, install it with"
" `pip install neptune-client`."
)

super().__init__()
self._log_model_checkpoints = log_model_checkpoints
self._prefix = prefix
self._run_name = name
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self._project_name = project
self._api_key = api_key
self._run_instance = run
self._neptune_run_kwargs = neptune_run_kwargs
self._run_short_id = None

self._run_instance = self._init_run_instance(api_key, project, name, run, neptune_run_kwargs)
if self._run_instance is not None:
self._retrieve_run_data()

self._run_short_id = self.run._short_id # skipcq: PYL-W0212
# make sure that we've log integration version for outside `Run` instances
self._run_instance[_INTEGRATION_VERSION_KEY] = __version__

def _retrieve_run_data(self):
try:
self.run.wait()
self._run_instance.wait()
self._run_short_id = self.run._short_id # skipcq: PYL-W0212
self._run_name = self._run_instance["sys/name"].fetch()
except NeptuneOfflineModeFetchException:
self._run_name = "offline-name"

def _init_run_instance(self, api_key, project, name, run, neptune_run_kwargs) -> Run:
if run is not None:
run_instance = run
else:
try:
run_instance = neptune.init(
project=project,
api_token=api_key,
name=name,
**neptune_run_kwargs,
)
except NeptuneLegacyProjectException as e:
raise TypeError(
f"""Project {project} has not been migrated to the new structure.
You can still integrate it with the Neptune logger using legacy Python API
available as part of neptune-contrib package:
- https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n
"""
) from e

# make sure that we've log integration version for both newly created and outside `Run` instances
run_instance[_INTEGRATION_VERSION_KEY] = __version__

# keep api_key and project, they will be required when resuming Run for pickled logger
self._api_key = api_key
self._project_name = run_instance._project_name # skipcq: PYL-W0212
@property
def _neptune_init_args(self):
args = {}
# Backward compatibility in case of previous version retrieval
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
try:
args = self._neptune_run_kwargs
except AttributeError:
pass

if self._project_name is not None:
args["project"] = self._project_name

if self._api_key is not None:
args["api_token"] = self._api_key

return run_instance
if self._run_short_id is not None:
args["run"] = self._run_short_id

# Backward compatibility in case of previous version retrieval
try:
if self._run_name is not None:
args["name"] = self._run_name
except AttributeError:
pass

return args

def _construct_path_with_prefix(self, *keys) -> str:
"""Return sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined."""
Expand Down Expand Up @@ -379,7 +392,7 @@ def __getstate__(self):

def __setstate__(self, state):
self.__dict__ = state
self._run_instance = neptune.init(project=self._project_name, api_token=self._api_key, run=self._run_short_id)
self._run_instance = neptune.init(**self._neptune_init_args)

@property
@rank_zero_experiment
Expand Down Expand Up @@ -412,8 +425,23 @@ def training_step(self, batch, batch_idx):
return self.run

@property
@rank_zero_experiment
def run(self) -> Run:
return self._run_instance
try:
if not self._run_instance:
self._run_instance = neptune.init(**self._neptune_init_args)
self._retrieve_run_data()
# make sure that we've log integration version for newly created
self._run_instance[_INTEGRATION_VERSION_KEY] = __version__

return self._run_instance
except NeptuneLegacyProjectException as e:
raise TypeError(
f"Project {self._project_name} has not been migrated to the new structure."
" You can still integrate it with the Neptune logger using legacy Python API"
" available as part of neptune-contrib package:"
" https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n"
) from e

@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # skipcq: PYL-W0221
Expand Down Expand Up @@ -474,12 +502,12 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti
for key, val in metrics.items():
# `step` is ignored because Neptune expects strictly increasing step values which
# Lightning does not always guarantee.
self.experiment[key].log(val)
self.run[key].log(val)

@rank_zero_only
def finalize(self, status: str) -> None:
if status:
self.experiment[self._construct_path_with_prefix("status")] = status
self.run[self._construct_path_with_prefix("status")] = status

super().finalize(status)

Expand All @@ -493,12 +521,14 @@ def save_dir(self) -> Optional[str]:
"""
return os.path.join(os.getcwd(), ".neptune")

@rank_zero_only
def log_model_summary(self, model, max_depth=-1):
model_str = str(ModelSummary(model=model, max_depth=max_depth))
self.experiment[self._construct_path_with_prefix("model/summary")] = neptune.types.File.from_content(
self.run[self._construct_path_with_prefix("model/summary")] = neptune.types.File.from_content(
content=model_str, extension="txt"
)

@rank_zero_only
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.

Expand All @@ -515,35 +545,33 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
if checkpoint_callback.last_model_path:
model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
file_names.add(model_last_name)
self.experiment[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)
self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)

# save best k models
for key in checkpoint_callback.best_k_models.keys():
model_name = self._get_full_model_name(key, checkpoint_callback)
file_names.add(model_name)
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(key)
self.run[f"{checkpoints_namespace}/{model_name}"].upload(key)

# log best model path and checkpoint
if checkpoint_callback.best_model_path:
self.experiment[
self._construct_path_with_prefix("model/best_model_path")
] = checkpoint_callback.best_model_path
self.run[self._construct_path_with_prefix("model/best_model_path")] = checkpoint_callback.best_model_path

model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
file_names.add(model_name)
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)
self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)

# remove old models logged to experiment if they are not part of best k models at this point
if self.experiment.exists(checkpoints_namespace):
exp_structure = self.experiment.get_structure()
if self.run.exists(checkpoints_namespace):
exp_structure = self.run.get_structure()
uploaded_model_names = self._get_full_model_names_from_exp_structure(exp_structure, checkpoints_namespace)

for file_to_drop in list(uploaded_model_names - file_names):
del self.experiment[f"{checkpoints_namespace}/{file_to_drop}"]
del self.run[f"{checkpoints_namespace}/{file_to_drop}"]

# log best model score
if checkpoint_callback.best_model_score:
self.experiment[self._construct_path_with_prefix("model/best_model_score")] = (
self.run[self._construct_path_with_prefix("model/best_model_score")] = (
checkpoint_callback.best_model_score.cpu().detach().numpy()
)

Expand Down Expand Up @@ -637,13 +665,11 @@ def log_artifact(self, artifact: str, destination: Optional[str] = None) -> None
self._signal_deprecated_api_usage("log_artifact", f"logger.run['{key}].log('path_to_file')")
self.run[key].log(destination)

@rank_zero_only
def set_property(self, *args, **kwargs):
self._signal_deprecated_api_usage(
"log_artifact", f"logger.run['{self._prefix}/{self.PARAMETERS_KEY}/key'].log(value)", raise_exception=True
)

@rank_zero_only
def append_tags(self, *args, **kwargs):
self._signal_deprecated_api_usage(
"append_tags", "logger.run['sys/tags'].add(['foo', 'bar'])", raise_exception=True
Expand Down
6 changes: 5 additions & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def _get_logger_args(logger_class, save_dir):
logger_args.update(offline_mode=True)
if "offline" in inspect.getfullargspec(logger_class).args:
logger_args.update(offline=True)
if issubclass(logger_class, NeptuneLogger):
logger_args.update(mode="offline")
return logger_args


Expand Down Expand Up @@ -330,7 +332,9 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):


@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
@pytest.mark.parametrize("logger_class", [CometLogger, CSVLogger, MLFlowLogger, TensorBoardLogger, TestTubeLogger])
@pytest.mark.parametrize(
"logger_class", [CometLogger, CSVLogger, MLFlowLogger, NeptuneLogger, TensorBoardLogger, TestTubeLogger]
)
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
"""Test that loggers get replaced by dummy loggers on global rank > 0."""
_patch_comet_atexit(monkeypatch)
Expand Down
6 changes: 3 additions & 3 deletions tests/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def tmpdir_unittest_fixture(request, tmpdir):
class TestNeptuneLogger(unittest.TestCase):
def test_neptune_online(self, neptune):
logger = NeptuneLogger(api_key="test", project="project")
created_run_mock = logger._run_instance
created_run_mock = logger.run

self.assertEqual(logger._run_instance, created_run_mock)
self.assertEqual(logger.name, "Run test name")
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_neptune_pickling(self, neptune):
pickled_logger = pickle.dumps(logger)
unpickled = pickle.loads(pickled_logger)

neptune.init.assert_called_once_with(project="test-project", api_token=None, run="TEST-42")
neptune.init.assert_called_once_with(name="Test name", run=unpickleable_run._short_id)
self.assertIsNotNone(unpickled.experiment)

@patch("pytorch_lightning.loggers.neptune.Run", Run)
Expand Down Expand Up @@ -360,7 +360,7 @@ def test_legacy_functions(self, neptune, neptune_file_mock, warnings_mock):
logger = NeptuneLogger(api_key="test", project="project")

# test deprecated functions which will be shut down in pytorch-lightning 1.7.0
attr_mock = logger._run_instance.__getitem__
attr_mock = logger.run.__getitem__
attr_mock.reset_mock()
fake_image = {}

Expand Down