Skip to content

Commit 288653a

Browse files
C080winglian
andauthored
Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifa… (axolotl-ai-cloud#2675) [skip ci]
* Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifacts setting * cleanup and lint --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
1 parent 3a5b495 commit 288653a

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

src/axolotl/utils/callbacks/mlflow_.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""MLFlow module for trainer callbacks"""
22

33
import logging
4+
import os
45
from shutil import copyfile
56
from tempfile import NamedTemporaryFile
67
from typing import TYPE_CHECKING
@@ -16,6 +17,11 @@
1617
LOG = logging.getLogger("axolotl.callbacks")
1718

1819

20+
def should_log_artifacts() -> bool:
21+
truths = ["TRUE", "1", "YES"]
22+
return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths
23+
24+
1925
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
2026
# pylint: disable=duplicate-code
2127
"""Callback to save axolotl config to mlflow"""
@@ -32,13 +38,18 @@ def on_train_begin(
3238
):
3339
if is_main_process():
3440
try:
35-
with NamedTemporaryFile(
36-
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
37-
) as temp_file:
38-
copyfile(self.axolotl_config_path, temp_file.name)
39-
mlflow.log_artifact(temp_file.name, artifact_path="")
41+
if should_log_artifacts():
42+
with NamedTemporaryFile(
43+
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
44+
) as temp_file:
45+
copyfile(self.axolotl_config_path, temp_file.name)
46+
mlflow.log_artifact(temp_file.name, artifact_path="")
47+
LOG.info(
48+
"The Axolotl config has been saved to the MLflow artifacts."
49+
)
50+
else:
4051
LOG.info(
41-
"The Axolotl config has been saved to the MLflow artifacts."
52+
"Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)"
4253
)
4354
except (FileNotFoundError, ConnectionError) as err:
4455
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")

0 commit comments

Comments
 (0)