Skip to content

Commit

Permalink
Feature: Option to set the tracking URI for MLflowCallback. (#29032)
Browse files Browse the repository at this point in the history
* Added option to set tracking URI for MLflowCallback.

* Added option to set tracking URI for MLflowCallback.

* Changed  to  in docstring.
  • Loading branch information
seanswyi authored Feb 16, 2024
1 parent be42c24 commit 161fe42
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,9 @@ def setup(self, args, state, model):
remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy each saved checkpoint on each save in
[`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote
storage will just copy the files to your artifact location.
- **MLFLOW_TRACKING_URI** (`str`, *optional*, defaults to `""`):
Whether to store runs at a specific path or remote server. Default to an empty string which will store runs
at `./mlruns` locally.
- **MLFLOW_EXPERIMENT_NAME** (`str`, *optional*, defaults to `None`):
Whether to use an MLflow experiment_name under which to launch the run. Default to `None` which will point
to the `Default` experiment in MLflow. Otherwise, it is a case sensitive name of the experiment to be
Expand All @@ -978,14 +981,22 @@ def setup(self, args, state, model):
"""
self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "")
self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._run_id = os.getenv("MLFLOW_RUN_ID", None)
logger.debug(
f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run},"
f" tags={self._nested_run}"
f" tags={self._nested_run}, tracking_uri={self._tracking_uri}"
)
if state.is_world_process_zero:
self._ml_flow.set_tracking_uri(self._tracking_uri)

if self._tracking_uri == "":
logger.debug(f"MLflow tracking URI is not set. Runs will be stored at {os.path.realpath('./mlruns')}")
else:
logger.debug(f"MLflow tracking URI is set to {self._tracking_uri}")

if self._ml_flow.active_run() is None or self._nested_run or self._run_id:
if self._experiment_name:
# Use of set_experiment() ensure that Experiment is created if not exists
Expand Down

0 comments on commit 161fe42

Please sign in to comment.