Skip to content

Commit

Permalink
Created command context provider to set job group ID as part of run t…
Browse files Browse the repository at this point in the history
…ags (mlflow#4521)
  • Loading branch information
jerrylian-db authored Jul 2, 2021
1 parent e808a38 commit bc86f1d
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 0 deletions.
15 changes: 15 additions & 0 deletions mlflow/tracking/context/databricks_command_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from mlflow.tracking.context.abstract_context import RunContextProvider
from mlflow.utils import databricks_utils
from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_NOTEBOOK_COMMAND_ID


class DatabricksCommandRunContext(RunContextProvider):
def in_context(self):
return databricks_utils.get_job_group_id() is not None

def tags(self):
job_group_id = databricks_utils.get_job_group_id()
tags = {}
if job_group_id is not None:
tags[MLFLOW_DATABRICKS_NOTEBOOK_COMMAND_ID] = job_group_id
return tags
2 changes: 2 additions & 0 deletions mlflow/tracking/context/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mlflow.tracking.context.databricks_notebook_context import DatabricksNotebookRunContext
from mlflow.tracking.context.databricks_job_context import DatabricksJobRunContext
from mlflow.tracking.context.databricks_cluster_context import DatabricksClusterRunContext
from mlflow.tracking.context.databricks_command_context import DatabricksCommandRunContext


_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,6 +54,7 @@ def __iter__(self):
_run_context_provider_registry.register(DatabricksNotebookRunContext)
_run_context_provider_registry.register(DatabricksJobRunContext)
_run_context_provider_registry.register(DatabricksClusterRunContext)
_run_context_provider_registry.register(DatabricksCommandRunContext)

_run_context_provider_registry.register_entrypoints()

Expand Down
10 changes: 10 additions & 0 deletions mlflow/utils/databricks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ def get_cluster_id():
return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId")


def get_job_group_id():
try:
dbutils = _get_dbutils()
job_group_id = dbutils.entry_point.getJobGroupId()
if job_group_id is not None:
return job_group_id
except Exception:
return None


def get_job_id():
try:
return _get_command_context().jobId().get()
Expand Down
2 changes: 2 additions & 0 deletions mlflow/utils/mlflow_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
MLFLOW_DATABRICKS_WEBAPP_URL = "mlflow.databricks.webappURL"
MLFLOW_DATABRICKS_RUN_URL = "mlflow.databricks.runURL"
MLFLOW_DATABRICKS_CLUSTER_ID = "mlflow.databricks.cluster.id"
# The unique ID of a command execution in a Databricks notebook
MLFLOW_DATABRICKS_NOTEBOOK_COMMAND_ID = "mlflow.databricks.notebook.commandID"
# The SHELL_JOB_ID and SHELL_JOB_RUN_ID tags are used for tracking the
# Databricks Job ID and Databricks Job Run ID associated with an MLflow Project run
MLFLOW_DATABRICKS_SHELL_JOB_ID = "mlflow.databricks.shellJobID"
Expand Down
21 changes: 21 additions & 0 deletions tests/tracking/context/test_databricks_command_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from unittest import mock

from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_NOTEBOOK_COMMAND_ID
from mlflow.tracking.context.databricks_command_context import DatabricksCommandRunContext


def test_databricks_command_run_context_in_context():
with mock.patch("mlflow.utils.databricks_utils.get_job_group_id", return_value="1"):
assert DatabricksCommandRunContext().in_context()


def test_databricks_command_run_context_tags():
with mock.patch("mlflow.utils.databricks_utils.get_job_group_id") as job_group_id_mock:
assert DatabricksCommandRunContext().tags() == {
MLFLOW_DATABRICKS_NOTEBOOK_COMMAND_ID: job_group_id_mock.return_value
}


def test_databricks_command_run_context_tags_nones():
with mock.patch("mlflow.utils.databricks_utils.get_job_group_id", return_value=None):
assert DatabricksCommandRunContext().tags() == {}

0 comments on commit bc86f1d

Please sign in to comment.