Skip to content

Commit

Permalink
Add is_in_databricks_runtime method (mlflow#4544)
Browse files Browse the repository at this point in the history
* init

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* fix tests

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* fix format

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 authored Jul 13, 2021
1 parent 48a5128 commit 92561c0
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 32 deletions.
7 changes: 2 additions & 5 deletions mlflow/utils/autologging_utils/versioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
from packaging.version import Version, InvalidVersion
from pkg_resources import resource_filename

from mlflow.utils.databricks_utils import (
is_in_databricks_job,
is_in_databricks_notebook,
)
from mlflow.utils.databricks_utils import is_in_databricks_runtime


# A map FLAVOR_NAME -> a tuple of (dependent_module_name, key_in_module_version_info_dict)
Expand Down Expand Up @@ -72,7 +69,7 @@ def is_flavor_supported_for_associated_package_versions(flavor_name):
actual_version = importlib.import_module(module_name).__version__

# In Databricks, treat 'pyspark 3.x.y.dev0' as 'pyspark 3.x.y'
if module_name == "pyspark" and (is_in_databricks_notebook() or is_in_databricks_job()):
if module_name == "pyspark" and is_in_databricks_runtime():
actual_version = _strip_dev_version_suffix(actual_version)

if _violates_pep_440(actual_version) or _is_pre_or_dev_release(actual_version):
Expand Down
12 changes: 11 additions & 1 deletion mlflow/utils/databricks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def is_in_databricks_job():
return False


def is_in_databricks_runtime():
try:
# pylint: disable=unused-import,import-error,no-name-in-module,unused-variable
import pyspark.databricks

return True
except ModuleNotFoundError:
return False


def is_dbfs_fuse_available():
with open(os.devnull, "w") as devnull_stderr, open(os.devnull, "w") as devnull_stdout:
try:
Expand Down Expand Up @@ -135,7 +145,7 @@ def get_notebook_path():


def get_databricks_runtime():
if is_in_databricks_notebook() or is_in_databricks_job():
if is_in_databricks_runtime():
spark_session = _get_active_spark_session()
if spark_session is not None:
return spark_session.conf.get(
Expand Down
13 changes: 3 additions & 10 deletions tests/autologging/test_autologging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,17 +841,10 @@ def test_dev_version_pyspark_is_supported_in_databricks(flavor, module_version,
with mock.patch(module_name + ".__version__", module_version):
# In Databricks
with mock.patch(
"mlflow.utils.autologging_utils.versioning.is_in_databricks_notebook",
return_value=True,
) as mock_notebook:
"mlflow.utils.autologging_utils.versioning.is_in_databricks_runtime", return_value=True,
) as mock_runtime:
assert is_flavor_supported_for_associated_package_versions(flavor) == expected_result
mock_notebook.assert_called()

with mock.patch(
"mlflow.utils.autologging_utils.versioning.is_in_databricks_job", return_value=True,
) as mock_job:
assert is_flavor_supported_for_associated_package_versions(flavor) == expected_result
mock_job.assert_called()
mock_runtime.assert_called()

# Not in Databricks
assert is_flavor_supported_for_associated_package_versions(flavor) is False
Expand Down
16 changes: 0 additions & 16 deletions tests/tracking/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,19 +500,3 @@ def test_get_databricks_runtime_nondb(mock_spark_session):
runtime = get_databricks_runtime()
assert runtime is None
mock_spark_session.conf.get.assert_not_called()


def test_get_databricks_runtime_in_notebook(mock_spark_session):
with mock.patch("mlflow.utils.databricks_utils.is_in_databricks_notebook", return_value=True):
get_databricks_runtime()
mock_spark_session.conf.get.assert_called_once_with(
"spark.databricks.clusterUsageTags.sparkVersion", default=None
)


def test_get_databricks_runtime_in_job(mock_spark_session):
with mock.patch("mlflow.utils.databricks_utils.is_in_databricks_job", return_value=True):
get_databricks_runtime()
mock_spark_session.conf.get.assert_called_once_with(
"spark.databricks.clusterUsageTags.sparkVersion", default=None
)
19 changes: 19 additions & 0 deletions tests/utils/test_databricks_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from unittest import mock
import pytest

Expand All @@ -20,6 +21,7 @@ def test_no_throw():
assert not databricks_utils.is_in_databricks_notebook()
assert not databricks_utils.is_in_databricks_job()
assert not databricks_utils.is_dbfs_fuse_available()
assert not databricks_utils.is_in_databricks_runtime()


@mock.patch("databricks_cli.configure.provider.get_config")
Expand Down Expand Up @@ -200,3 +202,20 @@ def test_databricks_params_throws_errors(ProfileConfigProvider):
ProfileConfigProvider.return_value = mock_provider
with pytest.raises(Exception):
databricks_utils.get_databricks_host_creds()


def test_is_in_databricks_runtime():
with mock.patch(
"sys.modules",
new={**sys.modules, "pyspark": mock.MagicMock(), "pyspark.databricks": mock.MagicMock()},
):
# pylint: disable=unused-import,import-error,no-name-in-module,unused-variable
import pyspark.databricks

assert databricks_utils.is_in_databricks_runtime()

with mock.patch("sys.modules", new={**sys.modules, "pyspark": mock.MagicMock()}):
with pytest.raises(ModuleNotFoundError, match="No module named 'pyspark.databricks'"):
# pylint: disable=unused-import,import-error,no-name-in-module,unused-variable
import pyspark.databricks
assert not databricks_utils.is_in_databricks_runtime()

0 comments on commit 92561c0

Please sign in to comment.