Skip to content

PR: Fix Duplicate Metric Logging in MLFlowLogger to Prevent MLflow Database Errors #20871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/pytorch/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ fastapi # for `ServableModuleValidator` # not setting version as re-defined in
uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App

tensorboard >=2.9.1, <2.20.0 # for `TensorBoardLogger`
mlflow
14 changes: 14 additions & 0 deletions src/lightning/pytorch/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ def any_lightning_module_function_or_hook(self):
ModuleNotFoundError:
If required MLFlow package is not installed on the device.

Note:
As of vX.XX, MLFlowLogger will skip logging any metric (same name and step)
more than once per run, to prevent database unique constraint violations on
some MLflow backends (such as PostgreSQL). Only the first value for each (metric, step)
pair will be logged per run. This improves robustness for all users.

"""

LOGGER_JOIN_CHAR = "-"
Expand Down Expand Up @@ -151,6 +157,7 @@ def __init__(
from mlflow.tracking import MlflowClient

self._mlflow_client = MlflowClient(tracking_uri)
self._logged_metrics = set() # Track (key, step)

@property
@rank_zero_experiment
Expand Down Expand Up @@ -201,6 +208,7 @@ def experiment(self) -> "MlflowClient":
resolve_tags = _get_resolve_tags()
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags))
self._run_id = run.info.run_id
self._logged_metrics.clear()
self._initialized = True
return self._mlflow_client

Expand Down Expand Up @@ -266,6 +274,12 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
category=RuntimeWarning,
)
k = new_k

metric_id = (k, step or 0)
if metric_id in self._logged_metrics:
continue
self._logged_metrics.add(metric_id)

metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0))

self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list, **self._log_batch_kwargs)
Expand Down
47 changes: 47 additions & 0 deletions tests/tests_pytorch/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,50 @@ def test_set_tracking_uri(mlflow_mock):
mlflow_mock.set_tracking_uri.assert_not_called()
_ = logger.experiment
mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri")


def test_mlflowlogger_metric_deduplication(monkeypatch):
import types

from lightning.pytorch.loggers.mlflow import MLFlowLogger

# Dummy MLflow client to record log_batch calls
logged_metrics = []

class DummyMlflowClient:
def log_batch(self, run_id, metrics, **kwargs):
logged_metrics.extend(metrics)

def set_tracking_uri(self, uri):
pass

def create_run(self, experiment_id, tags):
class Run:
info = types.SimpleNamespace(run_id="dummy_run_id")

return Run()

def get_run(self, run_id):
class Run:
info = types.SimpleNamespace(experiment_id="dummy_experiment_id")

return Run()

def get_experiment_by_name(self, name):
return None

def create_experiment(self, name, artifact_location=None):
return "dummy_experiment_id"

# Patch the MLFlowLogger to use DummyMlflowClient
monkeypatch.setattr("mlflow.tracking.MlflowClient", lambda *a, **k: DummyMlflowClient())

logger = MLFlowLogger(experiment_name="test_exp")
logger.log_metrics({"foo": 1.0}, step=5)
logger.log_metrics({"foo": 1.0}, step=5) # duplicate

# Only the first metric should be logged
assert len(logged_metrics) == 1
assert logged_metrics[0].key == "foo"
assert logged_metrics[0].value == 1.0
assert logged_metrics[0].step == 5
Loading