diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index f9486bc3e3d31..16ad804065d05 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -103,7 +103,13 @@ def start_run(run_id=None, experiment_id=None, run_name=None, nested=False): raise Exception(("Run with UUID {} is already active. To start a nested " + "run call start_run with nested=True").format( _active_run_stack[0].info.run_id)) - existing_run_id = run_id or os.environ.get(_RUN_ID_ENV_VAR, None) + if run_id: + existing_run_id = run_id + elif _RUN_ID_ENV_VAR in os.environ: + existing_run_id = os.environ[_RUN_ID_ENV_VAR] + del os.environ[_RUN_ID_ENV_VAR] + else: + existing_run_id = None if existing_run_id: _validate_run_id(existing_run_id) active_run_obj = MlflowClient().get_run(existing_run_id) diff --git a/tests/tracking/test_tracking.py b/tests/tracking/test_tracking.py index 221cf791bfe11..69f3f3fbf6082 100644 --- a/tests/tracking/test_tracking.py +++ b/tests/tracking/test_tracking.py @@ -20,6 +20,7 @@ from mlflow.utils.file_utils import local_file_uri_to_path from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID, MLFLOW_USER, MLFLOW_SOURCE_NAME, \ MLFLOW_SOURCE_TYPE +from mlflow.tracking.fluent import _RUN_ID_ENV_VAR from tests.projects.utils import tracking_uri_mock @@ -453,10 +454,16 @@ def test_with_startrun(): def test_parent_create_run(tracking_uri_mock): + + with mlflow.start_run() as parent_run: + parent_run_id = parent_run.info.run_id + os.environ[_RUN_ID_ENV_VAR] = parent_run_id with mlflow.start_run() as parent_run: + assert parent_run.info.run_id == parent_run_id with pytest.raises(Exception, match='To start a nested run'): mlflow.start_run() with mlflow.start_run(nested=True) as child_run: + assert child_run.info.run_id != parent_run_id with mlflow.start_run(nested=True) as grand_child_run: pass