diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 7e433be7f1abb4..3af00c98eb66b2 100644 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -959,6 +959,9 @@ def setup(self, args, state, model): remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy each saved checkpoint on each save in [`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote storage will just copy the files to your artifact location. + - **MLFLOW_TRACKING_URI** (`str`, *optional*, defaults to `""`): + Whether to store runs at a specific path or remote server. Default to an empty string which will store runs + at `./mlruns` locally. - **MLFLOW_EXPERIMENT_NAME** (`str`, *optional*, defaults to `None`): Whether to use an MLflow experiment_name under which to launch the run. Default to `None` which will point to the `Default` experiment in MLflow. Otherwise, it is a case sensitive name of the experiment to be @@ -978,14 +981,22 @@ def setup(self, args, state, model): """ self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES + self._tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "") self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None) self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES self._run_id = os.getenv("MLFLOW_RUN_ID", None) logger.debug( f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run}," - f" tags={self._nested_run}" + f" tags={self._nested_run}, tracking_uri={self._tracking_uri}" ) if state.is_world_process_zero: + self._ml_flow.set_tracking_uri(self._tracking_uri) + + if self._tracking_uri == "": + logger.debug(f"MLflow tracking URI is not set. Runs will be stored at {os.path.realpath('./mlruns')}") + else: + logger.debug(f"MLflow tracking URI is set to {self._tracking_uri}") + if self._ml_flow.active_run() is None or self._nested_run or self._run_id: if self._experiment_name: # Use of set_experiment() ensure that Experiment is created if not exists