Skip to content

Commit

Permalink
make fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
kartikgupta-db committed Apr 17, 2024
1 parent 8009f94 commit a15b9d5
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 21 deletions.
43 changes: 24 additions & 19 deletions databricks/sdk/dbutils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import base64
from dataclasses import dataclass
import json
import logging
import os
import threading
import typing
from collections import namedtuple
from dataclasses import dataclass

from .core import ApiClient, Config, DatabricksError
from .mixins import compute as compute_ext
Expand Down Expand Up @@ -247,46 +247,50 @@ def __getattr__(self, util) -> '_ProxyUtil':
class OverrideResult:
result: typing.Any


def get_local_notebook_path():
value = os.getenv("DATABRICKS_SOURCE_FILE")
if value is None:
raise ValueError("DABRICKS_SOURCE_FILE environment variable is not set. This is required to get the local notebook path.")
raise ValueError(
"DABRICKS_SOURCE_FILE environment variable is not set. This is required to get the local notebook path."
)

return value


class _OverrideProxyUtil:

@classmethod
def new(cls, path: str):
if len(cls.__get_matching_overrides(path)) > 0:
return _OverrideProxyUtil(path)
return None

def __init__(self, name: str):
self._name = name


# These are the paths that we want to override and not send to remote dbutils. NOTE, for each of these paths, no prefixes
# are sent to remote either. This could lead to unintentional breakage.
# are sent to remote either. This could lead to unintentional breakage.
# Our current proxy implementation (which sends everything to remote dbutils) uses `{util}.{method}(*args, **kwargs)` ONLY.
# This means, it is completely safe to override paths starting with `{util}.{attribute}.<other_parts>`, since none of the prefixes
# are being proxied to remote dbutils currently.
# are being proxied to remote dbutils currently.
proxy_override_paths = {
'notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()': get_local_notebook_path,
}
'notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()':
get_local_notebook_path,
}

@classmethod
@classmethod
def __get_matching_overrides(cls, path: str) -> list[str]:
return list(filter(lambda x: x.startswith(path), list(cls.proxy_override_paths.keys())))

def __run_override(self, path: str) -> typing.Optional[OverrideResult]:
overrides = self.__get_matching_overrides(path)
if len(overrides) == 1 and overrides[0] == path:
return OverrideResult(self.proxy_override_paths[overrides[0]]())
if len(overrides) > 0:

if len(overrides) > 0:
return OverrideResult(_OverrideProxyUtil(name=path))

return None

def __call__(self, *args, **kwds) -> typing.Any:
Expand All @@ -297,22 +301,23 @@ def __call__(self, *args, **kwds) -> typing.Any:
result = self.__run_override(callable_path)
if result:
return result.result

raise TypeError(f"{self._name} is not callable")

def __getattr__(self, method: str) -> typing.Any:
result = self.__run_override(f"{self._name}.{method}")
if result:
return result.result

raise AttributeError(f"module {self._name} has no attribute {method}")



class _ProxyUtil:
"""Enables temporary workaround to call remote in-REPL dbutils without having to re-implement them"""


def __init__(self, *, command_execution: compute.CommandExecutionAPI,
context_factory: typing.Callable[[], compute.ContextStatusResponse], cluster_id: str, name: str):
context_factory: typing.Callable[[],
compute.ContextStatusResponse], cluster_id: str, name: str):
self._commands = command_execution
self._cluster_id = cluster_id
self._context_factory = context_factory
Expand All @@ -325,7 +330,7 @@ def __getattr__(self, method: str) -> '_ProxyCall | _ProxyUtil | _OverrideProxyU
override = _OverrideProxyUtil.new(f"{self._name}.{method}")
if override:
return override

return _ProxyCall(command_execution=self._commands,
cluster_id=self._cluster_id,
context_factory=self._context_factory,
Expand Down
7 changes: 7 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from databricks.sdk import WorkspaceClient


client = WorkspaceClient(profile='logfood', cluster_id="0416-112901-opjiy4x6")
dbutils = client.dbutils

print(dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get())
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from databricks.sdk.core import Config
from databricks.sdk.credentials_provider import credentials_provider
from .integration.conftest import restorable_env


@credentials_provider('noop', [])
def noop_credentials(_: any):
Expand Down
1 change: 1 addition & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _load_debug_env_if_runs_from_ide(key) -> bool:
def _is_in_debug() -> bool:
return os.path.basename(sys.argv[0]) in ['_jb_pytest_runner.py', 'testlauncher.py', ]


@pytest.fixture(scope="function")
def restorable_env():
import os
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_dbconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def reload_modules(name: str):
except Exception as e:
print(f"Failed to reload {name}: {e}")


@pytest.fixture(params=list(DBCONNECT_DBR_CLIENT.keys()))
def setup_dbconnect_test(request, env_or_skip, restorable_env):
dbr = request.param
Expand Down
4 changes: 3 additions & 1 deletion tests/test_dbutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,10 @@ def test_jobs_task_values_get_throws(dbutils):
assert str(
e) == 'Must pass debugValue when calling get outside of a job context. debugValue cannot be None.'


def test_dbutils_proxy_overrides(dbutils, mocker, restorable_env):
import os
os.environ["DATABRICKS_SOURCE_FILE"] = "test_source_file"
mocker.patch('databricks.sdk.dbutils.RemoteDbUtils._cluster_id', return_value="test_cluster_id")
assert dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get() == "test_source_file"
assert dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get(
) == "test_source_file"

0 comments on commit a15b9d5

Please sign in to comment.