Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch dbutils.notebook.entry_point... to return current local notebook path from env var #618

Merged
merged 9 commits into from
May 29, 2024
80 changes: 79 additions & 1 deletion databricks/sdk/dbutils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import base64
import json
import logging
import os
import threading
import typing
from collections import namedtuple
from dataclasses import dataclass

from .core import ApiClient, Config, DatabricksError
from .mixins import compute as compute_ext
Expand Down Expand Up @@ -241,6 +243,75 @@ def __getattr__(self, util) -> '_ProxyUtil':
name=util)


@dataclass
class OverrideResult:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? We just return result.result every time. Might as well return the result itself?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to provide a type that wraps a None result from the override. Some of the functions can then safely return None, if the override is not found or is not able to run.

result: typing.Any


def get_local_notebook_path():
value = os.getenv("DATABRICKS_SOURCE_FILE")
if value is None:
raise ValueError(
"DABRICKS_SOURCE_FILE environment variable is not set. This is required to get the local notebook path."
kartikgupta-db marked this conversation as resolved.
Show resolved Hide resolved
)

return value


class _OverrideProxyUtil:

@classmethod
def new(cls, path: str):
if len(cls.__get_matching_overrides(path)) > 0:
return _OverrideProxyUtil(path)
return None

def __init__(self, name: str):
self._name = name

# These are the paths that we want to override and not send to remote dbutils. NOTE, for each of these paths, no prefixes
# are sent to remote either. This could lead to unintentional breakage.
# Our current proxy implementation (which sends everything to remote dbutils) uses `{util}.{method}(*args, **kwargs)` ONLY.
# This means, it is completely safe to override paths starting with `{util}.{attribute}.<other_parts>`, since none of the prefixes
# are being proxied to remote dbutils currently.
proxy_override_paths = {
'notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()':
get_local_notebook_path,
}

@classmethod
def __get_matching_overrides(cls, path: str):
return list(filter(lambda x: x.startswith(path), list(cls.proxy_override_paths.keys())))
kartikgupta-db marked this conversation as resolved.
Show resolved Hide resolved

def __run_override(self, path: str) -> typing.Optional[OverrideResult]:
overrides = self.__get_matching_overrides(path)
if len(overrides) == 1 and overrides[0] == path:
return OverrideResult(self.proxy_override_paths[overrides[0]]())

if len(overrides) > 0:
return OverrideResult(_OverrideProxyUtil(name=path))

return None

def __call__(self, *args, **kwds) -> typing.Any:
if len(args) != 0 or len(kwds) != 0:
raise TypeError(f"{self._name}() takes no arguments (1 given)")
kartikgupta-db marked this conversation as resolved.
Show resolved Hide resolved

callable_path = f"{self._name}()"
result = self.__run_override(callable_path)
if result:
return result.result

raise TypeError(f"{self._name} is not callable")

def __getattr__(self, method: str) -> typing.Any:
result = self.__run_override(f"{self._name}.{method}")
if result:
return result.result

raise AttributeError(f"module {self._name} has no attribute {method}")


class _ProxyUtil:
"""Enables temporary workaround to call remote in-REPL dbutils without having to re-implement them"""

Expand All @@ -252,7 +323,14 @@ def __init__(self, *, command_execution: compute.CommandExecutionAPI,
self._context_factory = context_factory
self._name = name

def __getattr__(self, method: str) -> '_ProxyCall':
def __call__(self):
raise NotImplementedError(f"dbutils.{self._name} is not callable")

def __getattr__(self, method: str) -> '_ProxyCall | _ProxyUtil | _OverrideProxyUtil':
override = _OverrideProxyUtil.new(f"{self._name}.{method}")
if override:
return override

return _ProxyCall(command_execution=self._commands,
cluster_id=self._cluster_id,
context_factory=self._context_factory,
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from databricks.sdk.core import Config
from databricks.sdk.credentials_provider import credentials_provider

from .integration.conftest import restorable_env # type: ignore


@credentials_provider('noop', [])
def noop_credentials(_: any):
Expand Down
12 changes: 12 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,15 @@ def _load_debug_env_if_runs_from_ide(key) -> bool:

def _is_in_debug() -> bool:
return os.path.basename(sys.argv[0]) in ['_jb_pytest_runner.py', 'testlauncher.py', ]


@pytest.fixture(scope="function")
def restorable_env():
import os
current_env = os.environ.copy()
yield
for k, v in os.environ.items():
if k not in current_env:
del os.environ[k]
elif v != current_env[k]:
os.environ[k] = current_env[k]
12 changes: 0 additions & 12 deletions tests/integration/test_dbconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,6 @@ def reload_modules(name: str):
print(f"Failed to reload {name}: {e}")


@pytest.fixture(scope="function")
def restorable_env():
import os
current_env = os.environ.copy()
yield
for k, v in os.environ.items():
if k not in current_env:
del os.environ[k]
elif v != current_env[k]:
os.environ[k] = current_env[k]


@pytest.fixture(params=list(DBCONNECT_DBR_CLIENT.keys()))
def setup_dbconnect_test(request, env_or_skip, restorable_env):
dbr = request.param
Expand Down
8 changes: 8 additions & 0 deletions tests/test_dbutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,11 @@ def test_jobs_task_values_get_throws(dbutils):
except TypeError as e:
assert str(
e) == 'Must pass debugValue when calling get outside of a job context. debugValue cannot be None.'


def test_dbutils_proxy_overrides(dbutils, mocker, restorable_env):
import os
os.environ["DATABRICKS_SOURCE_FILE"] = "test_source_file"
mocker.patch('databricks.sdk.dbutils.RemoteDbUtils._cluster_id', return_value="test_cluster_id")
assert dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get(
) == "test_source_file"
Loading