From a15b9d56dbcb3c03df234066e45c45ee90060a1d Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Wed, 17 Apr 2024 15:53:15 +0200 Subject: [PATCH] make fmt --- databricks/sdk/dbutils.py | 43 ++++++++++++++++------------- main.py | 7 +++++ tests/conftest.py | 2 +- tests/integration/conftest.py | 1 + tests/integration/test_dbconnect.py | 1 + tests/test_dbutils.py | 4 ++- 6 files changed, 37 insertions(+), 21 deletions(-) create mode 100644 main.py diff --git a/databricks/sdk/dbutils.py b/databricks/sdk/dbutils.py index e901b452..84fdef7b 100644 --- a/databricks/sdk/dbutils.py +++ b/databricks/sdk/dbutils.py @@ -1,11 +1,11 @@ import base64 -from dataclasses import dataclass import json import logging import os import threading import typing from collections import namedtuple +from dataclasses import dataclass from .core import ApiClient, Config, DatabricksError from .mixins import compute as compute_ext @@ -247,13 +247,17 @@ def __getattr__(self, util) -> '_ProxyUtil': class OverrideResult: result: typing.Any + def get_local_notebook_path(): value = os.getenv("DATABRICKS_SOURCE_FILE") if value is None: - raise ValueError("DABRICKS_SOURCE_FILE environment variable is not set. This is required to get the local notebook path.") + raise ValueError( + "DABRICKS_SOURCE_FILE environment variable is not set. This is required to get the local notebook path." + ) return value + class _OverrideProxyUtil: @classmethod @@ -261,32 +265,32 @@ def new(cls, path: str): if len(cls.__get_matching_overrides(path)) > 0: return _OverrideProxyUtil(path) return None - + def __init__(self, name: str): self._name = name - # These are the paths that we want to override and not send to remote dbutils. NOTE, for each of these paths, no prefixes - # are sent to remote either. This could lead to unintentional breakage. + # are sent to remote either. This could lead to unintentional breakage. # Our current proxy implementation (which sends everything to remote dbutils) uses `{util}.{method}(*args, **kwargs)` ONLY. # This means, it is completely safe to override paths starting with `{util}.{attribute}.`, since none of the prefixes - # are being proxied to remote dbutils currently. + # are being proxied to remote dbutils currently. proxy_override_paths = { - 'notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()': get_local_notebook_path, - } + 'notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()': + get_local_notebook_path, + } - @classmethod + @classmethod def __get_matching_overrides(cls, path: str) -> list[str]: return list(filter(lambda x: x.startswith(path), list(cls.proxy_override_paths.keys()))) - + def __run_override(self, path: str) -> typing.Optional[OverrideResult]: overrides = self.__get_matching_overrides(path) if len(overrides) == 1 and overrides[0] == path: return OverrideResult(self.proxy_override_paths[overrides[0]]()) - - if len(overrides) > 0: + + if len(overrides) > 0: return OverrideResult(_OverrideProxyUtil(name=path)) - + return None def __call__(self, *args, **kwds) -> typing.Any: @@ -297,22 +301,23 @@ def __call__(self, *args, **kwds) -> typing.Any: result = self.__run_override(callable_path) if result: return result.result - + raise TypeError(f"{self._name} is not callable") - + def __getattr__(self, method: str) -> typing.Any: result = self.__run_override(f"{self._name}.{method}") if result: return result.result raise AttributeError(f"module {self._name} has no attribute {method}") - + + class _ProxyUtil: """Enables temporary workaround to call remote in-REPL dbutils without having to re-implement them""" - def __init__(self, *, command_execution: compute.CommandExecutionAPI, - context_factory: typing.Callable[[], compute.ContextStatusResponse], cluster_id: str, name: str): + context_factory: typing.Callable[[], + compute.ContextStatusResponse], cluster_id: str, name: str): self._commands = command_execution self._cluster_id = cluster_id self._context_factory = context_factory @@ -325,7 +330,7 @@ def __getattr__(self, method: str) -> '_ProxyCall | _ProxyUtil | _OverrideProxyU override = _OverrideProxyUtil.new(f"{self._name}.{method}") if override: return override - + return _ProxyCall(command_execution=self._commands, cluster_id=self._cluster_id, context_factory=self._context_factory, diff --git a/main.py b/main.py new file mode 100644 index 00000000..9727940f --- /dev/null +++ b/main.py @@ -0,0 +1,7 @@ +from databricks.sdk import WorkspaceClient + + +client = WorkspaceClient(profile='logfood', cluster_id="0416-112901-opjiy4x6") +dbutils = client.dbutils + +print(dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()) diff --git a/tests/conftest.py b/tests/conftest.py index ceaacae1..80753ae9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ from databricks.sdk.core import Config from databricks.sdk.credentials_provider import credentials_provider -from .integration.conftest import restorable_env + @credentials_provider('noop', []) def noop_credentials(_: any): diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index a22d4be7..081b50a8 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -105,6 +105,7 @@ def _load_debug_env_if_runs_from_ide(key) -> bool: def _is_in_debug() -> bool: return os.path.basename(sys.argv[0]) in ['_jb_pytest_runner.py', 'testlauncher.py', ] + @pytest.fixture(scope="function") def restorable_env(): import os diff --git a/tests/integration/test_dbconnect.py b/tests/integration/test_dbconnect.py index 3d231c8a..59d327ae 100644 --- a/tests/integration/test_dbconnect.py +++ b/tests/integration/test_dbconnect.py @@ -22,6 +22,7 @@ def reload_modules(name: str): except Exception as e: print(f"Failed to reload {name}: {e}") + @pytest.fixture(params=list(DBCONNECT_DBR_CLIENT.keys())) def setup_dbconnect_test(request, env_or_skip, restorable_env): dbr = request.param diff --git a/tests/test_dbutils.py b/tests/test_dbutils.py index b6e4ee2e..c25e07e3 100644 --- a/tests/test_dbutils.py +++ b/tests/test_dbutils.py @@ -227,8 +227,10 @@ def test_jobs_task_values_get_throws(dbutils): assert str( e) == 'Must pass debugValue when calling get outside of a job context. debugValue cannot be None.' + def test_dbutils_proxy_overrides(dbutils, mocker, restorable_env): import os os.environ["DATABRICKS_SOURCE_FILE"] = "test_source_file" mocker.patch('databricks.sdk.dbutils.RemoteDbUtils._cluster_id', return_value="test_cluster_id") - assert dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get() == "test_source_file" + assert dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get( + ) == "test_source_file"