Skip to content

fix: MLFlowLogger artifact_path #20669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 25, 2025
Merged

fix: MLFlowLogger artifact_path #20669

merged 2 commits into from
Mar 25, 2025

Conversation

yxtay
Copy link
Contributor

@yxtay yxtay commented Mar 24, 2025

What does this PR do?

#20538

  • It looks like the changes introduced in this PR broke mlflow logging of checkpoints
  • MlFlowClient.log_artifact takes string for the artifact_path argument only
  • I am getting the following error as a result. Not exactly helpful, but I figured the type difference was the reason.
    mlflow.exceptions.MlflowException: Invalid artifact path: 'epoch=0-step=1'. Names may be treated as files in certain cases, and must not resolve to other names when treated as such. This name would resolve to 'epoch=0-step=1'
  • This change fixes it.

Fixes #20664

# example.py

import os

import lightning as L
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST


class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)


class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


dataset = MNIST(os.getcwd(), transform=transforms.ToTensor())
train_loader = DataLoader(dataset)
# model
autoencoder = LitAutoEncoder(Encoder(), Decoder())

# train model
mlflow_logger = L.pytorch.loggers.mlflow.MLFlowLogger(log_model=True)
trainer = L.Trainer(logger=mlflow_logger, max_steps=50)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--20669.org.readthedocs.build/en/20669/

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Mar 24, 2025
Copy link

codecov bot commented Mar 24, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 79%. Comparing base (669486a) to head (429f732).
Report is 2 commits behind head on master.

❗ There is a different number of reports uploaded between BASE (669486a) and HEAD (429f732). Click for more details.

HEAD has 844 uploads less than BASE
Flag BASE (669486a) HEAD (429f732)
gpu 4 0
pytest 112 0
lightning_fabric 51 0
lightning 121 15
python3.10 48 6
cpu 215 27
python 23 3
python3.11 48 6
python3.12 24 3
python3.12.7 72 9
pytorch2.1 23 6
pytorch_lightning 47 12
pytest-full 107 27
pytorch2.2.2 12 3
pytorch2.3 12 3
pytorch2.5 12 3
pytorch2.4.1 12 3
pytorch2.7 12 3
pytorch2.5.1 12 3
pytorch2.6 12 3
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #20669     +/-   ##
=========================================
- Coverage      87%      79%     -9%     
=========================================
  Files         268      265      -3     
  Lines       23449    23394     -55     
=========================================
- Hits        20499    18386   -2113     
- Misses       2950     5008   +2058     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@niander
Copy link

niander commented Mar 24, 2025

Related bug: #20664

Copy link

@niander niander left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix should also make sure that this change will work on Windows when also passing a checkpoint_path_prefix to the MLFlowLogger object creation. str(...) on a pathlib.Path created on Windows may be translated into backward slashes which doesn't seem to be what MLFlowClient expects

@Borda Borda merged commit df5dee6 into Lightning-AI:master Mar 25, 2025
83 checks passed
@Northo
Copy link

Northo commented Apr 10, 2025

Great to see this fixed! Do you know when we will have a release with this patch?

@niander
Copy link

niander commented Apr 23, 2025

@Borda can we get this into a new release soon? What is the release cadence? This fix unblocks people using MLFlowLogger for checkpoints.

Borda added a commit that referenced this pull request Apr 24, 2025
Borda added a commit that referenced this pull request Apr 24, 2025
@MohammadElSakka
Copy link

MohammadElSakka commented Apr 29, 2025

I'm still getting the same problem after upgrading packages
Any help would be appreciated :-)

torch                              2.6.0+cu118
torchaudio                        2.6.0+cu118
torchmetrics                       1.7.1
torchvision                        0.21.0+cu118
lightning                          2.5.1.post0
lightning-utilities              0.14.3
pytorch-lightning             2.5.1.post0
mlflow                             2.22.0
mlflow-skinny                  2.22.0

@yxtay
Copy link
Contributor Author

yxtay commented Apr 29, 2025

The PR changes were reverted in the master branch and not included in the 2.5.1.post0. Not sure what the reason was, just stating what I found.

Commit: f6ef409

@MohammadElSakka
Copy link

My workaround was to manually change the files inside my venv according to the changes in https://github.com/Lightning-AI/pytorch-lightning/pull/20669/files

Seems to work fine

@niander
Copy link

niander commented May 3, 2025

Also, I am very confused on why they reverted. It seems though that the they reverted more than just this but also the other commit that added custom path prefix for the artifacts.
@Borda

@Clara-liu
Copy link

Our CI is failing which caused us having to pin lightning to older version. Any reason this fix was reverted? Thanks.

@MohammadElSakka
Copy link

Is there an older version that is stable that has this fixed? Which workaround are you guys doing? It is a bit frustrating to allocate a whole section in my README.md to explain how to fix MLFlow in my program ... 🫠

@Clara-liu
Copy link

Is there an older version that is stable that has this fixed? Which workaround are you guys doing? It is a bit frustrating to allocate a whole section in my README.md to explain how to fix MLFlow in my program ... 🫠

We've pinned lightning to 2.5.0 until this is resolved.

@MohammadElSakka
Copy link

Is there an older version that is stable that has this fixed? Which workaround are you guys doing? It is a bit frustrating to allocate a whole section in my README.md to explain how to fix MLFlow in my program ... 🫠

We've pinned lightning to 2.5.0 until this is resolved.

I'll experiment with that 🫡

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pl Generic label for PyTorch Lightning package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

MLFlowLogger fails to log artifact on Windows
6 participants