Skip to content

Commit

Permalink
[Train] MLflow start run under correct experiment (ray-project#23662)
Browse files Browse the repository at this point in the history
Start Mlflow run under correct mlflow experiment
  • Loading branch information
amogkam authored Apr 6, 2022
1 parent fdc6e02 commit 8becbfa
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 3 deletions.
12 changes: 12 additions & 0 deletions python/ray/train/examples/mlflow_simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,15 @@ def train_func():
print("Run directory:", trainer.latest_run_dir)

trainer.shutdown()

# How to visualize the logs

# Navigate to the run directory of the trainer.
# For example `cd /home/ray_results/train_2021-09-01_12-00-00/run_001`
# $ cd <TRAINER_RUN_DIR>
#
# # View the MLflow UI.
# $ mlflow ui
#
# # View the tensorboard UI.
# $ tensorboard --logdir .
3 changes: 2 additions & 1 deletion python/ray/train/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ def train_func(config):

client = MlflowClient(tracking_uri=callback.mlflow_util._mlflow.get_tracking_uri())

all_runs = callback.mlflow_util._mlflow.search_runs(experiment_ids=["0"])
experiment_id = client.get_experiment_by_name("test_exp").experiment_id
all_runs = callback.mlflow_util._mlflow.search_runs(experiment_ids=[experiment_id])
assert len(all_runs) == 1
# all_runs is a pandas dataframe.
all_runs = all_runs.to_dict(orient="records")
Expand Down
5 changes: 4 additions & 1 deletion python/ray/util/ml_utils/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def start_run(
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME

client = self._get_client()
tags = tags or {}
tags[MLFLOW_RUN_NAME] = run_name
run = client.create_run(experiment_id=self.experiment_id, tags=tags)

Expand All @@ -211,7 +212,9 @@ def _start_active_run(
if active_run:
return active_run

return self._mlflow.start_run(run_name=run_name, tags=tags)
return self._mlflow.start_run(
run_name=run_name, experiment_id=self.experiment_id, tags=tags
)

def _run_exists(self, run_id: str) -> bool:
"""Check if run with the provided id exists."""
Expand Down
20 changes: 19 additions & 1 deletion python/ray/util/ml_utils/tests/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ def test_experiment_name(self):
)
assert self.mlflow_util.experiment_id == "0"

def test_run_started_with_correct_experiment(self):
experiment_name = "my_experiment_name"
# Make sure run is started under the correct experiment.
self.mlflow_util.setup_mlflow(
tracking_uri=self.tracking_uri, experiment_name=experiment_name
)
run = self.mlflow_util.start_run(set_active=True)
assert (
run.info.experiment_id
== self.mlflow_util._mlflow.get_experiment_by_name(
experiment_name
).experiment_id
)

self.mlflow_util.end_run()

def test_experiment_name_env_var(self):
os.environ["MLFLOW_EXPERIMENT_NAME"] = "existing_experiment"
self.mlflow_util.setup_mlflow(tracking_uri=self.tracking_uri)
Expand Down Expand Up @@ -79,10 +95,12 @@ def test_log_params(self):
params2 = {"b": "b"}
self.mlflow_util.start_run(set_active=True)
self.mlflow_util.log_params(params_to_log=params2, run_id=run_id)
assert self.mlflow_util._mlflow.get_run(run_id=run_id).data.params == {
run = self.mlflow_util._mlflow.get_run(run_id=run_id)
assert run.data.params == {
**params,
**params2,
}

self.mlflow_util.end_run()

def test_log_metrics(self):
Expand Down

0 comments on commit 8becbfa

Please sign in to comment.