1
1
"""MLFlow module for trainer callbacks"""
2
2
3
3
import logging
4
+ import os
4
5
from shutil import copyfile
5
6
from tempfile import NamedTemporaryFile
6
7
from typing import TYPE_CHECKING
16
17
LOG = logging .getLogger ("axolotl.callbacks" )
17
18
18
19
20
+ def should_log_artifacts () -> bool :
21
+ truths = ["TRUE" , "1" , "YES" ]
22
+ return os .getenv ("HF_MLFLOW_LOG_ARTIFACTS" , "FALSE" ).upper () in truths
23
+
24
+
19
25
class SaveAxolotlConfigtoMlflowCallback (TrainerCallback ):
20
26
# pylint: disable=duplicate-code
21
27
"""Callback to save axolotl config to mlflow"""
@@ -32,13 +38,18 @@ def on_train_begin(
32
38
):
33
39
if is_main_process ():
34
40
try :
35
- with NamedTemporaryFile (
36
- mode = "w" , delete = False , suffix = ".yml" , prefix = "axolotl_config_"
37
- ) as temp_file :
38
- copyfile (self .axolotl_config_path , temp_file .name )
39
- mlflow .log_artifact (temp_file .name , artifact_path = "" )
41
+ if should_log_artifacts ():
42
+ with NamedTemporaryFile (
43
+ mode = "w" , delete = False , suffix = ".yml" , prefix = "axolotl_config_"
44
+ ) as temp_file :
45
+ copyfile (self .axolotl_config_path , temp_file .name )
46
+ mlflow .log_artifact (temp_file .name , artifact_path = "" )
47
+ LOG .info (
48
+ "The Axolotl config has been saved to the MLflow artifacts."
49
+ )
50
+ else :
40
51
LOG .info (
41
- "The Axolotl config has been saved to the MLflow artifacts. "
52
+ "Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false) "
42
53
)
43
54
except (FileNotFoundError , ConnectionError ) as err :
44
55
LOG .warning (f"Error while saving Axolotl config to MLflow: { err } " )
0 commit comments