diff --git a/.codegen.json b/.codegen.json index f8e3a2351..32d3979c0 100644 --- a/.codegen.json +++ b/.codegen.json @@ -8,5 +8,23 @@ }, "samples": { ".codegen/example.py.tmpl": "examples/{{.Service.SnakeName}}/{{.Method.SnakeName}}_{{.SnakeName}}.py" + }, + "version": { + "databricks/sdk/version.py": "__version__ = '$VERSION'" + }, + "toolchain": { + "required": ["python3"], + "pre_setup": [ + "python3 -m venv .databricks" + ], + "prepend_path": ".databricks/bin", + "setup": [ + "pip install '.[dev]'" + ], + "post_generate": [ + "pytest -m 'not integration' --cov=databricks --cov-report html tests", + "pip install .", + "python docs/gen-client-docs.py" + ] } -} \ No newline at end of file +} diff --git a/.codegen/__init__.py.tmpl b/.codegen/__init__.py.tmpl index c52b8ed65..a8e730af0 100644 --- a/.codegen/__init__.py.tmpl +++ b/.codegen/__init__.py.tmpl @@ -17,6 +17,22 @@ from databricks.sdk.service.{{.Package.Name}} import {{.PascalName}}API{{end}} {{- getOrDefault $mixins $genApi $genApi -}} {{- end -}} +def _make_dbutils(config: client.Config): + # We try to directly check if we are in runtime, instead of + # trying to import from databricks.sdk.runtime. This is to prevent + # remote dbutils from being created without the config, which is both + # expensive (will need to check all credential providers) and can + # throw errors (when no env vars are set). + try: + from dbruntime import UserNamespaceInitializer + except ImportError: + return dbutils.RemoteDbUtils(config) + + # We are in runtime, so we can use the runtime dbutils + from databricks.sdk.runtime import dbutils as runtime_dbutils + return runtime_dbutils + + class WorkspaceClient: def __init__(self, *{{range $args}}, {{.}}: str = None{{end}}, debug_truncate_bytes: int = None, @@ -33,7 +49,7 @@ class WorkspaceClient: product=product, product_version=product_version) self.config = config.copy() - self.dbutils = dbutils.RemoteDbUtils(self.config) + self.dbutils = _make_dbutils(self.config) self.api_client = client.ApiClient(self.config) self.files = FilesMixin(self.api_client) {{- range .Services}}{{if not .IsAccounts}} diff --git a/.codegen/changelog.md.tmpl b/.codegen/changelog.md.tmpl new file mode 100644 index 000000000..8ecf6b9b5 --- /dev/null +++ b/.codegen/changelog.md.tmpl @@ -0,0 +1,55 @@ +# Version changelog + +## {{.Version}} + +{{range .Changes -}} + * {{.}}. +{{end}}{{- if .ApiChanges}} +API Changes: +{{range .ApiChanges}} + * {{.Action}} {{template "what" .}}{{if .Extra}} {{.Extra}}{{with .Other}} {{template "what" .}}{{end}}{{end}}. +{{- end}} + +OpenAPI SHA: {{.Sha}}, Date: {{.Changed}} +{{- end}}{{if .DependencyUpdates}} +Dependency updates: +{{range .DependencyUpdates}} + * {{.}}. +{{- end -}} +{{end}} + +## {{.PrevVersion}} + +{{- define "what" -}} + {{if eq .X "package" -}} + `databricks.sdk.service.{{.Package.Name}}` package + {{- else if eq .X "service" -}} + {{template "service" .Service}} + {{- else if eq .X "method" -}} + `{{.Method.SnakeName}}()` method for {{template "service" .Method.Service}} + {{- else if eq .X "entity" -}} + {{template "entity" .Entity}} dataclass + {{- else if eq .X "field" -}} + `{{.Field.SnakeName}}` field for {{template "entity" .Field.Of}} + {{- end}} +{{- end -}} + +{{- define "service" -}} + [{{if .IsAccounts}}a{{else}}w{{end}}.{{.SnakeName}}](https://databricks-sdk-py.readthedocs.io/en/latest/{{if .IsAccounts}}account{{else}}workspace{{end}}/{{.SnakeName}}.html) {{if .IsAccounts}}account{{else}}workspace{{end}}-level service +{{- end -}} + +{{- define "entity" -}} + {{- if not . }}any /* ERROR */ + {{- else if .IsEmpty}}`any` + {{- else if .PascalName}}`databricks.sdk.service.{{.Package.Name}}.{{.PascalName}}` + {{- else if .IsAny}}`any` + {{- else if .IsString}}`str` + {{- else if .IsBool}}`bool` + {{- else if .IsInt64}}`int` + {{- else if .IsFloat64}}`float` + {{- else if .IsInt}}`int` + {{- else if .ArrayValue }}list[{{template "entity" .ArrayValue}}] + {{- else if .MapValue }}dict[str,{{template "entity" .MapValue}}] + {{- else}}`databricks.sdk.service.{{.Package.Name}}.{{.PascalName}}` + {{- end -}} +{{- end -}} diff --git a/.codegen/example.py.tmpl b/.codegen/example.py.tmpl index ea85e2f71..dba71d9bf 100644 --- a/.codegen/example.py.tmpl +++ b/.codegen/example.py.tmpl @@ -43,7 +43,7 @@ import time, base64, os {{- else if eq .Type "lookup" -}} {{template "expr" .X}}.{{.Field.SnakeName}} {{- else if eq .Type "enum" -}} - {{.Package}}.{{.Entity.PascalName}}.{{.Content}}{{if eq .Content "None"}}_{{end}} + {{.Package}}.{{.Entity.PascalName}}.{{.ConstantName}} {{- else if eq .Type "variable" -}} {{if eq .SnakeName "true"}}True {{- else if eq .SnakeName "false"}}False @@ -109,4 +109,4 @@ f'/Users/{w.current_user.me().user_name}/sdk-{time.time_ns()}' {{- else -}} {{.SnakeName}}({{range $i, $x := .Args}}{{if $i}}, {{end}}{{template "expr" .}}{{end}}) {{- end -}} -{{- end}} \ No newline at end of file +{{- end}} diff --git a/.codegen/service.py.tmpl b/.codegen/service.py.tmpl index 785c29440..590be6732 100644 --- a/.codegen/service.py.tmpl +++ b/.codegen/service.py.tmpl @@ -40,7 +40,7 @@ class {{.PascalName}}{{if eq "List" .PascalName}}Request{{end}}:{{if .Descriptio {{else if .Enum}}class {{.PascalName}}(Enum): {{if .Description}}"""{{.Comment " " 100 | trimSuffix "\"" }}"""{{end}} {{range .Enum }} - {{.Content}}{{if eq .Content "None"}}_{{end}} = '{{.Content}}'{{end}}{{end}} + {{.ConstantName}} = '{{.Content}}'{{end}}{{end}} {{end}} {{- define "from_dict_type" -}} {{- if not .Entity }}None @@ -113,8 +113,8 @@ class {{.Name}}API:{{if .Description}} def {{.SnakeName}}(self{{range .Binding}}, {{.PollField.SnakeName}}: {{template "type-nq" .PollField.Entity}}{{end}}, timeout=timedelta(minutes={{.Timeout}}), callback: Optional[Callable[[{{.Poll.Response.PascalName}}], None]] = None) -> {{.Poll.Response.PascalName}}: deadline = time.time() + timeout.total_seconds() - target_states = ({{range .Success}}{{.Entity.PascalName}}.{{.Content}}, {{end}}){{if .Failure}} - failure_states = ({{range .Failure}}{{.Entity.PascalName}}.{{.Content}}, {{end}}){{end}} + target_states = ({{range .Success}}{{.Entity.PascalName}}.{{.ConstantName}}, {{end}}){{if .Failure}} + failure_states = ({{range .Failure}}{{.Entity.PascalName}}.{{.ConstantName}}, {{end}}){{end}} status_message = 'polling...' attempt = 1 while time.time() < deadline: @@ -218,7 +218,7 @@ class {{.Name}}API:{{if .Description}} {{define "method-call-paginated" -}} {{if .Pagination.MultiRequest}} - {{if .Pagination.NeedsOffsetDedupe -}} + {{if .NeedsOffsetDedupe -}} # deduplicate items that may have been added during iteration seen = set() {{- end}}{{if and .Pagination.Offset (not (eq .Path "/api/2.0/clusters/events")) }} @@ -228,8 +228,8 @@ class {{.Name}}API:{{if .Description}} if '{{.Pagination.Results.Name}}' not in json or not json['{{.Pagination.Results.Name}}']: return for v in json['{{.Pagination.Results.Name}}']: - {{if .Pagination.NeedsOffsetDedupe -}} - i = v['{{.Pagination.Entity.IdentifierField.Name}}'] + {{if .NeedsOffsetDedupe -}} + i = v['{{.IdentifierField.Name}}'] if i in seen: continue seen.add(i) diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index bbf1544b4..c1c3d189c 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -48,7 +48,7 @@ IpAccessListsAPI, TokenManagementAPI, TokensAPI, WorkspaceConfAPI) -from databricks.sdk.service.sharing import (ProvidersAPI, +from databricks.sdk.service.sharing import (CleanRoomsAPI, ProvidersAPI, RecipientActivationAPI, RecipientsAPI, SharesAPI) from databricks.sdk.service.sql import (AlertsAPI, DashboardsAPI, @@ -59,6 +59,22 @@ SecretsAPI, WorkspaceAPI) +def _make_dbutils(config: client.Config): + # We try to directly check if we are in runtime, instead of + # trying to import from databricks.sdk.runtime. This is to prevent + # remote dbutils from being created without the config, which is both + # expensive (will need to check all credential providers) and can + # throw errors (when no env vars are set). + try: + from dbruntime import UserNamespaceInitializer + except ImportError: + return dbutils.RemoteDbUtils(config) + + # We are in runtime, so we can use the runtime dbutils + from databricks.sdk.runtime import dbutils as runtime_dbutils + return runtime_dbutils + + class WorkspaceClient: def __init__(self, @@ -108,12 +124,13 @@ def __init__(self, product=product, product_version=product_version) self.config = config.copy() - self.dbutils = dbutils.RemoteDbUtils(self.config) + self.dbutils = _make_dbutils(self.config) self.api_client = client.ApiClient(self.config) self.files = FilesMixin(self.api_client) self.account_access_control_proxy = AccountAccessControlProxyAPI(self.api_client) self.alerts = AlertsAPI(self.api_client) self.catalogs = CatalogsAPI(self.api_client) + self.clean_rooms = CleanRoomsAPI(self.api_client) self.cluster_policies = ClusterPoliciesAPI(self.api_client) self.clusters = ClustersExt(self.api_client) self.command_execution = CommandExecutionAPI(self.api_client) diff --git a/databricks/sdk/_widgets/__init__.py b/databricks/sdk/_widgets/__init__.py new file mode 100644 index 000000000..4fef42696 --- /dev/null +++ b/databricks/sdk/_widgets/__init__.py @@ -0,0 +1,72 @@ +import logging +import typing +import warnings +from abc import ABC, abstractmethod + + +class WidgetUtils(ABC): + + def get(self, name: str): + return self._get(name) + + @abstractmethod + def _get(self, name: str) -> str: + pass + + def getArgument(self, name: str, default_value: typing.Optional[str] = None): + try: + return self.get(name) + except Exception: + return default_value + + def remove(self, name: str): + self._remove(name) + + @abstractmethod + def _remove(self, name: str): + pass + + def removeAll(self): + self._remove_all() + + @abstractmethod + def _remove_all(self): + pass + + +try: + # We only use ipywidgets if we are in a notebook interactive shell otherwise we raise error, + # to fallback to using default_widgets. Also, users WILL have IPython in their notebooks (jupyter), + # because we DO NOT SUPPORT any other notebook backends, and hence fallback to default_widgets. + from IPython.core.getipython import get_ipython + + # Detect if we are in an interactive notebook by iterating over the mro of the current ipython instance, + # to find ZMQInteractiveShell (jupyter). When used from REPL or file, this check will fail, since the + # mro only contains TerminalInteractiveShell. + if len(list(filter(lambda i: i.__name__ == 'ZMQInteractiveShell', get_ipython().__class__.__mro__))) == 0: + logging.debug("Not in an interactive notebook. Skipping ipywidgets implementation for dbutils.") + raise EnvironmentError("Not in an interactive notebook.") + + # For import errors in IPyWidgetUtil, we provide a warning message, prompting users to install the + # correct installation group of the sdk. + try: + from .ipywidgets_utils import IPyWidgetUtil + + widget_impl = IPyWidgetUtil + logging.debug("Using ipywidgets implementation for dbutils.") + + except ImportError as e: + # Since we are certain that we are in an interactive notebook, we can make assumptions about + # formatting and make the warning nicer for the user. + warnings.warn( + "\nTo use databricks widgets interactively in your notebook, please install databricks sdk using:\n" + "\tpip install 'databricks-sdk[notebook]'\n" + "Falling back to default_value_only implementation for databricks widgets.") + logging.debug(f"{e.msg}. Skipping ipywidgets implementation for dbutils.") + raise e + +except: + from .default_widgets_utils import DefaultValueOnlyWidgetUtils + + widget_impl = DefaultValueOnlyWidgetUtils + logging.debug("Using default_value_only implementation for dbutils.") diff --git a/databricks/sdk/_widgets/default_widgets_utils.py b/databricks/sdk/_widgets/default_widgets_utils.py new file mode 100644 index 000000000..9b61a75f6 --- /dev/null +++ b/databricks/sdk/_widgets/default_widgets_utils.py @@ -0,0 +1,42 @@ +import typing + +from . import WidgetUtils + + +class DefaultValueOnlyWidgetUtils(WidgetUtils): + + def __init__(self) -> None: + self._widgets: typing.Dict[str, str] = {} + + def text(self, name: str, defaultValue: str, label: typing.Optional[str] = None): + self._widgets[name] = defaultValue + + def dropdown(self, + name: str, + defaultValue: str, + choices: typing.List[str], + label: typing.Optional[str] = None): + self._widgets[name] = defaultValue + + def combobox(self, + name: str, + defaultValue: str, + choices: typing.List[str], + label: typing.Optional[str] = None): + self._widgets[name] = defaultValue + + def multiselect(self, + name: str, + defaultValue: str, + choices: typing.List[str], + label: typing.Optional[str] = None): + self._widgets[name] = defaultValue + + def _get(self, name: str) -> str: + return self._widgets[name] + + def _remove(self, name: str): + del self._widgets[name] + + def _remove_all(self): + self._widgets = {} diff --git a/databricks/sdk/_widgets/ipywidgets_utils.py b/databricks/sdk/_widgets/ipywidgets_utils.py new file mode 100644 index 000000000..6f27df438 --- /dev/null +++ b/databricks/sdk/_widgets/ipywidgets_utils.py @@ -0,0 +1,87 @@ +import typing + +from IPython.core.display_functions import display +from ipywidgets.widgets import (ValueWidget, Widget, widget_box, + widget_selection, widget_string) + +from .default_widgets_utils import WidgetUtils + + +class DbUtilsWidget: + + def __init__(self, label: str, value_widget: ValueWidget) -> None: + self.label_widget = widget_string.Label(label) + self.value_widget = value_widget + self.box = widget_box.Box([self.label_widget, self.value_widget]) + + def display(self): + display(self.box) + + def close(self): + self.label_widget.close() + self.value_widget.close() + self.box.close() + + @property + def value(self): + value = self.value_widget.value + if type(value) == str or value is None: + return value + if type(value) == list or type(value) == tuple: + return ','.join(value) + + raise ValueError("The returned value has invalid type (" + type(value) + ").") + + +class IPyWidgetUtil(WidgetUtils): + + def __init__(self) -> None: + self._widgets: typing.Dict[str, DbUtilsWidget] = {} + + def _register(self, name: str, widget: ValueWidget, label: typing.Optional[str] = None): + label = label if label is not None else name + w = DbUtilsWidget(label, widget) + + if name in self._widgets: + self.remove(name) + + self._widgets[name] = w + w.display() + + def text(self, name: str, defaultValue: str, label: typing.Optional[str] = None): + self._register(name, widget_string.Text(defaultValue), label) + + def dropdown(self, + name: str, + defaultValue: str, + choices: typing.List[str], + label: typing.Optional[str] = None): + self._register(name, widget_selection.Dropdown(value=defaultValue, options=choices), label) + + def combobox(self, + name: str, + defaultValue: str, + choices: typing.List[str], + label: typing.Optional[str] = None): + self._register(name, widget_string.Combobox(value=defaultValue, options=choices), label) + + def multiselect(self, + name: str, + defaultValue: str, + choices: typing.List[str], + label: typing.Optional[str] = None): + self._register( + name, + widget_selection.SelectMultiple(value=(defaultValue, ), + options=[("__EMPTY__", ""), *list(zip(choices, choices))]), label) + + def _get(self, name: str) -> str: + return self._widgets[name].value + + def _remove(self, name: str): + self._widgets[name].close() + del self._widgets[name] + + def _remove_all(self): + Widget.close_all() + self._widgets = {} diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index 0c58d44c7..68ffe04a8 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -1,1049 +1,1053 @@ -import abc -import base64 -import configparser -import copy -import functools -import json -import logging -import os -import pathlib -import platform -import re -import subprocess -import sys -import urllib.parse -from datetime import datetime -from json import JSONDecodeError -from typing import Callable, Dict, Iterable, List, Optional, Union - -import requests -import requests.auth -from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry - -from .azure import ARM_DATABRICKS_RESOURCE_ID, ENVIRONMENTS, AzureEnvironment -from .oauth import (ClientCredentials, OAuthClient, OidcEndpoints, Refreshable, - Token, TokenCache, TokenSource) -from .version import __version__ - -__all__ = ['Config', 'DatabricksError'] - -logger = logging.getLogger('databricks.sdk') - -HeaderFactory = Callable[[], Dict[str, str]] - - -class CredentialsProvider(abc.ABC): - """ CredentialsProvider is the protocol (call-side interface) - for authenticating requests to Databricks REST APIs""" - - @abc.abstractmethod - def auth_type(self) -> str: - ... - - @abc.abstractmethod - def __call__(self, cfg: 'Config') -> HeaderFactory: - ... - - -def credentials_provider(name: str, require: List[str]): - """ Given the function that receives a Config and returns RequestVisitor, - create CredentialsProvider with a given name and required configuration - attribute names to be present for this function to be called. """ - - def inner(func: Callable[['Config'], HeaderFactory]) -> CredentialsProvider: - - @functools.wraps(func) - def wrapper(cfg: 'Config') -> Optional[HeaderFactory]: - for attr in require: - if not getattr(cfg, attr): - return None - return func(cfg) - - wrapper.auth_type = lambda: name - return wrapper - - return inner - - -@credentials_provider('basic', ['host', 'username', 'password']) -def basic_auth(cfg: 'Config') -> HeaderFactory: - """ Given username and password, add base64-encoded Basic credentials """ - encoded = base64.b64encode(f'{cfg.username}:{cfg.password}'.encode()).decode() - static_credentials = {'Authorization': f'Basic {encoded}'} - - def inner() -> Dict[str, str]: - return static_credentials - - return inner - - -@credentials_provider('pat', ['host', 'token']) -def pat_auth(cfg: 'Config') -> HeaderFactory: - """ Adds Databricks Personal Access Token to every request """ - static_credentials = {'Authorization': f'Bearer {cfg.token}'} - - def inner() -> Dict[str, str]: - return static_credentials - - return inner - - -@credentials_provider('runtime', []) -def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]: - from databricks.sdk.runtime import init_runtime_native_auth - if init_runtime_native_auth is not None: - host, inner = init_runtime_native_auth() - cfg.host = host - return inner - try: - from dbruntime.databricks_repl_context import get_context - ctx = get_context() - if ctx is None: - logger.debug('Empty REPL context returned, skipping runtime auth') - return None - cfg.host = f'https://{ctx.workspaceUrl}' - - def inner() -> Dict[str, str]: - ctx = get_context() - return {'Authorization': f'Bearer {ctx.apiToken}'} - - return inner - except ImportError: - return None - - -@credentials_provider('oauth-m2m', ['is_aws', 'host', 'client_id', 'client_secret']) -def oauth_service_principal(cfg: 'Config') -> Optional[HeaderFactory]: - """ Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request, - if /oidc/.well-known/oauth-authorization-server is available on the given host. """ - # TODO: Azure returns 404 for UC workspace after redirecting to - # https://login.microsoftonline.com/{cfg.azure_tenant_id}/.well-known/oauth-authorization-server - oidc = cfg.oidc_endpoints - if oidc is None: - return None - token_source = ClientCredentials(client_id=cfg.client_id, - client_secret=cfg.client_secret, - token_url=oidc.token_endpoint, - scopes=["all-apis"], - use_header=True) - - def inner() -> Dict[str, str]: - token = token_source.token() - return {'Authorization': f'{token.token_type} {token.access_token}'} - - return inner - - -@credentials_provider('external-browser', ['host', 'auth_type']) -def external_browser(cfg: 'Config') -> Optional[HeaderFactory]: - if cfg.auth_type != 'external-browser': - return None - if cfg.client_id: - client_id = cfg.client_id - elif cfg.is_aws: - client_id = 'databricks-cli' - elif cfg.is_azure: - # Use Azure AD app for cases when Azure CLI is not available on the machine. - # App has to be registered as Single-page multi-tenant to support PKCE - # TODO: temporary app ID, change it later. - client_id = '6128a518-99a9-425b-8333-4cc94f04cacd' - else: - raise ValueError(f'local browser SSO is not supported') - oauth_client = OAuthClient(host=cfg.host, - client_id=client_id, - redirect_url='http://localhost:8020', - client_secret=cfg.client_secret) - - # Load cached credentials from disk if they exist. - # Note that these are local to the Python SDK and not reused by other SDKs. - token_cache = TokenCache(oauth_client) - credentials = token_cache.load() - if credentials: - # Force a refresh in case the loaded credentials are expired. - credentials.token() - else: - consent = oauth_client.initiate_consent() - if not consent: - return None - credentials = consent.launch_external_browser() - token_cache.save(credentials) - return credentials(cfg) - - -def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenSource]): - """ Resolves Azure Databricks workspace URL from ARM Resource ID """ - if cfg.host: - return - if not cfg.azure_workspace_resource_id: - return - arm = cfg.arm_environment.resource_manager_endpoint - token = token_source_for(arm).token() - resp = requests.get(f"{arm}{cfg.azure_workspace_resource_id}?api-version=2018-04-01", - headers={"Authorization": f"Bearer {token.access_token}"}) - if not resp.ok: - raise ValueError(f"Cannot resolve Azure Databricks workspace: {resp.content}") - cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}" - - -@credentials_provider('azure-client-secret', - ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id']) -def azure_service_principal(cfg: 'Config') -> HeaderFactory: - """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens - to every request, while automatically resolving different Azure environment endpoints. """ - - def token_source_for(resource: str) -> TokenSource: - aad_endpoint = cfg.arm_environment.active_directory_endpoint - return ClientCredentials(client_id=cfg.azure_client_id, - client_secret=cfg.azure_client_secret, - token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token", - endpoint_params={"resource": resource}, - use_params=True) - - _ensure_host_present(cfg, token_source_for) - logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id) - inner = token_source_for(cfg.effective_azure_login_app_id) - cloud = token_source_for(cfg.arm_environment.service_management_endpoint) - - def refreshed_headers() -> Dict[str, str]: - headers = { - 'Authorization': f"Bearer {inner.token().access_token}", - 'X-Databricks-Azure-SP-Management-Token': cloud.token().access_token, - } - if cfg.azure_workspace_resource_id: - headers["X-Databricks-Azure-Workspace-Resource-Id"] = cfg.azure_workspace_resource_id - return headers - - return refreshed_headers - - -class CliTokenSource(Refreshable): - - def __init__(self, cmd: List[str], token_type_field: str, access_token_field: str, expiry_field: str): - super().__init__() - self._cmd = cmd - self._token_type_field = token_type_field - self._access_token_field = access_token_field - self._expiry_field = expiry_field - - @staticmethod - def _parse_expiry(expiry: str) -> datetime: - for fmt in ("%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S.%f%z"): - try: - return datetime.strptime(expiry, fmt) - except ValueError as e: - last_e = e - if last_e: - raise last_e - - def refresh(self) -> Token: - try: - is_windows = sys.platform.startswith('win') - # windows requires shell=True to be able to execute 'az login' or other commands - # cannot use shell=True all the time, as it breaks macOS - out = subprocess.check_output(self._cmd, stderr=subprocess.STDOUT, shell=is_windows) - it = json.loads(out.decode()) - expires_on = self._parse_expiry(it[self._expiry_field]) - return Token(access_token=it[self._access_token_field], - token_type=it[self._token_type_field], - expiry=expires_on) - except ValueError as e: - raise ValueError(f"cannot unmarshal CLI result: {e}") - except subprocess.CalledProcessError as e: - message = e.output.decode().strip() - raise IOError(f'cannot get access token: {message}') from e - - -class AzureCliTokenSource(CliTokenSource): - """ Obtain the token granted by `az login` CLI command """ - - def __init__(self, resource: str): - cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"] - super().__init__(cmd=cmd, - token_type_field='tokenType', - access_token_field='accessToken', - expiry_field='expiresOn') - - -@credentials_provider('azure-cli', ['is_azure']) -def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]: - """ Adds refreshed OAuth token granted by `az login` command to every request. """ - token_source = AzureCliTokenSource(cfg.effective_azure_login_app_id) - try: - token_source.token() - except FileNotFoundError: - doc = 'https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest' - logger.debug(f'Most likely Azure CLI is not installed. See {doc} for details') - return None - - _ensure_host_present(cfg, lambda resource: AzureCliTokenSource(resource)) - logger.info("Using Azure CLI authentication with AAD tokens") - - def inner() -> Dict[str, str]: - token = token_source.token() - return {'Authorization': f'{token.token_type} {token.access_token}'} - - return inner - - -class DatabricksCliTokenSource(CliTokenSource): - """ Obtain the token granted by `databricks auth login` CLI command """ - - def __init__(self, cfg: 'Config'): - args = ['auth', 'token', '--host', cfg.host] - if cfg.is_account_client: - args += ['--account-id', cfg.account_id] - - cli_path = cfg.databricks_cli_path - if not cli_path: - cli_path = 'databricks' - - # If the path is unqualified, look it up in PATH. - if cli_path.count("/") == 0: - cli_path = self.__class__._find_executable(cli_path) - - super().__init__(cmd=[cli_path, *args], - token_type_field='token_type', - access_token_field='access_token', - expiry_field='expiry') - - @staticmethod - def _find_executable(name) -> str: - err = FileNotFoundError("Most likely the Databricks CLI is not installed") - for dir in os.getenv("PATH", default="").split(os.path.pathsep): - path = pathlib.Path(dir).joinpath(name).resolve() - if not path.is_file(): - continue - - # The new Databricks CLI is a single binary with size > 1MB. - # We use the size as a signal to determine which Databricks CLI is installed. - stat = path.stat() - if stat.st_size < (1024 * 1024): - err = FileNotFoundError("Databricks CLI version <0.100.0 detected") - continue - - return str(path) - - raise err - - -@credentials_provider('databricks-cli', ['host', 'is_aws']) -def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]: - try: - token_source = DatabricksCliTokenSource(cfg) - except FileNotFoundError as e: - logger.debug(e) - return None - - try: - token_source.token() - except IOError as e: - if 'databricks OAuth is not' in str(e): - logger.debug(f'OAuth not configured or not available: {e}') - return None - raise e - - logger.info("Using Databricks CLI authentication") - - def inner() -> Dict[str, str]: - token = token_source.token() - return {'Authorization': f'{token.token_type} {token.access_token}'} - - return inner - - -class MetadataServiceTokenSource(Refreshable): - """ Obtain the token granted by Databricks Metadata Service """ - METADATA_SERVICE_VERSION = "1" - METADATA_SERVICE_VERSION_HEADER = "X-Databricks-Metadata-Version" - METADATA_SERVICE_HOST_HEADER = "X-Databricks-Host" - _metadata_service_timeout = 10 # seconds - - def __init__(self, cfg: 'Config'): - super().__init__() - self.url = cfg.metadata_service_url - self.host = cfg.host - - def refresh(self) -> Token: - resp = requests.get(self.url, - timeout=self._metadata_service_timeout, - headers={ - self.METADATA_SERVICE_VERSION_HEADER: self.METADATA_SERVICE_VERSION, - self.METADATA_SERVICE_HOST_HEADER: self.host - }) - json_resp: dict[str, Union[str, float]] = resp.json() - access_token = json_resp.get("access_token", None) - if access_token is None: - raise ValueError("Metadata Service returned empty token") - token_type = json_resp.get("token_type", None) - if token_type is None: - raise ValueError("Metadata Service returned empty token type") - if json_resp["expires_on"] in ["", None]: - raise ValueError("Metadata Service returned invalid expiry") - try: - expiry = datetime.fromtimestamp(json_resp["expires_on"]) - except: - raise ValueError("Metadata Service returned invalid expiry") - - return Token(access_token=access_token, token_type=token_type, expiry=expiry) - - -@credentials_provider('metadata-service', ['host', 'metadata_service_url']) -def metadata_service(cfg: 'Config') -> Optional[HeaderFactory]: - """ Adds refreshed token granted by Databricks Metadata Service to every request. """ - - token_source = MetadataServiceTokenSource(cfg) - token_source.token() - logger.info("Using Databricks Metadata Service authentication") - - def inner() -> Dict[str, str]: - token = token_source.token() - return {'Authorization': f'{token.token_type} {token.access_token}'} - - return inner - - -class DefaultCredentials: - """ Select the first applicable credential provider from the chain """ - - def __init__(self) -> None: - self._auth_type = 'default' - - def auth_type(self) -> str: - return self._auth_type - - def __call__(self, cfg: 'Config') -> HeaderFactory: - auth_providers = [ - pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal, - azure_cli, external_browser, databricks_cli, runtime_native_auth - ] - for provider in auth_providers: - auth_type = provider.auth_type() - if cfg.auth_type and auth_type != cfg.auth_type: - # ignore other auth types if one is explicitly enforced - logger.debug(f"Ignoring {auth_type} auth, because {cfg.auth_type} is preferred") - continue - logger.debug(f'Attempting to configure auth: {auth_type}') - try: - header_factory = provider(cfg) - if not header_factory: - continue - self._auth_type = auth_type - return header_factory - except Exception as e: - raise ValueError(f'{auth_type}: {e}') from e - raise ValueError('cannot configure default credentials') - - -class ConfigAttribute: - """ Configuration attribute metadata and descriptor protocols. """ - - # name and transform are discovered from Config.__new__ - name: str = None - transform: type = str - - def __init__(self, env: str = None, auth: str = None, sensitive: bool = False): - self.env = env - self.auth = auth - self.sensitive = sensitive - - def __get__(self, cfg: 'Config', owner): - if not cfg: - return None - return cfg._inner.get(self.name, None) - - def __set__(self, cfg: 'Config', value: any): - cfg._inner[self.name] = self.transform(value) - - def __repr__(self) -> str: - return f"" - - -class Config: - host = ConfigAttribute(env='DATABRICKS_HOST') - account_id = ConfigAttribute(env='DATABRICKS_ACCOUNT_ID') - token = ConfigAttribute(env='DATABRICKS_TOKEN', auth='pat', sensitive=True) - username = ConfigAttribute(env='DATABRICKS_USERNAME', auth='basic') - password = ConfigAttribute(env='DATABRICKS_PASSWORD', auth='basic', sensitive=True) - client_id = ConfigAttribute(env='DATABRICKS_CLIENT_ID', auth='oauth') - client_secret = ConfigAttribute(env='DATABRICKS_CLIENT_SECRET', auth='oauth', sensitive=True) - profile = ConfigAttribute(env='DATABRICKS_CONFIG_PROFILE') - config_file = ConfigAttribute(env='DATABRICKS_CONFIG_FILE') - google_service_account = ConfigAttribute(env='DATABRICKS_GOOGLE_SERVICE_ACCOUNT', auth='google') - google_credentials = ConfigAttribute(env='GOOGLE_CREDENTIALS', auth='google', sensitive=True) - azure_workspace_resource_id = ConfigAttribute(env='DATABRICKS_AZURE_RESOURCE_ID', auth='azure') - azure_use_msi: bool = ConfigAttribute(env='ARM_USE_MSI', auth='azure') - azure_client_secret = ConfigAttribute(env='ARM_CLIENT_SECRET', auth='azure', sensitive=True) - azure_client_id = ConfigAttribute(env='ARM_CLIENT_ID', auth='azure') - azure_tenant_id = ConfigAttribute(env='ARM_TENANT_ID', auth='azure') - azure_environment = ConfigAttribute(env='ARM_ENVIRONMENT') - azure_login_app_id = ConfigAttribute(env='DATABRICKS_AZURE_LOGIN_APP_ID', auth='azure') - databricks_cli_path = ConfigAttribute(env='DATABRICKS_CLI_PATH') - auth_type = ConfigAttribute(env='DATABRICKS_AUTH_TYPE') - cluster_id = ConfigAttribute(env='DATABRICKS_CLUSTER_ID') - warehouse_id = ConfigAttribute(env='DATABRICKS_WAREHOUSE_ID') - skip_verify: bool = ConfigAttribute() - http_timeout_seconds: int = ConfigAttribute() - debug_truncate_bytes: int = ConfigAttribute(env='DATABRICKS_DEBUG_TRUNCATE_BYTES') - debug_headers: bool = ConfigAttribute(env='DATABRICKS_DEBUG_HEADERS') - rate_limit: int = ConfigAttribute(env='DATABRICKS_RATE_LIMIT') - retry_timeout_seconds: int = ConfigAttribute() - metadata_service_url = ConfigAttribute(env='DATABRICKS_METADATA_SERVICE_URL', - auth='metadata-service', - sensitive=True) - - def __init__(self, - *, - credentials_provider: CredentialsProvider = None, - product="unknown", - product_version="0.0.0", - **kwargs): - self._inner = {} - self._user_agent_other_info = [] - self._credentials_provider = credentials_provider if credentials_provider else DefaultCredentials() - try: - self._set_inner_config(kwargs) - self._load_from_env() - self._known_file_config_loader() - self._fix_host_if_needed() - self._validate() - self._init_auth() - self._product = product - self._product_version = product_version - except ValueError as e: - message = self.wrap_debug_info(str(e)) - raise ValueError(message) from e - - def wrap_debug_info(self, message: str) -> str: - debug_string = self.debug_string() - if debug_string: - message = f'{message.rstrip(".")}. {debug_string}' - return message - - @staticmethod - def parse_dsn(dsn: str) -> 'Config': - uri = urllib.parse.urlparse(dsn) - if uri.scheme != 'databricks': - raise ValueError(f'Expected databricks:// scheme, got {uri.scheme}://') - kwargs = {'host': f'https://{uri.hostname}'} - if uri.username: - kwargs['username'] = uri.username - if uri.password: - kwargs['password'] = uri.password - query = dict(urllib.parse.parse_qsl(uri.query)) - for attr in Config.attributes(): - if attr.name not in query: - continue - kwargs[attr.name] = query[attr.name] - return Config(**kwargs) - - def authenticate(self) -> Dict[str, str]: - """ Returns a list of fresh authentication headers """ - return self._header_factory() - - def as_dict(self) -> dict: - return self._inner - - @property - def is_azure(self) -> bool: - has_resource_id = self.azure_workspace_resource_id is not None - has_host = self.host is not None - is_public_cloud = has_host and ".azuredatabricks.net" in self.host - is_china_cloud = has_host and ".databricks.azure.cn" in self.host - is_gov_cloud = has_host and ".databricks.azure.us" in self.host - is_valid_cloud = is_public_cloud or is_china_cloud or is_gov_cloud - return has_resource_id or (has_host and is_valid_cloud) - - @property - def is_gcp(self) -> bool: - return self.host and ".gcp.databricks.com" in self.host - - @property - def is_aws(self) -> bool: - return not self.is_azure and not self.is_gcp - - @property - def is_account_client(self) -> bool: - if not self.host: - return False - return self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.") - - @property - def arm_environment(self) -> AzureEnvironment: - env = self.azure_environment if self.azure_environment else "PUBLIC" - try: - return ENVIRONMENTS[env] - except KeyError: - raise ValueError(f"Cannot find Azure {env} Environment") - - @property - def effective_azure_login_app_id(self): - app_id = self.azure_login_app_id - if app_id: - return app_id - return ARM_DATABRICKS_RESOURCE_ID - - @property - def hostname(self) -> str: - url = urllib.parse.urlparse(self.host) - return url.netloc - - @property - def is_any_auth_configured(self) -> bool: - for attr in Config.attributes(): - if not attr.auth: - continue - value = self._inner.get(attr.name, None) - if value: - return True - return False - - @property - def user_agent(self): - """ Returns User-Agent header used by this SDK """ - py_version = platform.python_version() - os_name = platform.uname().system.lower() - - ua = [ - f"{self._product}/{self._product_version}", f"databricks-sdk-py/{__version__}", - f"python/{py_version}", f"os/{os_name}", f"auth/{self.auth_type}", - ] - if len(self._user_agent_other_info) > 0: - ua.append(' '.join(self._user_agent_other_info)) - if len(self._upstream_user_agent) > 0: - ua.append(self._upstream_user_agent) - - return ' '.join(ua) - - @property - def _upstream_user_agent(self) -> str: - product = os.environ.get('DATABRICKS_SDK_UPSTREAM', None) - product_version = os.environ.get('DATABRICKS_SDK_UPSTREAM_VERSION', None) - if product is not None and product_version is not None: - return f"upstream/{product} upstream-version/{product_version}" - return "" - - def with_user_agent_extra(self, key: str, value: str) -> 'Config': - self._user_agent_other_info.append(f"{key}/{value}") - return self - - @property - def oidc_endpoints(self) -> Optional[OidcEndpoints]: - self._fix_host_if_needed() - if not self.host: - return None - if self.is_azure: - # Retrieve authorize endpoint to retrieve token endpoint after - res = requests.get(f'{self.host}/oidc/oauth2/v2.0/authorize', allow_redirects=False) - real_auth_url = res.headers.get('location') - if not real_auth_url: - return None - return OidcEndpoints(authorization_endpoint=real_auth_url, - token_endpoint=real_auth_url.replace('/authorize', '/token')) - if self.account_id: - prefix = f'{self.host}/oidc/accounts/{self.account_id}' - return OidcEndpoints(authorization_endpoint=f'{prefix}/v1/authorize', - token_endpoint=f'{prefix}/v1/token') - oidc = f'{self.host}/oidc/.well-known/oauth-authorization-server' - res = requests.get(oidc) - if res.status_code != 200: - return None - auth_metadata = res.json() - return OidcEndpoints(authorization_endpoint=auth_metadata.get('authorization_endpoint'), - token_endpoint=auth_metadata.get('token_endpoint')) - - def debug_string(self) -> str: - """ Returns log-friendly representation of configured attributes """ - buf = [] - attrs_used = [] - envs_used = [] - for attr in Config.attributes(): - if attr.env and os.environ.get(attr.env): - envs_used.append(attr.env) - value = getattr(self, attr.name) - if not value: - continue - safe = '***' if attr.sensitive else f'{value}' - attrs_used.append(f'{attr.name}={safe}') - if attrs_used: - buf.append(f'Config: {", ".join(attrs_used)}') - if envs_used: - buf.append(f'Env: {", ".join(envs_used)}') - return '. '.join(buf) - - def to_dict(self) -> Dict[str, any]: - return self._inner - - @property - def sql_http_path(self) -> Optional[str]: - """(Experimental) Return HTTP path for SQL Drivers. - - If `cluster_id` or `warehouse_id` are configured, return a valid HTTP Path argument - used in construction of JDBC/ODBC DSN string. - - See https://docs.databricks.com/integrations/jdbc-odbc-bi.html - """ - if (not self.cluster_id) and (not self.warehouse_id): - return None - if self.cluster_id and self.warehouse_id: - raise ValueError('cannot have both cluster_id and warehouse_id') - headers = self.authenticate() - headers['User-Agent'] = f'{self.user_agent} sdk-feature/sql-http-path' - if self.cluster_id: - response = requests.get(f"{self.host}/api/2.0/preview/scim/v2/Me", headers=headers) - # get workspace ID from the response header - workspace_id = response.headers.get('x-databricks-org-id') - return f'sql/protocolv1/o/{workspace_id}/{self.cluster_id}' - if self.warehouse_id: - return f'/sql/1.0/warehouses/{self.warehouse_id}' - - @classmethod - def attributes(cls) -> Iterable[ConfigAttribute]: - """ Returns a list of Databricks SDK configuration metadata """ - if hasattr(cls, '_attributes'): - return cls._attributes - # Python 3.7 compatibility: getting type hints require extra hop, as described in - # "Accessing The Annotations Dict Of An Object In Python 3.9 And Older" section of - # https://docs.python.org/3/howto/annotations.html - anno = cls.__dict__['__annotations__'] - attrs = [] - for name, v in cls.__dict__.items(): - if type(v) != ConfigAttribute: - continue - v.name = name - v.transform = anno.get(name, str) - attrs.append(v) - cls._attributes = attrs - return cls._attributes - - def _fix_host_if_needed(self): - if not self.host: - return - # fix url to remove trailing slash - o = urllib.parse.urlparse(self.host) - if not o.hostname: - # only hostname is specified - self.host = f"https://{self.host}" - else: - self.host = f"{o.scheme}://{o.netloc}" - - def _set_inner_config(self, keyword_args: Dict[str, any]): - for attr in self.attributes(): - if attr.name not in keyword_args: - continue - if keyword_args.get(attr.name, None) is None: - continue - # make sure that args are of correct type - self._inner[attr.name] = attr.transform(keyword_args[attr.name]) - - def _load_from_env(self): - found = False - for attr in Config.attributes(): - if not attr.env: - continue - if attr.name in self._inner: - continue - value = os.environ.get(attr.env) - if not value: - continue - self._inner[attr.name] = value - found = True - if found: - logger.debug('Loaded from environment') - - def _known_file_config_loader(self): - if not self.profile and (self.is_any_auth_configured or self.host - or self.azure_workspace_resource_id): - # skip loading configuration file if there's any auth configured - # directly as part of the Config() constructor. - return - config_file = self.config_file - if not config_file: - config_file = "~/.databrickscfg" - config_path = pathlib.Path(config_file).expanduser() - if not config_path.exists(): - logger.debug("%s does not exist", config_path) - return - ini_file = configparser.ConfigParser() - ini_file.read(config_path) - profile = self.profile - has_explicit_profile = self.profile is not None - # In Go SDK, we skip merging the profile with DEFAULT section, though Python's ConfigParser.items() - # is returning profile key-value pairs _including those from DEFAULT_. This is not what we expect - # from Unified Auth test suite at the moment. Hence, the private variable access. - # See: https://docs.python.org/3/library/configparser.html#mapping-protocol-access - if not has_explicit_profile and not ini_file.defaults(): - logger.debug(f'{config_path} has no DEFAULT profile configured') - return - if not has_explicit_profile: - profile = "DEFAULT" - profiles = ini_file._sections - if ini_file.defaults(): - profiles['DEFAULT'] = ini_file.defaults() - if profile not in profiles: - raise ValueError(f'resolve: {config_path} has no {profile} profile configured') - raw_config = profiles[profile] - logger.info(f'loading {profile} profile from {config_file}: {", ".join(raw_config.keys())}') - for k, v in raw_config.items(): - if k in self._inner: - # don't overwrite a value previously set - continue - self.__setattr__(k, v) - - def _validate(self): - auths_used = set() - for attr in Config.attributes(): - if attr.name not in self._inner: - continue - if not attr.auth: - continue - auths_used.add(attr.auth) - if len(auths_used) <= 1: - return - if self.auth_type: - # client has auth preference set - return - names = " and ".join(sorted(auths_used)) - raise ValueError(f'validate: more than one authorization method configured: {names}') - - def _init_auth(self): - try: - self._header_factory = self._credentials_provider(self) - self.auth_type = self._credentials_provider.auth_type() - if not self._header_factory: - raise ValueError('not configured') - except ValueError as e: - raise ValueError(f'{self._credentials_provider.auth_type()} auth: {e}') from e - - def __repr__(self): - return f'<{self.debug_string()}>' - - def copy(self): - """Creates a copy of the config object. - All the copies share most of their internal state (ie, shared reference to fields such as credential_provider). - Copies have their own instances of the following fields - - `_user_agent_other_info` - """ - cpy: Config = copy.copy(self) - cpy._user_agent_other_info = copy.deepcopy(self._user_agent_other_info) - return cpy - - -class DatabricksError(IOError): - """ Generic error from Databricks REST API """ - - def __init__(self, - message: str = None, - *, - error_code: str = None, - detail: str = None, - status: str = None, - scimType: str = None, - error: str = None, - **kwargs): - if error: - # API 1.2 has different response format, let's adapt - message = error - if detail: - # Handle SCIM error message details - # @see https://tools.ietf.org/html/rfc7644#section-3.7.3 - if detail == "null": - message = "SCIM API Internal Error" - else: - message = detail - # add more context from SCIM responses - message = f"{scimType} {message}".strip(" ") - error_code = f"SCIM_{status}" - super().__init__(message if message else error) - self.error_code = error_code - self.kwargs = kwargs - - -class ApiClient: - _cfg: Config - - def __init__(self, cfg: Config = None): - - if cfg is None: - cfg = Config() - - self._cfg = cfg - self._debug_truncate_bytes = cfg.debug_truncate_bytes if cfg.debug_truncate_bytes else 96 - self._user_agent_base = cfg.user_agent - - retry_strategy = Retry( - total=6, - backoff_factor=1, - status_forcelist=[429], - allowed_methods={"POST"} | set(Retry.DEFAULT_ALLOWED_METHODS), - respect_retry_after_header=True, - raise_on_status=False, # return original response when retries have been exhausted - ) - - self._session = requests.Session() - self._session.auth = self._authenticate - self._session.mount("https://", HTTPAdapter(max_retries=retry_strategy)) - - @property - def account_id(self) -> str: - return self._cfg.account_id - - @property - def is_account_client(self) -> bool: - return self._cfg.is_account_client - - def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest: - headers = self._cfg.authenticate() - for k, v in headers.items(): - r.headers[k] = v - return r - - @staticmethod - def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]: - # Convert True -> "true" for Databricks APIs to understand booleans. - # See: https://github.com/databricks/databricks-sdk-py/issues/142 - if query is None: - return None - return {k: v if type(v) != bool else ('true' if v else 'false') for k, v in query.items()} - - def do(self, - method: str, - path: str, - query: dict = None, - body: dict = None, - raw: bool = False, - files=None, - data=None) -> dict: - headers = {'Accept': 'application/json', 'User-Agent': self._user_agent_base} - response = self._session.request(method, - f"{self._cfg.host}{path}", - params=self._fix_query_string(query), - json=body, - headers=headers, - files=files, - data=data, - stream=True if raw else False) - try: - self._record_request_log(response, raw=raw or data is not None or files is not None) - if not response.ok: - # TODO: experiment with traceback pruning for better readability - # See https://stackoverflow.com/a/58821552/277035 - payload = response.json() - raise self._make_nicer_error(status_code=response.status_code, **payload) from None - if raw: - return response.raw - if not len(response.content): - return {} - return response.json() - except requests.exceptions.JSONDecodeError: - message = self._make_sense_from_html(response.text) - if not message: - message = response.reason - raise self._make_nicer_error(message=message) from None - - @staticmethod - def _make_sense_from_html(txt: str) -> str: - matchers = [r'
(.*)
', r'(.*)'] - for attempt in matchers: - expr = re.compile(attempt, re.MULTILINE) - match = expr.search(txt) - if not match: - continue - return match.group(1).strip() - return txt - - def _make_nicer_error(self, status_code: int = 200, **kwargs) -> DatabricksError: - message = kwargs.get('message', 'request failed') - is_http_unauthorized_or_forbidden = status_code in (401, 403) - if is_http_unauthorized_or_forbidden: - message = self._cfg.wrap_debug_info(message) - kwargs['message'] = message - return DatabricksError(**kwargs) - - def _record_request_log(self, response: requests.Response, raw=False): - if not logger.isEnabledFor(logging.DEBUG): - return - request = response.request - url = urllib.parse.urlparse(request.url) - query = '' - if url.query: - query = f'?{urllib.parse.unquote(url.query)}' - sb = [f'{request.method} {urllib.parse.unquote(url.path)}{query}'] - if self._cfg.debug_headers: - if self._cfg.host: - sb.append(f'> * Host: {self._cfg.host}') - for k, v in request.headers.items(): - sb.append(f'> * {k}: {self._only_n_bytes(v, self._debug_truncate_bytes)}') - if request.body: - sb.append("> [raw stream]" if raw else self._redacted_dump("> ", request.body)) - sb.append(f'< {response.status_code} {response.reason}') - if raw and response.headers.get('Content-Type', None) != 'application/json': - # Raw streams with `Transfer-Encoding: chunked` do not have `Content-Type` header - sb.append("< [raw stream]") - elif response.content: - sb.append(self._redacted_dump("< ", response.content)) - logger.debug("\n".join(sb)) - - @staticmethod - def _mask(m: Dict[str, any]): - for k in m: - if k in {'bytes_value', 'string_value', 'token_value', 'value', 'content'}: - m[k] = "**REDACTED**" - - @staticmethod - def _map_keys(m: Dict[str, any]) -> List[str]: - keys = list(m.keys()) - keys.sort() - return keys - - @staticmethod - def _only_n_bytes(j: str, num_bytes: int = 96) -> str: - diff = len(j.encode('utf-8')) - num_bytes - if diff > 0: - return f"{j[:num_bytes]}... ({diff} more bytes)" - return j - - def _recursive_marshal_dict(self, m, budget) -> dict: - out = {} - self._mask(m) - for k in sorted(m.keys()): - raw = self._recursive_marshal(m[k], budget) - out[k] = raw - budget -= len(str(raw)) - return out - - def _recursive_marshal_list(self, s, budget) -> list: - out = [] - for i in range(len(s)): - if i > 0 >= budget: - out.append("... (%d additional elements)" % (len(s) - len(out))) - break - raw = self._recursive_marshal(s[i], budget) - out.append(raw) - budget -= len(str(raw)) - return out - - def _recursive_marshal(self, v: any, budget: int) -> any: - if isinstance(v, dict): - return self._recursive_marshal_dict(v, budget) - elif isinstance(v, list): - return self._recursive_marshal_list(v, budget) - elif isinstance(v, str): - return self._only_n_bytes(v, self._debug_truncate_bytes) - else: - return v - - def _redacted_dump(self, prefix: str, body: str) -> str: - if len(body) == 0: - return "" - try: - # Unmarshal body into primitive types. - tmp = json.loads(body) - max_bytes = 96 - if self._debug_truncate_bytes > max_bytes: - max_bytes = self._debug_truncate_bytes - # Re-marshal body taking redaction and character limit into account. - raw = self._recursive_marshal(tmp, max_bytes) - return "\n".join([f'{prefix}{line}' for line in json.dumps(raw, indent=2).split("\n")]) - except JSONDecodeError: - return f'{prefix}[non-JSON document of {len(body)} bytes]' +import abc +import base64 +import configparser +import copy +import functools +import json +import logging +import os +import pathlib +import platform +import re +import subprocess +import sys +import urllib.parse +from datetime import datetime +from json import JSONDecodeError +from typing import Callable, Dict, Iterable, List, Optional, Union + +import requests +import requests.auth +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +from .azure import ARM_DATABRICKS_RESOURCE_ID, ENVIRONMENTS, AzureEnvironment +from .oauth import (ClientCredentials, OAuthClient, OidcEndpoints, Refreshable, + Token, TokenCache, TokenSource) +from .version import __version__ + +__all__ = ['Config', 'DatabricksError'] + +logger = logging.getLogger('databricks.sdk') + +HeaderFactory = Callable[[], Dict[str, str]] + + +class CredentialsProvider(abc.ABC): + """ CredentialsProvider is the protocol (call-side interface) + for authenticating requests to Databricks REST APIs""" + + @abc.abstractmethod + def auth_type(self) -> str: + ... + + @abc.abstractmethod + def __call__(self, cfg: 'Config') -> HeaderFactory: + ... + + +def credentials_provider(name: str, require: List[str]): + """ Given the function that receives a Config and returns RequestVisitor, + create CredentialsProvider with a given name and required configuration + attribute names to be present for this function to be called. """ + + def inner(func: Callable[['Config'], HeaderFactory]) -> CredentialsProvider: + + @functools.wraps(func) + def wrapper(cfg: 'Config') -> Optional[HeaderFactory]: + for attr in require: + if not getattr(cfg, attr): + return None + return func(cfg) + + wrapper.auth_type = lambda: name + return wrapper + + return inner + + +@credentials_provider('basic', ['host', 'username', 'password']) +def basic_auth(cfg: 'Config') -> HeaderFactory: + """ Given username and password, add base64-encoded Basic credentials """ + encoded = base64.b64encode(f'{cfg.username}:{cfg.password}'.encode()).decode() + static_credentials = {'Authorization': f'Basic {encoded}'} + + def inner() -> Dict[str, str]: + return static_credentials + + return inner + + +@credentials_provider('pat', ['host', 'token']) +def pat_auth(cfg: 'Config') -> HeaderFactory: + """ Adds Databricks Personal Access Token to every request """ + static_credentials = {'Authorization': f'Bearer {cfg.token}'} + + def inner() -> Dict[str, str]: + return static_credentials + + return inner + + +@credentials_provider('runtime', []) +def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]: + from databricks.sdk.runtime import init_runtime_native_auth + if init_runtime_native_auth is not None: + host, inner = init_runtime_native_auth() + cfg.host = host + return inner + try: + from dbruntime.databricks_repl_context import get_context + ctx = get_context() + if ctx is None: + logger.debug('Empty REPL context returned, skipping runtime auth') + return None + cfg.host = f'https://{ctx.workspaceUrl}' + + def inner() -> Dict[str, str]: + ctx = get_context() + return {'Authorization': f'Bearer {ctx.apiToken}'} + + return inner + except ImportError: + return None + + +@credentials_provider('oauth-m2m', ['is_aws', 'host', 'client_id', 'client_secret']) +def oauth_service_principal(cfg: 'Config') -> Optional[HeaderFactory]: + """ Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request, + if /oidc/.well-known/oauth-authorization-server is available on the given host. """ + # TODO: Azure returns 404 for UC workspace after redirecting to + # https://login.microsoftonline.com/{cfg.azure_tenant_id}/.well-known/oauth-authorization-server + oidc = cfg.oidc_endpoints + if oidc is None: + return None + token_source = ClientCredentials(client_id=cfg.client_id, + client_secret=cfg.client_secret, + token_url=oidc.token_endpoint, + scopes=["all-apis"], + use_header=True) + + def inner() -> Dict[str, str]: + token = token_source.token() + return {'Authorization': f'{token.token_type} {token.access_token}'} + + return inner + + +@credentials_provider('external-browser', ['host', 'auth_type']) +def external_browser(cfg: 'Config') -> Optional[HeaderFactory]: + if cfg.auth_type != 'external-browser': + return None + if cfg.client_id: + client_id = cfg.client_id + elif cfg.is_aws: + client_id = 'databricks-cli' + elif cfg.is_azure: + # Use Azure AD app for cases when Azure CLI is not available on the machine. + # App has to be registered as Single-page multi-tenant to support PKCE + # TODO: temporary app ID, change it later. + client_id = '6128a518-99a9-425b-8333-4cc94f04cacd' + else: + raise ValueError(f'local browser SSO is not supported') + oauth_client = OAuthClient(host=cfg.host, + client_id=client_id, + redirect_url='http://localhost:8020', + client_secret=cfg.client_secret) + + # Load cached credentials from disk if they exist. + # Note that these are local to the Python SDK and not reused by other SDKs. + token_cache = TokenCache(oauth_client) + credentials = token_cache.load() + if credentials: + # Force a refresh in case the loaded credentials are expired. + credentials.token() + else: + consent = oauth_client.initiate_consent() + if not consent: + return None + credentials = consent.launch_external_browser() + token_cache.save(credentials) + return credentials(cfg) + + +def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenSource]): + """ Resolves Azure Databricks workspace URL from ARM Resource ID """ + if cfg.host: + return + if not cfg.azure_workspace_resource_id: + return + arm = cfg.arm_environment.resource_manager_endpoint + token = token_source_for(arm).token() + resp = requests.get(f"{arm}{cfg.azure_workspace_resource_id}?api-version=2018-04-01", + headers={"Authorization": f"Bearer {token.access_token}"}) + if not resp.ok: + raise ValueError(f"Cannot resolve Azure Databricks workspace: {resp.content}") + cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}" + + +@credentials_provider('azure-client-secret', + ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id']) +def azure_service_principal(cfg: 'Config') -> HeaderFactory: + """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens + to every request, while automatically resolving different Azure environment endpoints. """ + + def token_source_for(resource: str) -> TokenSource: + aad_endpoint = cfg.arm_environment.active_directory_endpoint + return ClientCredentials(client_id=cfg.azure_client_id, + client_secret=cfg.azure_client_secret, + token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token", + endpoint_params={"resource": resource}, + use_params=True) + + _ensure_host_present(cfg, token_source_for) + logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id) + inner = token_source_for(cfg.effective_azure_login_app_id) + cloud = token_source_for(cfg.arm_environment.service_management_endpoint) + + def refreshed_headers() -> Dict[str, str]: + headers = { + 'Authorization': f"Bearer {inner.token().access_token}", + 'X-Databricks-Azure-SP-Management-Token': cloud.token().access_token, + } + if cfg.azure_workspace_resource_id: + headers["X-Databricks-Azure-Workspace-Resource-Id"] = cfg.azure_workspace_resource_id + return headers + + return refreshed_headers + + +class CliTokenSource(Refreshable): + + def __init__(self, cmd: List[str], token_type_field: str, access_token_field: str, expiry_field: str): + super().__init__() + self._cmd = cmd + self._token_type_field = token_type_field + self._access_token_field = access_token_field + self._expiry_field = expiry_field + + @staticmethod + def _parse_expiry(expiry: str) -> datetime: + for fmt in ("%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S.%f%z"): + try: + return datetime.strptime(expiry, fmt) + except ValueError as e: + last_e = e + if last_e: + raise last_e + + def refresh(self) -> Token: + try: + is_windows = sys.platform.startswith('win') + # windows requires shell=True to be able to execute 'az login' or other commands + # cannot use shell=True all the time, as it breaks macOS + out = subprocess.check_output(self._cmd, stderr=subprocess.STDOUT, shell=is_windows) + it = json.loads(out.decode()) + expires_on = self._parse_expiry(it[self._expiry_field]) + return Token(access_token=it[self._access_token_field], + token_type=it[self._token_type_field], + expiry=expires_on) + except ValueError as e: + raise ValueError(f"cannot unmarshal CLI result: {e}") + except subprocess.CalledProcessError as e: + message = e.output.decode().strip() + raise IOError(f'cannot get access token: {message}') from e + + +class AzureCliTokenSource(CliTokenSource): + """ Obtain the token granted by `az login` CLI command """ + + def __init__(self, resource: str): + cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"] + super().__init__(cmd=cmd, + token_type_field='tokenType', + access_token_field='accessToken', + expiry_field='expiresOn') + + +@credentials_provider('azure-cli', ['is_azure']) +def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]: + """ Adds refreshed OAuth token granted by `az login` command to every request. """ + token_source = AzureCliTokenSource(cfg.effective_azure_login_app_id) + try: + token_source.token() + except FileNotFoundError: + doc = 'https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest' + logger.debug(f'Most likely Azure CLI is not installed. See {doc} for details') + return None + + _ensure_host_present(cfg, lambda resource: AzureCliTokenSource(resource)) + logger.info("Using Azure CLI authentication with AAD tokens") + + def inner() -> Dict[str, str]: + token = token_source.token() + return {'Authorization': f'{token.token_type} {token.access_token}'} + + return inner + + +class DatabricksCliTokenSource(CliTokenSource): + """ Obtain the token granted by `databricks auth login` CLI command """ + + def __init__(self, cfg: 'Config'): + args = ['auth', 'token', '--host', cfg.host] + if cfg.is_account_client: + args += ['--account-id', cfg.account_id] + + cli_path = cfg.databricks_cli_path + if not cli_path: + cli_path = 'databricks' + + # If the path is unqualified, look it up in PATH. + if cli_path.count("/") == 0: + cli_path = self.__class__._find_executable(cli_path) + + super().__init__(cmd=[cli_path, *args], + token_type_field='token_type', + access_token_field='access_token', + expiry_field='expiry') + + @staticmethod + def _find_executable(name) -> str: + err = FileNotFoundError("Most likely the Databricks CLI is not installed") + for dir in os.getenv("PATH", default="").split(os.path.pathsep): + path = pathlib.Path(dir).joinpath(name).resolve() + if not path.is_file(): + continue + + # The new Databricks CLI is a single binary with size > 1MB. + # We use the size as a signal to determine which Databricks CLI is installed. + stat = path.stat() + if stat.st_size < (1024 * 1024): + err = FileNotFoundError("Databricks CLI version <0.100.0 detected") + continue + + return str(path) + + raise err + + +@credentials_provider('databricks-cli', ['host', 'is_aws']) +def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]: + try: + token_source = DatabricksCliTokenSource(cfg) + except FileNotFoundError as e: + logger.debug(e) + return None + + try: + token_source.token() + except IOError as e: + if 'databricks OAuth is not' in str(e): + logger.debug(f'OAuth not configured or not available: {e}') + return None + raise e + + logger.info("Using Databricks CLI authentication") + + def inner() -> Dict[str, str]: + token = token_source.token() + return {'Authorization': f'{token.token_type} {token.access_token}'} + + return inner + + +class MetadataServiceTokenSource(Refreshable): + """ Obtain the token granted by Databricks Metadata Service """ + METADATA_SERVICE_VERSION = "1" + METADATA_SERVICE_VERSION_HEADER = "X-Databricks-Metadata-Version" + METADATA_SERVICE_HOST_HEADER = "X-Databricks-Host" + _metadata_service_timeout = 10 # seconds + + def __init__(self, cfg: 'Config'): + super().__init__() + self.url = cfg.metadata_service_url + self.host = cfg.host + + def refresh(self) -> Token: + resp = requests.get(self.url, + timeout=self._metadata_service_timeout, + headers={ + self.METADATA_SERVICE_VERSION_HEADER: self.METADATA_SERVICE_VERSION, + self.METADATA_SERVICE_HOST_HEADER: self.host + }) + json_resp: dict[str, Union[str, float]] = resp.json() + access_token = json_resp.get("access_token", None) + if access_token is None: + raise ValueError("Metadata Service returned empty token") + token_type = json_resp.get("token_type", None) + if token_type is None: + raise ValueError("Metadata Service returned empty token type") + if json_resp["expires_on"] in ["", None]: + raise ValueError("Metadata Service returned invalid expiry") + try: + expiry = datetime.fromtimestamp(json_resp["expires_on"]) + except: + raise ValueError("Metadata Service returned invalid expiry") + + return Token(access_token=access_token, token_type=token_type, expiry=expiry) + + +@credentials_provider('metadata-service', ['host', 'metadata_service_url']) +def metadata_service(cfg: 'Config') -> Optional[HeaderFactory]: + """ Adds refreshed token granted by Databricks Metadata Service to every request. """ + + token_source = MetadataServiceTokenSource(cfg) + token_source.token() + logger.info("Using Databricks Metadata Service authentication") + + def inner() -> Dict[str, str]: + token = token_source.token() + return {'Authorization': f'{token.token_type} {token.access_token}'} + + return inner + + +class DefaultCredentials: + """ Select the first applicable credential provider from the chain """ + + def __init__(self) -> None: + self._auth_type = 'default' + + def auth_type(self) -> str: + return self._auth_type + + def __call__(self, cfg: 'Config') -> HeaderFactory: + auth_providers = [ + pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal, + azure_cli, external_browser, databricks_cli, runtime_native_auth + ] + for provider in auth_providers: + auth_type = provider.auth_type() + if cfg.auth_type and auth_type != cfg.auth_type: + # ignore other auth types if one is explicitly enforced + logger.debug(f"Ignoring {auth_type} auth, because {cfg.auth_type} is preferred") + continue + logger.debug(f'Attempting to configure auth: {auth_type}') + try: + header_factory = provider(cfg) + if not header_factory: + continue + self._auth_type = auth_type + return header_factory + except Exception as e: + raise ValueError(f'{auth_type}: {e}') from e + raise ValueError('cannot configure default credentials') + + +class ConfigAttribute: + """ Configuration attribute metadata and descriptor protocols. """ + + # name and transform are discovered from Config.__new__ + name: str = None + transform: type = str + + def __init__(self, env: str = None, auth: str = None, sensitive: bool = False): + self.env = env + self.auth = auth + self.sensitive = sensitive + + def __get__(self, cfg: 'Config', owner): + if not cfg: + return None + return cfg._inner.get(self.name, None) + + def __set__(self, cfg: 'Config', value: any): + cfg._inner[self.name] = self.transform(value) + + def __repr__(self) -> str: + return f"" + + +class Config: + host = ConfigAttribute(env='DATABRICKS_HOST') + account_id = ConfigAttribute(env='DATABRICKS_ACCOUNT_ID') + token = ConfigAttribute(env='DATABRICKS_TOKEN', auth='pat', sensitive=True) + username = ConfigAttribute(env='DATABRICKS_USERNAME', auth='basic') + password = ConfigAttribute(env='DATABRICKS_PASSWORD', auth='basic', sensitive=True) + client_id = ConfigAttribute(env='DATABRICKS_CLIENT_ID', auth='oauth') + client_secret = ConfigAttribute(env='DATABRICKS_CLIENT_SECRET', auth='oauth', sensitive=True) + profile = ConfigAttribute(env='DATABRICKS_CONFIG_PROFILE') + config_file = ConfigAttribute(env='DATABRICKS_CONFIG_FILE') + google_service_account = ConfigAttribute(env='DATABRICKS_GOOGLE_SERVICE_ACCOUNT', auth='google') + google_credentials = ConfigAttribute(env='GOOGLE_CREDENTIALS', auth='google', sensitive=True) + azure_workspace_resource_id = ConfigAttribute(env='DATABRICKS_AZURE_RESOURCE_ID', auth='azure') + azure_use_msi: bool = ConfigAttribute(env='ARM_USE_MSI', auth='azure') + azure_client_secret = ConfigAttribute(env='ARM_CLIENT_SECRET', auth='azure', sensitive=True) + azure_client_id = ConfigAttribute(env='ARM_CLIENT_ID', auth='azure') + azure_tenant_id = ConfigAttribute(env='ARM_TENANT_ID', auth='azure') + azure_environment = ConfigAttribute(env='ARM_ENVIRONMENT') + azure_login_app_id = ConfigAttribute(env='DATABRICKS_AZURE_LOGIN_APP_ID', auth='azure') + databricks_cli_path = ConfigAttribute(env='DATABRICKS_CLI_PATH') + auth_type = ConfigAttribute(env='DATABRICKS_AUTH_TYPE') + cluster_id = ConfigAttribute(env='DATABRICKS_CLUSTER_ID') + warehouse_id = ConfigAttribute(env='DATABRICKS_WAREHOUSE_ID') + skip_verify: bool = ConfigAttribute() + http_timeout_seconds: int = ConfigAttribute() + debug_truncate_bytes: int = ConfigAttribute(env='DATABRICKS_DEBUG_TRUNCATE_BYTES') + debug_headers: bool = ConfigAttribute(env='DATABRICKS_DEBUG_HEADERS') + rate_limit: int = ConfigAttribute(env='DATABRICKS_RATE_LIMIT') + retry_timeout_seconds: int = ConfigAttribute() + metadata_service_url = ConfigAttribute(env='DATABRICKS_METADATA_SERVICE_URL', + auth='metadata-service', + sensitive=True) + + def __init__(self, + *, + credentials_provider: CredentialsProvider = None, + product="unknown", + product_version="0.0.0", + **kwargs): + self._inner = {} + self._user_agent_other_info = [] + self._credentials_provider = credentials_provider if credentials_provider else DefaultCredentials() + try: + self._set_inner_config(kwargs) + self._load_from_env() + self._known_file_config_loader() + self._fix_host_if_needed() + self._validate() + self._init_auth() + self._product = product + self._product_version = product_version + except ValueError as e: + message = self.wrap_debug_info(str(e)) + raise ValueError(message) from e + + def wrap_debug_info(self, message: str) -> str: + debug_string = self.debug_string() + if debug_string: + message = f'{message.rstrip(".")}. {debug_string}' + return message + + @staticmethod + def parse_dsn(dsn: str) -> 'Config': + uri = urllib.parse.urlparse(dsn) + if uri.scheme != 'databricks': + raise ValueError(f'Expected databricks:// scheme, got {uri.scheme}://') + kwargs = {'host': f'https://{uri.hostname}'} + if uri.username: + kwargs['username'] = uri.username + if uri.password: + kwargs['password'] = uri.password + query = dict(urllib.parse.parse_qsl(uri.query)) + for attr in Config.attributes(): + if attr.name not in query: + continue + kwargs[attr.name] = query[attr.name] + return Config(**kwargs) + + def authenticate(self) -> Dict[str, str]: + """ Returns a list of fresh authentication headers """ + return self._header_factory() + + def as_dict(self) -> dict: + return self._inner + + @property + def is_azure(self) -> bool: + has_resource_id = self.azure_workspace_resource_id is not None + has_host = self.host is not None + is_public_cloud = has_host and ".azuredatabricks.net" in self.host + is_china_cloud = has_host and ".databricks.azure.cn" in self.host + is_gov_cloud = has_host and ".databricks.azure.us" in self.host + is_valid_cloud = is_public_cloud or is_china_cloud or is_gov_cloud + return has_resource_id or (has_host and is_valid_cloud) + + @property + def is_gcp(self) -> bool: + return self.host and ".gcp.databricks.com" in self.host + + @property + def is_aws(self) -> bool: + return not self.is_azure and not self.is_gcp + + @property + def is_account_client(self) -> bool: + if not self.host: + return False + return self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.") + + @property + def arm_environment(self) -> AzureEnvironment: + env = self.azure_environment if self.azure_environment else "PUBLIC" + try: + return ENVIRONMENTS[env] + except KeyError: + raise ValueError(f"Cannot find Azure {env} Environment") + + @property + def effective_azure_login_app_id(self): + app_id = self.azure_login_app_id + if app_id: + return app_id + return ARM_DATABRICKS_RESOURCE_ID + + @property + def hostname(self) -> str: + url = urllib.parse.urlparse(self.host) + return url.netloc + + @property + def is_any_auth_configured(self) -> bool: + for attr in Config.attributes(): + if not attr.auth: + continue + value = self._inner.get(attr.name, None) + if value: + return True + return False + + @property + def user_agent(self): + """ Returns User-Agent header used by this SDK """ + py_version = platform.python_version() + os_name = platform.uname().system.lower() + + ua = [ + f"{self._product}/{self._product_version}", f"databricks-sdk-py/{__version__}", + f"python/{py_version}", f"os/{os_name}", f"auth/{self.auth_type}", + ] + if len(self._user_agent_other_info) > 0: + ua.append(' '.join(self._user_agent_other_info)) + if len(self._upstream_user_agent) > 0: + ua.append(self._upstream_user_agent) + + return ' '.join(ua) + + @property + def _upstream_user_agent(self) -> str: + product = os.environ.get('DATABRICKS_SDK_UPSTREAM', None) + product_version = os.environ.get('DATABRICKS_SDK_UPSTREAM_VERSION', None) + if product is not None and product_version is not None: + return f"upstream/{product} upstream-version/{product_version}" + return "" + + def with_user_agent_extra(self, key: str, value: str) -> 'Config': + self._user_agent_other_info.append(f"{key}/{value}") + return self + + @property + def oidc_endpoints(self) -> Optional[OidcEndpoints]: + self._fix_host_if_needed() + if not self.host: + return None + if self.is_azure: + # Retrieve authorize endpoint to retrieve token endpoint after + res = requests.get(f'{self.host}/oidc/oauth2/v2.0/authorize', allow_redirects=False) + real_auth_url = res.headers.get('location') + if not real_auth_url: + return None + return OidcEndpoints(authorization_endpoint=real_auth_url, + token_endpoint=real_auth_url.replace('/authorize', '/token')) + if self.account_id: + prefix = f'{self.host}/oidc/accounts/{self.account_id}' + return OidcEndpoints(authorization_endpoint=f'{prefix}/v1/authorize', + token_endpoint=f'{prefix}/v1/token') + oidc = f'{self.host}/oidc/.well-known/oauth-authorization-server' + res = requests.get(oidc) + if res.status_code != 200: + return None + auth_metadata = res.json() + return OidcEndpoints(authorization_endpoint=auth_metadata.get('authorization_endpoint'), + token_endpoint=auth_metadata.get('token_endpoint')) + + def debug_string(self) -> str: + """ Returns log-friendly representation of configured attributes """ + buf = [] + attrs_used = [] + envs_used = [] + for attr in Config.attributes(): + if attr.env and os.environ.get(attr.env): + envs_used.append(attr.env) + value = getattr(self, attr.name) + if not value: + continue + safe = '***' if attr.sensitive else f'{value}' + attrs_used.append(f'{attr.name}={safe}') + if attrs_used: + buf.append(f'Config: {", ".join(attrs_used)}') + if envs_used: + buf.append(f'Env: {", ".join(envs_used)}') + return '. '.join(buf) + + def to_dict(self) -> Dict[str, any]: + return self._inner + + @property + def sql_http_path(self) -> Optional[str]: + """(Experimental) Return HTTP path for SQL Drivers. + + If `cluster_id` or `warehouse_id` are configured, return a valid HTTP Path argument + used in construction of JDBC/ODBC DSN string. + + See https://docs.databricks.com/integrations/jdbc-odbc-bi.html + """ + if (not self.cluster_id) and (not self.warehouse_id): + return None + if self.cluster_id and self.warehouse_id: + raise ValueError('cannot have both cluster_id and warehouse_id') + headers = self.authenticate() + headers['User-Agent'] = f'{self.user_agent} sdk-feature/sql-http-path' + if self.cluster_id: + response = requests.get(f"{self.host}/api/2.0/preview/scim/v2/Me", headers=headers) + # get workspace ID from the response header + workspace_id = response.headers.get('x-databricks-org-id') + return f'sql/protocolv1/o/{workspace_id}/{self.cluster_id}' + if self.warehouse_id: + return f'/sql/1.0/warehouses/{self.warehouse_id}' + + @classmethod + def attributes(cls) -> Iterable[ConfigAttribute]: + """ Returns a list of Databricks SDK configuration metadata """ + if hasattr(cls, '_attributes'): + return cls._attributes + if sys.version_info[1] >= 10: + import inspect + anno = inspect.get_annotations(cls) + else: + # Python 3.7 compatibility: getting type hints require extra hop, as described in + # "Accessing The Annotations Dict Of An Object In Python 3.9 And Older" section of + # https://docs.python.org/3/howto/annotations.html + anno = cls.__dict__['__annotations__'] + attrs = [] + for name, v in cls.__dict__.items(): + if type(v) != ConfigAttribute: + continue + v.name = name + v.transform = anno.get(name, str) + attrs.append(v) + cls._attributes = attrs + return cls._attributes + + def _fix_host_if_needed(self): + if not self.host: + return + # fix url to remove trailing slash + o = urllib.parse.urlparse(self.host) + if not o.hostname: + # only hostname is specified + self.host = f"https://{self.host}" + else: + self.host = f"{o.scheme}://{o.netloc}" + + def _set_inner_config(self, keyword_args: Dict[str, any]): + for attr in self.attributes(): + if attr.name not in keyword_args: + continue + if keyword_args.get(attr.name, None) is None: + continue + # make sure that args are of correct type + self._inner[attr.name] = attr.transform(keyword_args[attr.name]) + + def _load_from_env(self): + found = False + for attr in Config.attributes(): + if not attr.env: + continue + if attr.name in self._inner: + continue + value = os.environ.get(attr.env) + if not value: + continue + self._inner[attr.name] = value + found = True + if found: + logger.debug('Loaded from environment') + + def _known_file_config_loader(self): + if not self.profile and (self.is_any_auth_configured or self.host + or self.azure_workspace_resource_id): + # skip loading configuration file if there's any auth configured + # directly as part of the Config() constructor. + return + config_file = self.config_file + if not config_file: + config_file = "~/.databrickscfg" + config_path = pathlib.Path(config_file).expanduser() + if not config_path.exists(): + logger.debug("%s does not exist", config_path) + return + ini_file = configparser.ConfigParser() + ini_file.read(config_path) + profile = self.profile + has_explicit_profile = self.profile is not None + # In Go SDK, we skip merging the profile with DEFAULT section, though Python's ConfigParser.items() + # is returning profile key-value pairs _including those from DEFAULT_. This is not what we expect + # from Unified Auth test suite at the moment. Hence, the private variable access. + # See: https://docs.python.org/3/library/configparser.html#mapping-protocol-access + if not has_explicit_profile and not ini_file.defaults(): + logger.debug(f'{config_path} has no DEFAULT profile configured') + return + if not has_explicit_profile: + profile = "DEFAULT" + profiles = ini_file._sections + if ini_file.defaults(): + profiles['DEFAULT'] = ini_file.defaults() + if profile not in profiles: + raise ValueError(f'resolve: {config_path} has no {profile} profile configured') + raw_config = profiles[profile] + logger.info(f'loading {profile} profile from {config_file}: {", ".join(raw_config.keys())}') + for k, v in raw_config.items(): + if k in self._inner: + # don't overwrite a value previously set + continue + self.__setattr__(k, v) + + def _validate(self): + auths_used = set() + for attr in Config.attributes(): + if attr.name not in self._inner: + continue + if not attr.auth: + continue + auths_used.add(attr.auth) + if len(auths_used) <= 1: + return + if self.auth_type: + # client has auth preference set + return + names = " and ".join(sorted(auths_used)) + raise ValueError(f'validate: more than one authorization method configured: {names}') + + def _init_auth(self): + try: + self._header_factory = self._credentials_provider(self) + self.auth_type = self._credentials_provider.auth_type() + if not self._header_factory: + raise ValueError('not configured') + except ValueError as e: + raise ValueError(f'{self._credentials_provider.auth_type()} auth: {e}') from e + + def __repr__(self): + return f'<{self.debug_string()}>' + + def copy(self): + """Creates a copy of the config object. + All the copies share most of their internal state (ie, shared reference to fields such as credential_provider). + Copies have their own instances of the following fields + - `_user_agent_other_info` + """ + cpy: Config = copy.copy(self) + cpy._user_agent_other_info = copy.deepcopy(self._user_agent_other_info) + return cpy + + +class DatabricksError(IOError): + """ Generic error from Databricks REST API """ + + def __init__(self, + message: str = None, + *, + error_code: str = None, + detail: str = None, + status: str = None, + scimType: str = None, + error: str = None, + **kwargs): + if error: + # API 1.2 has different response format, let's adapt + message = error + if detail: + # Handle SCIM error message details + # @see https://tools.ietf.org/html/rfc7644#section-3.7.3 + if detail == "null": + message = "SCIM API Internal Error" + else: + message = detail + # add more context from SCIM responses + message = f"{scimType} {message}".strip(" ") + error_code = f"SCIM_{status}" + super().__init__(message if message else error) + self.error_code = error_code + self.kwargs = kwargs + + +class ApiClient: + _cfg: Config + + def __init__(self, cfg: Config = None): + + if cfg is None: + cfg = Config() + + self._cfg = cfg + self._debug_truncate_bytes = cfg.debug_truncate_bytes if cfg.debug_truncate_bytes else 96 + self._user_agent_base = cfg.user_agent + + retry_strategy = Retry( + total=6, + backoff_factor=1, + status_forcelist=[429], + allowed_methods={"POST"} | set(Retry.DEFAULT_ALLOWED_METHODS), + respect_retry_after_header=True, + raise_on_status=False, # return original response when retries have been exhausted + ) + + self._session = requests.Session() + self._session.auth = self._authenticate + self._session.mount("https://", HTTPAdapter(max_retries=retry_strategy)) + + @property + def account_id(self) -> str: + return self._cfg.account_id + + @property + def is_account_client(self) -> bool: + return self._cfg.is_account_client + + def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest: + headers = self._cfg.authenticate() + for k, v in headers.items(): + r.headers[k] = v + return r + + @staticmethod + def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]: + # Convert True -> "true" for Databricks APIs to understand booleans. + # See: https://github.com/databricks/databricks-sdk-py/issues/142 + if query is None: + return None + return {k: v if type(v) != bool else ('true' if v else 'false') for k, v in query.items()} + + def do(self, + method: str, + path: str, + query: dict = None, + body: dict = None, + raw: bool = False, + files=None, + data=None) -> dict: + headers = {'Accept': 'application/json', 'User-Agent': self._user_agent_base} + response = self._session.request(method, + f"{self._cfg.host}{path}", + params=self._fix_query_string(query), + json=body, + headers=headers, + files=files, + data=data, + stream=True if raw else False) + try: + self._record_request_log(response, raw=raw or data is not None or files is not None) + if not response.ok: + # TODO: experiment with traceback pruning for better readability + # See https://stackoverflow.com/a/58821552/277035 + payload = response.json() + raise self._make_nicer_error(status_code=response.status_code, **payload) from None + if raw: + return response.raw + if not len(response.content): + return {} + return response.json() + except requests.exceptions.JSONDecodeError: + message = self._make_sense_from_html(response.text) + if not message: + message = response.reason + raise self._make_nicer_error(message=message) from None + + @staticmethod + def _make_sense_from_html(txt: str) -> str: + matchers = [r'
(.*)
', r'(.*)'] + for attempt in matchers: + expr = re.compile(attempt, re.MULTILINE) + match = expr.search(txt) + if not match: + continue + return match.group(1).strip() + return txt + + def _make_nicer_error(self, status_code: int = 200, **kwargs) -> DatabricksError: + message = kwargs.get('message', 'request failed') + is_http_unauthorized_or_forbidden = status_code in (401, 403) + if is_http_unauthorized_or_forbidden: + message = self._cfg.wrap_debug_info(message) + kwargs['message'] = message + return DatabricksError(**kwargs) + + def _record_request_log(self, response: requests.Response, raw=False): + if not logger.isEnabledFor(logging.DEBUG): + return + request = response.request + url = urllib.parse.urlparse(request.url) + query = '' + if url.query: + query = f'?{urllib.parse.unquote(url.query)}' + sb = [f'{request.method} {urllib.parse.unquote(url.path)}{query}'] + if self._cfg.debug_headers: + if self._cfg.host: + sb.append(f'> * Host: {self._cfg.host}') + for k, v in request.headers.items(): + sb.append(f'> * {k}: {self._only_n_bytes(v, self._debug_truncate_bytes)}') + if request.body: + sb.append("> [raw stream]" if raw else self._redacted_dump("> ", request.body)) + sb.append(f'< {response.status_code} {response.reason}') + if raw and response.headers.get('Content-Type', None) != 'application/json': + # Raw streams with `Transfer-Encoding: chunked` do not have `Content-Type` header + sb.append("< [raw stream]") + elif response.content: + sb.append(self._redacted_dump("< ", response.content)) + logger.debug("\n".join(sb)) + + @staticmethod + def _mask(m: Dict[str, any]): + for k in m: + if k in {'bytes_value', 'string_value', 'token_value', 'value', 'content'}: + m[k] = "**REDACTED**" + + @staticmethod + def _map_keys(m: Dict[str, any]) -> List[str]: + keys = list(m.keys()) + keys.sort() + return keys + + @staticmethod + def _only_n_bytes(j: str, num_bytes: int = 96) -> str: + diff = len(j.encode('utf-8')) - num_bytes + if diff > 0: + return f"{j[:num_bytes]}... ({diff} more bytes)" + return j + + def _recursive_marshal_dict(self, m, budget) -> dict: + out = {} + self._mask(m) + for k in sorted(m.keys()): + raw = self._recursive_marshal(m[k], budget) + out[k] = raw + budget -= len(str(raw)) + return out + + def _recursive_marshal_list(self, s, budget) -> list: + out = [] + for i in range(len(s)): + if i > 0 >= budget: + out.append("... (%d additional elements)" % (len(s) - len(out))) + break + raw = self._recursive_marshal(s[i], budget) + out.append(raw) + budget -= len(str(raw)) + return out + + def _recursive_marshal(self, v: any, budget: int) -> any: + if isinstance(v, dict): + return self._recursive_marshal_dict(v, budget) + elif isinstance(v, list): + return self._recursive_marshal_list(v, budget) + elif isinstance(v, str): + return self._only_n_bytes(v, self._debug_truncate_bytes) + else: + return v + + def _redacted_dump(self, prefix: str, body: str) -> str: + if len(body) == 0: + return "" + try: + # Unmarshal body into primitive types. + tmp = json.loads(body) + max_bytes = 96 + if self._debug_truncate_bytes > max_bytes: + max_bytes = self._debug_truncate_bytes + # Re-marshal body taking redaction and character limit into account. + raw = self._recursive_marshal(tmp, max_bytes) + return "\n".join([f'{prefix}{line}' for line in json.dumps(raw, indent=2).split("\n")]) + except JSONDecodeError: + return f'{prefix}[non-JSON document of {len(body)} bytes]' diff --git a/databricks/sdk/dbutils.py b/databricks/sdk/dbutils.py index 60d2ccd69..182b352f3 100644 --- a/databricks/sdk/dbutils.py +++ b/databricks/sdk/dbutils.py @@ -175,6 +175,19 @@ def __init__(self, config: 'Config' = None): self.fs = _FsUtil(dbfs_ext.DbfsExt(self._client), self.__getattr__) self.secrets = _SecretsUtil(workspace.SecretsAPI(self._client)) + self._widgets = None + + # When we import widget_impl, the init file checks whether user has the + # correct dependencies required for running on notebook or not (ipywidgets etc). + # We only want these checks (and the subsequent errors and warnings), to + # happen when the user actually uses widgets. + @property + def widgets(self): + if self._widgets is None: + from ._widgets import widget_impl + self._widgets = widget_impl() + + return self._widgets @property def _cluster_id(self) -> str: @@ -192,7 +205,7 @@ def _running_command_context(self) -> compute.ContextStatusResponse: return self._ctx self._clusters.ensure_cluster_is_running(self._cluster_id) self._ctx = self._commands.create(cluster_id=self._cluster_id, - language=compute.Language.python).result() + language=compute.Language.PYTHON).result() return self._ctx def __getattr__(self, util) -> '_ProxyUtil': @@ -245,10 +258,10 @@ def __init__(self, *, command_execution: compute.CommandExecutionAPI, _ascii_escape_re = re.compile(r'(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]') def _is_failed(self, results: compute.Results) -> bool: - return results.result_type == compute.ResultType.error + return results.result_type == compute.ResultType.ERROR def _text(self, results: compute.Results) -> str: - if results.result_type != compute.ResultType.text: + if results.result_type != compute.ResultType.TEXT: return '' return self._out_re.sub("", str(results.data)) @@ -292,10 +305,10 @@ def __call__(self, *args, **kwargs): ''' ctx = self._context_factory() result = self._commands.execute(cluster_id=self._cluster_id, - language=compute.Language.python, + language=compute.Language.PYTHON, context_id=ctx.id, command=code).result() - if result.status == compute.CommandStatus.Finished: + if result.status == compute.CommandStatus.FINISHED: self._raise_if_failed(result.results) raw = result.results.data return json.loads(raw) diff --git a/databricks/sdk/mixins/workspace.py b/databricks/sdk/mixins/workspace.py index 12040b64e..9409a1112 100644 --- a/databricks/sdk/mixins/workspace.py +++ b/databricks/sdk/mixins/workspace.py @@ -84,7 +84,7 @@ def upload(self, return self._api.do('POST', '/api/2.0/workspace/import', files={'content': content}, data=data) except DatabricksError as e: if e.error_code == 'INVALID_PARAMETER_VALUE': - msg = f'Perhaps you forgot to specify the `format=ExportFormat.AUTO`. {e}' + msg = f'Perhaps you forgot to specify the `format=ImportFormat.AUTO`. {e}' raise DatabricksError(message=msg, error_code=e.error_code) else: raise e diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index b79b9b2c6..609b24f95 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + is_local_implementation = True # All objects that are injected into the Notebook's user namespace should also be made @@ -31,15 +33,22 @@ _globals[var] = userNamespaceGlobals[var] is_local_implementation = False except ImportError: + from typing import Type, cast + # OSS implementation is_local_implementation = True try: + from . import stub from .stub import * + dbutils_type = Type[stub.dbutils] except (ImportError, NameError): from databricks.sdk.dbutils import RemoteDbUtils # this assumes that all environment variables are set dbutils = RemoteDbUtils() + dbutils_type = RemoteDbUtils + + dbutils = cast(dbutils_type, dbutils) __all__ = ['dbutils'] if is_local_implementation else dbruntime_objects diff --git a/databricks/sdk/service/_internal.py b/databricks/sdk/service/_internal.py index 9adadf412..f7b929480 100644 --- a/databricks/sdk/service/_internal.py +++ b/databricks/sdk/service/_internal.py @@ -18,7 +18,7 @@ def _repeated(d: Dict[str, any], field: str, cls: Type) -> any: def _enum(d: Dict[str, any], field: str, cls: Type) -> any: if field not in d or not d[field]: return None - return getattr(cls, '__members__').get(d[field], None) + return next((v for v in getattr(cls, '__members__').values() if v.value == d[field]), None) ReturnType = TypeVar('ReturnType') diff --git a/databricks/sdk/service/catalog.py b/databricks/sdk/service/catalog.py index ebe7dacf7..9183b2cd3 100755 --- a/databricks/sdk/service/catalog.py +++ b/databricks/sdk/service/catalog.py @@ -1011,10 +1011,10 @@ class DisableRequest: class DisableSchemaName(Enum): - access = 'access' - billing = 'billing' - lineage = 'lineage' - operational_data = 'operational_data' + ACCESS = 'access' + BILLING = 'billing' + LINEAGE = 'lineage' + OPERATIONAL_DATA = 'operational_data' @dataclass @@ -1116,10 +1116,10 @@ class EnableRequest: class EnableSchemaName(Enum): - access = 'access' - billing = 'billing' - lineage = 'lineage' - operational_data = 'operational_data' + ACCESS = 'access' + BILLING = 'billing' + LINEAGE = 'lineage' + OPERATIONAL_DATA = 'operational_data' @dataclass @@ -2110,16 +2110,17 @@ def from_dict(cls, d: Dict[str, any]) -> 'SchemaInfo': class SecurableType(Enum): """The type of Unity Catalog securable""" - CATALOG = 'CATALOG' - EXTERNAL_LOCATION = 'EXTERNAL_LOCATION' - FUNCTION = 'FUNCTION' - METASTORE = 'METASTORE' - PROVIDER = 'PROVIDER' - RECIPIENT = 'RECIPIENT' - SCHEMA = 'SCHEMA' - SHARE = 'SHARE' - STORAGE_CREDENTIAL = 'STORAGE_CREDENTIAL' - TABLE = 'TABLE' + CATALOG = 'catalog' + EXTERNAL_LOCATION = 'external_location' + FUNCTION = 'function' + METASTORE = 'metastore' + PIPELINE = 'pipeline' + PROVIDER = 'provider' + RECIPIENT = 'recipient' + SCHEMA = 'schema' + SHARE = 'share' + STORAGE_CREDENTIAL = 'storage_credential' + TABLE = 'table' @dataclass @@ -2700,6 +2701,24 @@ def from_dict(cls, d: Dict[str, any]) -> 'UpdateStorageCredential': skip_validation=d.get('skip_validation', None)) +@dataclass +class UpdateTableRequest: + """Update a table owner.""" + + full_name: str + owner: Optional[str] = None + + def as_dict(self) -> dict: + body = {} + if self.full_name is not None: body['full_name'] = self.full_name + if self.owner is not None: body['owner'] = self.owner + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'UpdateTableRequest': + return cls(full_name=d.get('full_name', None), owner=d.get('owner', None)) + + @dataclass class UpdateVolumeRequestContent: comment: Optional[str] = None @@ -4733,7 +4752,7 @@ def enable(self, metastore_id: str, schema_name: EnableSchemaName, **kwargs): request = EnableRequest(metastore_id=metastore_id, schema_name=schema_name) self._api.do( - 'POST', + 'PUT', f'/api/2.1/unity-catalog/metastores/{request.metastore_id}/systemschemas/{request.schema_name.value}' ) @@ -5008,6 +5027,26 @@ def list_summaries(self, return query['page_token'] = json['next_page_token'] + def update(self, full_name: str, *, owner: Optional[str] = None, **kwargs): + """Update a table owner. + + Change the owner of the table. The caller must be the owner of the parent catalog, have the + **USE_CATALOG** privilege on the parent catalog and be the owner of the parent schema, or be the owner + of the table and have the **USE_CATALOG** privilege on the parent catalog and the **USE_SCHEMA** + privilege on the parent schema. + + :param full_name: str + Full name of the table. + :param owner: str (optional) + + + """ + request = kwargs.get('request', None) + if not request: # request is not given through keyed args + request = UpdateTableRequest(full_name=full_name, owner=owner) + body = request.as_dict() + self._api.do('PATCH', f'/api/2.1/unity-catalog/tables/{request.full_name}', body=body) + class VolumesAPI: """Volumes are a Unity Catalog (UC) capability for accessing, storing, governing, organizing and processing diff --git a/databricks/sdk/service/compute.py b/databricks/sdk/service/compute.py index 98734ec00..af2f6e6b6 100755 --- a/databricks/sdk/service/compute.py +++ b/databricks/sdk/service/compute.py @@ -209,8 +209,8 @@ def from_dict(cls, d: Dict[str, any]) -> 'CloudProviderNodeInfo': class CloudProviderNodeStatus(Enum): - NotAvailableInRegion = 'NotAvailableInRegion' - NotEnabledOnSubscription = 'NotEnabledOnSubscription' + NOT_AVAILABLE_IN_REGION = 'NotAvailableInRegion' + NOT_ENABLED_ON_SUBSCRIPTION = 'NotEnabledOnSubscription' @dataclass @@ -664,12 +664,12 @@ def from_dict(cls, d: Dict[str, any]) -> 'Command': class CommandStatus(Enum): - Cancelled = 'Cancelled' - Cancelling = 'Cancelling' - Error = 'Error' - Finished = 'Finished' - Queued = 'Queued' - Running = 'Running' + CANCELLED = 'Cancelled' + CANCELLING = 'Cancelling' + ERROR = 'Error' + FINISHED = 'Finished' + QUEUED = 'Queued' + RUNNING = 'Running' @dataclass @@ -723,9 +723,9 @@ class ComputeSpecKind(Enum): class ContextStatus(Enum): - Error = 'Error' - Pending = 'Pending' - Running = 'Running' + ERROR = 'Error' + PENDING = 'Pending' + RUNNING = 'Running' @dataclass @@ -1813,6 +1813,8 @@ class GetInstancePoolRequest: @dataclass class GetPolicyFamilyRequest: + """Get policy family information""" + policy_family_id: str @@ -2227,9 +2229,9 @@ def from_dict(cls, d: Dict[str, any]) -> 'InstanceProfile': class Language(Enum): - python = 'python' - scala = 'scala' - sql = 'sql' + PYTHON = 'python' + SCALA = 'scala' + SQL = 'sql' @dataclass @@ -2428,6 +2430,8 @@ def from_dict(cls, d: Dict[str, any]) -> 'ListPoliciesResponse': @dataclass class ListPolicyFamiliesRequest: + """List policy families""" + max_results: Optional[int] = None page_token: Optional[str] = None @@ -2809,11 +2813,11 @@ def from_dict(cls, d: Dict[str, any]) -> 'RestartCluster': class ResultType(Enum): - error = 'error' - image = 'image' - images = 'images' - table = 'table' - text = 'text' + ERROR = 'error' + IMAGE = 'image' + IMAGES = 'images' + TABLE = 'table' + TEXT = 'text' @dataclass @@ -4280,8 +4284,8 @@ def wait_command_status_command_execution_cancelled( timeout=timedelta(minutes=20), callback: Optional[Callable[[CommandStatusResponse], None]] = None) -> CommandStatusResponse: deadline = time.time() + timeout.total_seconds() - target_states = (CommandStatus.Cancelled, ) - failure_states = (CommandStatus.Error, ) + target_states = (CommandStatus.CANCELLED, ) + failure_states = (CommandStatus.ERROR, ) status_message = 'polling...' attempt = 1 while time.time() < deadline: @@ -4315,8 +4319,8 @@ def wait_command_status_command_execution_finished_or_error( timeout=timedelta(minutes=20), callback: Optional[Callable[[CommandStatusResponse], None]] = None) -> CommandStatusResponse: deadline = time.time() + timeout.total_seconds() - target_states = (CommandStatus.Finished, CommandStatus.Error, ) - failure_states = (CommandStatus.Cancelled, CommandStatus.Cancelling, ) + target_states = (CommandStatus.FINISHED, CommandStatus.ERROR, ) + failure_states = (CommandStatus.CANCELLED, CommandStatus.CANCELLING, ) status_message = 'polling...' attempt = 1 while time.time() < deadline: @@ -4347,8 +4351,8 @@ def wait_context_status_command_execution_running( timeout=timedelta(minutes=20), callback: Optional[Callable[[ContextStatusResponse], None]] = None) -> ContextStatusResponse: deadline = time.time() + timeout.total_seconds() - target_states = (ContextStatus.Running, ) - failure_states = (ContextStatus.Error, ) + target_states = (ContextStatus.RUNNING, ) + failure_states = (ContextStatus.ERROR, ) status_message = 'polling...' attempt = 1 while time.time() < deadline: @@ -4999,11 +5003,10 @@ def add(self, [Databricks SQL Serverless]: https://docs.databricks.com/sql/admin/serverless.html :param is_meta_instance_profile: bool (optional) - By default, Databricks validates that it has sufficient permissions to launch instances with the - instance profile. This validation uses AWS dry-run mode for the RunInstances API. If validation - fails with an error message that does not indicate an IAM related permission issue, (e.g. `Your - requested instance type is not supported in your requested availability zone`), you can pass this - flag to skip the validation and forcibly add the instance profile. + Boolean flag indicating whether the instance profile should only be used in credential passthrough + scenarios. If true, it means the instance profile contains an meta IAM role which could assume a + wide range of roles. Therefore it should always be used with authorization. This field is optional, + the default value is `false`. :param skip_validation: bool (optional) By default, Databricks validates that it has sufficient permissions to launch instances with the instance profile. This validation uses AWS dry-run mode for the RunInstances API. If validation @@ -5054,11 +5057,10 @@ def edit(self, [Databricks SQL Serverless]: https://docs.databricks.com/sql/admin/serverless.html :param is_meta_instance_profile: bool (optional) - By default, Databricks validates that it has sufficient permissions to launch instances with the - instance profile. This validation uses AWS dry-run mode for the RunInstances API. If validation - fails with an error message that does not indicate an IAM related permission issue, (e.g. `Your - requested instance type is not supported in your requested availability zone`), you can pass this - flag to skip the validation and forcibly add the instance profile. + Boolean flag indicating whether the instance profile should only be used in credential passthrough + scenarios. If true, it means the instance profile contains an meta IAM role which could assume a + wide range of roles. Therefore it should always be used with authorization. This field is optional, + the default value is `false`. """ @@ -5226,7 +5228,14 @@ def __init__(self, api_client): self._api = api_client def get(self, policy_family_id: str, **kwargs) -> PolicyFamily: - + """Get policy family information. + + Retrieve the information for an policy family based on its identifier. + + :param policy_family_id: str + + :returns: :class:`PolicyFamily` + """ request = kwargs.get('request', None) if not request: # request is not given through keyed args request = GetPolicyFamilyRequest(policy_family_id=policy_family_id) @@ -5239,7 +5248,17 @@ def list(self, max_results: Optional[int] = None, page_token: Optional[str] = None, **kwargs) -> Iterator[PolicyFamily]: - + """List policy families. + + Retrieve a list of policy families. This API is paginated. + + :param max_results: int (optional) + The max number of policy families to return. + :param page_token: str (optional) + A token that can be used to get the next page of results. + + :returns: Iterator over :class:`PolicyFamily` + """ request = kwargs.get('request', None) if not request: # request is not given through keyed args request = ListPolicyFamiliesRequest(max_results=max_results, page_token=page_token) diff --git a/databricks/sdk/service/iam.py b/databricks/sdk/service/iam.py index 7dec4635c..abd2692fb 100755 --- a/databricks/sdk/service/iam.py +++ b/databricks/sdk/service/iam.py @@ -3,7 +3,7 @@ import logging from dataclasses import dataclass from enum import Enum -from typing import Dict, Iterator, List, Optional +from typing import Any, Dict, Iterator, List, Optional from ._internal import _enum, _from_dict, _repeated @@ -265,6 +265,7 @@ class Group: groups: Optional['List[ComplexValue]'] = None id: Optional[str] = None members: Optional['List[ComplexValue]'] = None + meta: Optional['ResourceMeta'] = None roles: Optional['List[ComplexValue]'] = None def as_dict(self) -> dict: @@ -275,6 +276,7 @@ def as_dict(self) -> dict: if self.groups: body['groups'] = [v.as_dict() for v in self.groups] if self.id is not None: body['id'] = self.id if self.members: body['members'] = [v.as_dict() for v in self.members] + if self.meta: body['meta'] = self.meta.as_dict() if self.roles: body['roles'] = [v.as_dict() for v in self.roles] return body @@ -286,6 +288,7 @@ def from_dict(cls, d: Dict[str, any]) -> 'Group': groups=_repeated(d, 'groups', ComplexValue), id=d.get('id', None), members=_repeated(d, 'members', ComplexValue), + meta=_from_dict(d, 'meta', ResourceMeta), roles=_repeated(d, 'roles', ComplexValue)) @@ -402,8 +405,8 @@ class ListServicePrincipalsRequest: class ListSortOrder(Enum): - ascending = 'ascending' - descending = 'descending' + ASCENDING = 'ascending' + DESCENDING = 'descending' @dataclass @@ -490,29 +493,33 @@ def from_dict(cls, d: Dict[str, any]) -> 'ObjectPermissions': class PartialUpdate: id: Optional[str] = None operations: Optional['List[Patch]'] = None + schema: Optional['List[PatchSchema]'] = None def as_dict(self) -> dict: body = {} if self.id is not None: body['id'] = self.id - if self.operations: body['operations'] = [v.as_dict() for v in self.operations] + if self.operations: body['Operations'] = [v.as_dict() for v in self.operations] + if self.schema: body['schema'] = [v for v in self.schema] return body @classmethod def from_dict(cls, d: Dict[str, any]) -> 'PartialUpdate': - return cls(id=d.get('id', None), operations=_repeated(d, 'operations', Patch)) + return cls(id=d.get('id', None), + operations=_repeated(d, 'Operations', Patch), + schema=d.get('schema', None)) @dataclass class Patch: op: Optional['PatchOp'] = None path: Optional[str] = None - value: Optional[str] = None + value: Optional[Any] = None def as_dict(self) -> dict: body = {} if self.op is not None: body['op'] = self.op.value if self.path is not None: body['path'] = self.path - if self.value is not None: body['value'] = self.value + if self.value: body['value'] = self.value return body @classmethod @@ -523,9 +530,14 @@ def from_dict(cls, d: Dict[str, any]) -> 'Patch': class PatchOp(Enum): """Type of patch operation.""" - add = 'add' - remove = 'remove' - replace = 'replace' + ADD = 'add' + REMOVE = 'remove' + REPLACE = 'replace' + + +class PatchSchema(Enum): + + URN_IETF_PARAMS_SCIM_API_MESSAGES20_PATCH_OP = 'urn:ietf:params:scim:api:messages:2.0:PatchOp' @dataclass @@ -685,6 +697,20 @@ def from_dict(cls, d: Dict[str, any]) -> 'PrincipalOutput': user_name=d.get('user_name', None)) +@dataclass +class ResourceMeta: + resource_type: Optional[str] = None + + def as_dict(self) -> dict: + body = {} + if self.resource_type is not None: body['resourceType'] = self.resource_type + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'ResourceMeta': + return cls(resource_type=d.get('resourceType', None)) + + @dataclass class RuleSetResponse: etag: Optional[str] = None @@ -1044,6 +1070,7 @@ def create(self, groups: Optional[List[ComplexValue]] = None, id: Optional[str] = None, members: Optional[List[ComplexValue]] = None, + meta: Optional[ResourceMeta] = None, roles: Optional[List[ComplexValue]] = None, **kwargs) -> Group: """Create a new group. @@ -1058,6 +1085,8 @@ def create(self, :param id: str (optional) Databricks group ID :param members: List[:class:`ComplexValue`] (optional) + :param meta: :class:`ResourceMeta` (optional) + Container for the group identifier. Workspace local versus account. :param roles: List[:class:`ComplexValue`] (optional) :returns: :class:`Group` @@ -1070,6 +1099,7 @@ def create(self, groups=groups, id=id, members=members, + meta=meta, roles=roles) body = request.as_dict() @@ -1167,7 +1197,12 @@ def list(self, json = self._api.do('GET', f'/api/2.0/accounts/{self._api.account_id}/scim/v2/Groups', query=query) return [Group.from_dict(v) for v in json.get('Resources', [])] - def patch(self, id: str, *, operations: Optional[List[Patch]] = None, **kwargs): + def patch(self, + id: str, + *, + operations: Optional[List[Patch]] = None, + schema: Optional[List[PatchSchema]] = None, + **kwargs): """Update group details. Partially updates the details of a group. @@ -1175,12 +1210,14 @@ def patch(self, id: str, *, operations: Optional[List[Patch]] = None, **kwargs): :param id: str Unique ID for a group in the Databricks account. :param operations: List[:class:`Patch`] (optional) + :param schema: List[:class:`PatchSchema`] (optional) + The schema of the patch request. Must be ["urn:ietf:params:scim:api:messages:2.0:PatchOp"]. """ request = kwargs.get('request', None) if not request: # request is not given through keyed args - request = PartialUpdate(id=id, operations=operations) + request = PartialUpdate(id=id, operations=operations, schema=schema) body = request.as_dict() self._api.do('PATCH', f'/api/2.0/accounts/{self._api.account_id}/scim/v2/Groups/{request.id}', @@ -1194,6 +1231,7 @@ def update(self, external_id: Optional[str] = None, groups: Optional[List[ComplexValue]] = None, members: Optional[List[ComplexValue]] = None, + meta: Optional[ResourceMeta] = None, roles: Optional[List[ComplexValue]] = None, **kwargs): """Replace a group. @@ -1208,6 +1246,8 @@ def update(self, :param external_id: str (optional) :param groups: List[:class:`ComplexValue`] (optional) :param members: List[:class:`ComplexValue`] (optional) + :param meta: :class:`ResourceMeta` (optional) + Container for the group identifier. Workspace local versus account. :param roles: List[:class:`ComplexValue`] (optional) @@ -1220,6 +1260,7 @@ def update(self, groups=groups, id=id, members=members, + meta=meta, roles=roles) body = request.as_dict() self._api.do('PUT', @@ -1379,7 +1420,12 @@ def list(self, query=query) return [ServicePrincipal.from_dict(v) for v in json.get('Resources', [])] - def patch(self, id: str, *, operations: Optional[List[Patch]] = None, **kwargs): + def patch(self, + id: str, + *, + operations: Optional[List[Patch]] = None, + schema: Optional[List[PatchSchema]] = None, + **kwargs): """Update service principal details. Partially updates the details of a single service principal in the Databricks account. @@ -1387,12 +1433,14 @@ def patch(self, id: str, *, operations: Optional[List[Patch]] = None, **kwargs): :param id: str Unique ID for a service principal in the Databricks account. :param operations: List[:class:`Patch`] (optional) + :param schema: List[:class:`PatchSchema`] (optional) + The schema of the patch request. Must be ["urn:ietf:params:scim:api:messages:2.0:PatchOp"]. """ request = kwargs.get('request', None) if not request: # request is not given through keyed args - request = PartialUpdate(id=id, operations=operations) + request = PartialUpdate(id=id, operations=operations, schema=schema) body = request.as_dict() self._api.do('PATCH', f'/api/2.0/accounts/{self._api.account_id}/scim/v2/ServicePrincipals/{request.id}', @@ -1606,7 +1654,12 @@ def list(self, json = self._api.do('GET', f'/api/2.0/accounts/{self._api.account_id}/scim/v2/Users', query=query) return [User.from_dict(v) for v in json.get('Resources', [])] - def patch(self, id: str, *, operations: Optional[List[Patch]] = None, **kwargs): + def patch(self, + id: str, + *, + operations: Optional[List[Patch]] = None, + schema: Optional[List[PatchSchema]] = None, + **kwargs): """Update user details. Partially updates a user resource by applying the supplied operations on specific user attributes. @@ -1614,12 +1667,14 @@ def patch(self, id: str, *, operations: Optional[List[Patch]] = None, **kwargs): :param id: str Unique ID for a user in the Databricks account. :param operations: List[:class:`Patch`] (optional) + :param schema: List[:class:`PatchSchema`] (optional) + The schema of the patch request. Must be ["urn:ietf:params:scim:api:messages:2.0:PatchOp"]. """ request = kwargs.get('request', None) if not request: # request is not given through keyed args - request = PartialUpdate(id=id, operations=operations) + request = PartialUpdate(id=id, operations=operations, schema=schema) body = request.as_dict() self._api.do('PATCH', f'/api/2.0/accounts/{self._api.account_id}/scim/v2/Users/{request.id}', @@ -1713,6 +1768,7 @@ def create(self, groups: Optional[List[ComplexValue]] = None, id: Optional[str] = None, members: Optional[List[ComplexValue]] = None, + meta: Optional[ResourceMeta] = None, roles: Optional[List[ComplexValue]] = None, **kwargs) -> Group: """Create a new group. @@ -1727,6 +1783,8 @@ def create(self, :param id: str (optional) Databricks group ID :param members: List[:class:`ComplexValue`] (optional) + :param meta: :class:`ResourceMeta` (optional) + Container for the group identifier. Workspace local versus account. :param roles: List[:class:`ComplexValue`] (optional) :returns: :class:`Group` @@ -1739,6 +1797,7 @@ def create(self, groups=groups, id=id, members=members, + meta=meta, roles=roles) body = request.as_dict() @@ -1836,7 +1895,12 @@ def list(self, json = self._api.do('GET', '/api/2.0/preview/scim/v2/Groups', query=query) return [Group.from_dict(v) for v in json.get('Resources', [])] - def patch(self, id: str, *, operations: Optional[List[Patch]] = None, **kwargs): + def patch(self, + id: str, + *, + operations: Optional[List[Patch]] = None, + schema: Optional[List[PatchSchema]] = None, + **kwargs): """Update group details. Partially updates the details of a group. @@ -1844,12 +1908,14 @@ def patch(self, id: str, *, operations: Optional[List[Patch]] = None, **kwargs): :param id: str Unique ID for a group in the Databricks workspace. :param operations: List[:class:`Patch`] (optional) + :param schema: List[:class:`PatchSchema`] (optional) + The schema of the patch request. Must be ["urn:ietf:params:scim:api:messages:2.0:PatchOp"]. """ request = kwargs.get('request', None) if not request: # request is not given through keyed args - request = PartialUpdate(id=id, operations=operations) + request = PartialUpdate(id=id, operations=operations, schema=schema) body = request.as_dict() self._api.do('PATCH', f'/api/2.0/preview/scim/v2/Groups/{request.id}', body=body) @@ -1861,6 +1927,7 @@ def update(self, external_id: Optional[str] = None, groups: Optional[List[ComplexValue]] = None, members: Optional[List[ComplexValue]] = None, + meta: Optional[ResourceMeta] = None, roles: Optional[List[ComplexValue]] = None, **kwargs): """Replace a group. @@ -1875,6 +1942,8 @@ def update(self, :param external_id: str (optional) :param groups: List[:class:`ComplexValue`] (optional) :param members: List[:class:`ComplexValue`] (optional) + :param meta: :class:`ResourceMeta` (optional) + Container for the group identifier. Workspace local versus account. :param roles: List[:class:`ComplexValue`] (optional) @@ -1887,6 +1956,7 @@ def update(self, groups=groups, id=id, members=members, + meta=meta, roles=roles) body = request.as_dict() self._api.do('PUT', f'/api/2.0/preview/scim/v2/Groups/{request.id}', body=body) @@ -2146,7 +2216,12 @@ def list(self, json = self._api.do('GET', '/api/2.0/preview/scim/v2/ServicePrincipals', query=query) return [ServicePrincipal.from_dict(v) for v in json.get('Resources', [])] - def patch(self, id: str, *, operations: Optional[List[Patch]] = None, **kwargs): + def patch(self, + id: str, + *, + operations: Optional[List[Patch]] = None, + schema: Optional[List[PatchSchema]] = None, + **kwargs): """Update service principal details. Partially updates the details of a single service principal in the Databricks workspace. @@ -2154,12 +2229,14 @@ def patch(self, id: str, *, operations: Optional[List[Patch]] = None, **kwargs): :param id: str Unique ID for a service principal in the Databricks workspace. :param operations: List[:class:`Patch`] (optional) + :param schema: List[:class:`PatchSchema`] (optional) + The schema of the patch request. Must be ["urn:ietf:params:scim:api:messages:2.0:PatchOp"]. """ request = kwargs.get('request', None) if not request: # request is not given through keyed args - request = PartialUpdate(id=id, operations=operations) + request = PartialUpdate(id=id, operations=operations, schema=schema) body = request.as_dict() self._api.do('PATCH', f'/api/2.0/preview/scim/v2/ServicePrincipals/{request.id}', body=body) @@ -2369,7 +2446,12 @@ def list(self, json = self._api.do('GET', '/api/2.0/preview/scim/v2/Users', query=query) return [User.from_dict(v) for v in json.get('Resources', [])] - def patch(self, id: str, *, operations: Optional[List[Patch]] = None, **kwargs): + def patch(self, + id: str, + *, + operations: Optional[List[Patch]] = None, + schema: Optional[List[PatchSchema]] = None, + **kwargs): """Update user details. Partially updates a user resource by applying the supplied operations on specific user attributes. @@ -2377,12 +2459,14 @@ def patch(self, id: str, *, operations: Optional[List[Patch]] = None, **kwargs): :param id: str Unique ID for a user in the Databricks workspace. :param operations: List[:class:`Patch`] (optional) + :param schema: List[:class:`PatchSchema`] (optional) + The schema of the patch request. Must be ["urn:ietf:params:scim:api:messages:2.0:PatchOp"]. """ request = kwargs.get('request', None) if not request: # request is not given through keyed args - request = PartialUpdate(id=id, operations=operations) + request = PartialUpdate(id=id, operations=operations, schema=schema) body = request.as_dict() self._api.do('PATCH', f'/api/2.0/preview/scim/v2/Users/{request.id}', body=body) diff --git a/databricks/sdk/service/jobs.py b/databricks/sdk/service/jobs.py index c4b30330d..eb4a9f017 100755 --- a/databricks/sdk/service/jobs.py +++ b/databricks/sdk/service/jobs.py @@ -54,6 +54,7 @@ class BaseRun: git_source: Optional['GitSource'] = None job_clusters: Optional['List[JobCluster]'] = None job_id: Optional[int] = None + job_parameters: Optional['List[JobParameter]'] = None number_in_job: Optional[int] = None original_attempt_run_id: Optional[int] = None overriding_parameters: Optional['RunParameters'] = None @@ -68,6 +69,7 @@ class BaseRun: state: Optional['RunState'] = None tasks: Optional['List[RunTask]'] = None trigger: Optional['TriggerType'] = None + trigger_info: Optional['TriggerInfo'] = None def as_dict(self) -> dict: body = {} @@ -82,6 +84,7 @@ def as_dict(self) -> dict: if self.git_source: body['git_source'] = self.git_source.as_dict() if self.job_clusters: body['job_clusters'] = [v.as_dict() for v in self.job_clusters] if self.job_id is not None: body['job_id'] = self.job_id + if self.job_parameters: body['job_parameters'] = [v.as_dict() for v in self.job_parameters] if self.number_in_job is not None: body['number_in_job'] = self.number_in_job if self.original_attempt_run_id is not None: body['original_attempt_run_id'] = self.original_attempt_run_id @@ -97,6 +100,7 @@ def as_dict(self) -> dict: if self.state: body['state'] = self.state.as_dict() if self.tasks: body['tasks'] = [v.as_dict() for v in self.tasks] if self.trigger is not None: body['trigger'] = self.trigger.value + if self.trigger_info: body['trigger_info'] = self.trigger_info.as_dict() return body @classmethod @@ -112,6 +116,7 @@ def from_dict(cls, d: Dict[str, any]) -> 'BaseRun': git_source=_from_dict(d, 'git_source', GitSource), job_clusters=_repeated(d, 'job_clusters', JobCluster), job_id=d.get('job_id', None), + job_parameters=_repeated(d, 'job_parameters', JobParameter), number_in_job=d.get('number_in_job', None), original_attempt_run_id=d.get('original_attempt_run_id', None), overriding_parameters=_from_dict(d, 'overriding_parameters', RunParameters), @@ -125,7 +130,8 @@ def from_dict(cls, d: Dict[str, any]) -> 'BaseRun': start_time=d.get('start_time', None), state=_from_dict(d, 'state', RunState), tasks=_repeated(d, 'tasks', RunTask), - trigger=_enum(d, 'trigger', TriggerType)) + trigger=_enum(d, 'trigger', TriggerType), + trigger_info=_from_dict(d, 'trigger_info', TriggerInfo)) @dataclass @@ -255,6 +261,7 @@ class CreateJob: max_concurrent_runs: Optional[int] = None name: Optional[str] = None notification_settings: Optional['JobNotificationSettings'] = None + parameters: Optional['List[JobParameterDefinition]'] = None run_as: Optional['JobRunAs'] = None schedule: Optional['CronSchedule'] = None tags: Optional['Dict[str,str]'] = None @@ -276,6 +283,7 @@ def as_dict(self) -> dict: if self.max_concurrent_runs is not None: body['max_concurrent_runs'] = self.max_concurrent_runs if self.name is not None: body['name'] = self.name if self.notification_settings: body['notification_settings'] = self.notification_settings.as_dict() + if self.parameters: body['parameters'] = [v.as_dict() for v in self.parameters] if self.run_as: body['run_as'] = self.run_as.as_dict() if self.schedule: body['schedule'] = self.schedule.as_dict() if self.tags: body['tags'] = self.tags @@ -297,6 +305,7 @@ def from_dict(cls, d: Dict[str, any]) -> 'CreateJob': max_concurrent_runs=d.get('max_concurrent_runs', None), name=d.get('name', None), notification_settings=_from_dict(d, 'notification_settings', JobNotificationSettings), + parameters=_repeated(d, 'parameters', JobParameterDefinition), run_as=_from_dict(d, 'run_as', JobRunAs), schedule=_from_dict(d, 'schedule', CronSchedule), tags=d.get('tags', None), @@ -439,14 +448,14 @@ class ExportRunRequest: @dataclass class FileArrivalTriggerConfiguration: - min_time_between_trigger_seconds: Optional[int] = None + min_time_between_triggers_seconds: Optional[int] = None url: Optional[str] = None wait_after_last_change_seconds: Optional[int] = None def as_dict(self) -> dict: body = {} - if self.min_time_between_trigger_seconds is not None: - body['min_time_between_trigger_seconds'] = self.min_time_between_trigger_seconds + if self.min_time_between_triggers_seconds is not None: + body['min_time_between_triggers_seconds'] = self.min_time_between_triggers_seconds if self.url is not None: body['url'] = self.url if self.wait_after_last_change_seconds is not None: body['wait_after_last_change_seconds'] = self.wait_after_last_change_seconds @@ -454,7 +463,7 @@ def as_dict(self) -> dict: @classmethod def from_dict(cls, d: Dict[str, any]) -> 'FileArrivalTriggerConfiguration': - return cls(min_time_between_trigger_seconds=d.get('min_time_between_trigger_seconds', None), + return cls(min_time_between_triggers_seconds=d.get('min_time_between_triggers_seconds', None), url=d.get('url', None), wait_after_last_change_seconds=d.get('wait_after_last_change_seconds', None)) @@ -489,14 +498,14 @@ class GetRunRequest: class GitProvider(Enum): - awsCodeCommit = 'awsCodeCommit' - azureDevOpsServices = 'azureDevOpsServices' - bitbucketCloud = 'bitbucketCloud' - bitbucketServer = 'bitbucketServer' - gitHub = 'gitHub' - gitHubEnterprise = 'gitHubEnterprise' - gitLab = 'gitLab' - gitLabEnterpriseEdition = 'gitLabEnterpriseEdition' + AWS_CODE_COMMIT = 'awsCodeCommit' + AZURE_DEV_OPS_SERVICES = 'azureDevOpsServices' + BITBUCKET_CLOUD = 'bitbucketCloud' + BITBUCKET_SERVER = 'bitbucketServer' + GIT_HUB = 'gitHub' + GIT_HUB_ENTERPRISE = 'gitHubEnterprise' + GIT_LAB = 'gitLab' + GIT_LAB_ENTERPRISE_EDITION = 'gitLabEnterpriseEdition' @dataclass @@ -527,6 +536,7 @@ class GitSource: git_commit: Optional[str] = None git_snapshot: Optional['GitSnapshot'] = None git_tag: Optional[str] = None + job_source: Optional['JobSource'] = None def as_dict(self) -> dict: body = {} @@ -536,6 +546,7 @@ def as_dict(self) -> dict: if self.git_snapshot: body['git_snapshot'] = self.git_snapshot.as_dict() if self.git_tag is not None: body['git_tag'] = self.git_tag if self.git_url is not None: body['git_url'] = self.git_url + if self.job_source: body['job_source'] = self.job_source.as_dict() return body @classmethod @@ -545,7 +556,8 @@ def from_dict(cls, d: Dict[str, any]) -> 'GitSource': git_provider=_enum(d, 'git_provider', GitProvider), git_snapshot=_from_dict(d, 'git_snapshot', GitSnapshot), git_tag=d.get('git_tag', None), - git_url=d.get('git_url', None)) + git_url=d.get('git_url', None), + job_source=_from_dict(d, 'job_source', JobSource)) @dataclass @@ -653,6 +665,40 @@ def from_dict(cls, d: Dict[str, any]) -> 'JobNotificationSettings': no_alert_for_skipped_runs=d.get('no_alert_for_skipped_runs', None)) +@dataclass +class JobParameter: + default: Optional[str] = None + name: Optional[str] = None + value: Optional[str] = None + + def as_dict(self) -> dict: + body = {} + if self.default is not None: body['default'] = self.default + if self.name is not None: body['name'] = self.name + if self.value is not None: body['value'] = self.value + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'JobParameter': + return cls(default=d.get('default', None), name=d.get('name', None), value=d.get('value', None)) + + +@dataclass +class JobParameterDefinition: + name: str + default: str + + def as_dict(self) -> dict: + body = {} + if self.default is not None: body['default'] = self.default + if self.name is not None: body['name'] = self.name + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'JobParameterDefinition': + return cls(default=d.get('default', None), name=d.get('name', None)) + + @dataclass class JobRunAs: """Write-only setting, available only in Create/Update/Reset and Submit calls. Specifies the user @@ -689,6 +735,7 @@ class JobSettings: max_concurrent_runs: Optional[int] = None name: Optional[str] = None notification_settings: Optional['JobNotificationSettings'] = None + parameters: Optional['List[JobParameterDefinition]'] = None run_as: Optional['JobRunAs'] = None schedule: Optional['CronSchedule'] = None tags: Optional['Dict[str,str]'] = None @@ -708,6 +755,7 @@ def as_dict(self) -> dict: if self.max_concurrent_runs is not None: body['max_concurrent_runs'] = self.max_concurrent_runs if self.name is not None: body['name'] = self.name if self.notification_settings: body['notification_settings'] = self.notification_settings.as_dict() + if self.parameters: body['parameters'] = [v.as_dict() for v in self.parameters] if self.run_as: body['run_as'] = self.run_as.as_dict() if self.schedule: body['schedule'] = self.schedule.as_dict() if self.tags: body['tags'] = self.tags @@ -728,6 +776,7 @@ def from_dict(cls, d: Dict[str, any]) -> 'JobSettings': max_concurrent_runs=d.get('max_concurrent_runs', None), name=d.get('name', None), notification_settings=_from_dict(d, 'notification_settings', JobNotificationSettings), + parameters=_repeated(d, 'parameters', JobParameterDefinition), run_as=_from_dict(d, 'run_as', JobRunAs), schedule=_from_dict(d, 'schedule', CronSchedule), tags=d.get('tags', None), @@ -737,6 +786,36 @@ def from_dict(cls, d: Dict[str, any]) -> 'JobSettings': webhook_notifications=_from_dict(d, 'webhook_notifications', WebhookNotifications)) +@dataclass +class JobSource: + """The source of the job specification in the remote repository when the job is source controlled.""" + + job_config_path: str + import_from_git_branch: str + dirty_state: Optional['JobSourceDirtyState'] = None + + def as_dict(self) -> dict: + body = {} + if self.dirty_state is not None: body['dirty_state'] = self.dirty_state.value + if self.import_from_git_branch is not None: + body['import_from_git_branch'] = self.import_from_git_branch + if self.job_config_path is not None: body['job_config_path'] = self.job_config_path + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'JobSource': + return cls(dirty_state=_enum(d, 'dirty_state', JobSourceDirtyState), + import_from_git_branch=d.get('import_from_git_branch', None), + job_config_path=d.get('job_config_path', None)) + + +class JobSourceDirtyState(Enum): + """This describes an enum""" + + DISCONNECTED = 'DISCONNECTED' + NOT_SYNCED = 'NOT_SYNCED' + + @dataclass class ListJobsRequest: """List jobs""" @@ -854,6 +933,9 @@ def from_dict(cls, d: Dict[str, any]) -> 'NotebookTask': source=_enum(d, 'source', Source)) +ParamPairs = Dict[str, str] + + class PauseStatus(Enum): PAUSED = 'PAUSED' @@ -960,6 +1042,7 @@ class RepairRun: python_named_params: Optional['Dict[str,str]'] = None python_params: Optional['List[str]'] = None rerun_all_failed_tasks: Optional[bool] = None + rerun_dependent_tasks: Optional[bool] = None rerun_tasks: Optional['List[str]'] = None spark_submit_params: Optional['List[str]'] = None sql_params: Optional['Dict[str,str]'] = None @@ -975,6 +1058,7 @@ def as_dict(self) -> dict: if self.python_params: body['python_params'] = [v for v in self.python_params] if self.rerun_all_failed_tasks is not None: body['rerun_all_failed_tasks'] = self.rerun_all_failed_tasks + if self.rerun_dependent_tasks is not None: body['rerun_dependent_tasks'] = self.rerun_dependent_tasks if self.rerun_tasks: body['rerun_tasks'] = [v for v in self.rerun_tasks] if self.run_id is not None: body['run_id'] = self.run_id if self.spark_submit_params: body['spark_submit_params'] = [v for v in self.spark_submit_params] @@ -991,6 +1075,7 @@ def from_dict(cls, d: Dict[str, any]) -> 'RepairRun': python_named_params=d.get('python_named_params', None), python_params=d.get('python_params', None), rerun_all_failed_tasks=d.get('rerun_all_failed_tasks', None), + rerun_dependent_tasks=d.get('rerun_dependent_tasks', None), rerun_tasks=d.get('rerun_tasks', None), run_id=d.get('run_id', None), spark_submit_params=d.get('spark_submit_params', None), @@ -1027,6 +1112,151 @@ def from_dict(cls, d: Dict[str, any]) -> 'ResetJob': return cls(job_id=d.get('job_id', None), new_settings=_from_dict(d, 'new_settings', JobSettings)) +@dataclass +class ResolvedConditionTaskValues: + left: Optional[str] = None + right: Optional[str] = None + + def as_dict(self) -> dict: + body = {} + if self.left is not None: body['left'] = self.left + if self.right is not None: body['right'] = self.right + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'ResolvedConditionTaskValues': + return cls(left=d.get('left', None), right=d.get('right', None)) + + +@dataclass +class ResolvedDbtTaskValues: + commands: Optional['List[str]'] = None + + def as_dict(self) -> dict: + body = {} + if self.commands: body['commands'] = [v for v in self.commands] + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'ResolvedDbtTaskValues': + return cls(commands=d.get('commands', None)) + + +@dataclass +class ResolvedNotebookTaskValues: + base_parameters: Optional['Dict[str,str]'] = None + + def as_dict(self) -> dict: + body = {} + if self.base_parameters: body['base_parameters'] = self.base_parameters + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'ResolvedNotebookTaskValues': + return cls(base_parameters=d.get('base_parameters', None)) + + +@dataclass +class ResolvedParamPairValues: + parameters: Optional['Dict[str,str]'] = None + + def as_dict(self) -> dict: + body = {} + if self.parameters: body['parameters'] = self.parameters + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'ResolvedParamPairValues': + return cls(parameters=d.get('parameters', None)) + + +@dataclass +class ResolvedPythonWheelTaskValues: + named_parameters: Optional['Dict[str,str]'] = None + parameters: Optional['List[str]'] = None + + def as_dict(self) -> dict: + body = {} + if self.named_parameters: body['named_parameters'] = self.named_parameters + if self.parameters: body['parameters'] = [v for v in self.parameters] + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'ResolvedPythonWheelTaskValues': + return cls(named_parameters=d.get('named_parameters', None), parameters=d.get('parameters', None)) + + +@dataclass +class ResolvedRunJobTaskValues: + named_parameters: Optional['Dict[str,str]'] = None + parameters: Optional['Dict[str,str]'] = None + + def as_dict(self) -> dict: + body = {} + if self.named_parameters: body['named_parameters'] = self.named_parameters + if self.parameters: body['parameters'] = self.parameters + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'ResolvedRunJobTaskValues': + return cls(named_parameters=d.get('named_parameters', None), parameters=d.get('parameters', None)) + + +@dataclass +class ResolvedStringParamsValues: + parameters: Optional['List[str]'] = None + + def as_dict(self) -> dict: + body = {} + if self.parameters: body['parameters'] = [v for v in self.parameters] + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'ResolvedStringParamsValues': + return cls(parameters=d.get('parameters', None)) + + +@dataclass +class ResolvedValues: + condition_task: Optional['ResolvedConditionTaskValues'] = None + dbt_task: Optional['ResolvedDbtTaskValues'] = None + notebook_task: Optional['ResolvedNotebookTaskValues'] = None + python_wheel_task: Optional['ResolvedPythonWheelTaskValues'] = None + run_job_task: Optional['ResolvedRunJobTaskValues'] = None + simulation_task: Optional['ResolvedParamPairValues'] = None + spark_jar_task: Optional['ResolvedStringParamsValues'] = None + spark_python_task: Optional['ResolvedStringParamsValues'] = None + spark_submit_task: Optional['ResolvedStringParamsValues'] = None + sql_task: Optional['ResolvedParamPairValues'] = None + + def as_dict(self) -> dict: + body = {} + if self.condition_task: body['condition_task'] = self.condition_task.as_dict() + if self.dbt_task: body['dbt_task'] = self.dbt_task.as_dict() + if self.notebook_task: body['notebook_task'] = self.notebook_task.as_dict() + if self.python_wheel_task: body['python_wheel_task'] = self.python_wheel_task.as_dict() + if self.run_job_task: body['run_job_task'] = self.run_job_task.as_dict() + if self.simulation_task: body['simulation_task'] = self.simulation_task.as_dict() + if self.spark_jar_task: body['spark_jar_task'] = self.spark_jar_task.as_dict() + if self.spark_python_task: body['spark_python_task'] = self.spark_python_task.as_dict() + if self.spark_submit_task: body['spark_submit_task'] = self.spark_submit_task.as_dict() + if self.sql_task: body['sql_task'] = self.sql_task.as_dict() + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'ResolvedValues': + return cls(condition_task=_from_dict(d, 'condition_task', ResolvedConditionTaskValues), + dbt_task=_from_dict(d, 'dbt_task', ResolvedDbtTaskValues), + notebook_task=_from_dict(d, 'notebook_task', ResolvedNotebookTaskValues), + python_wheel_task=_from_dict(d, 'python_wheel_task', ResolvedPythonWheelTaskValues), + run_job_task=_from_dict(d, 'run_job_task', ResolvedRunJobTaskValues), + simulation_task=_from_dict(d, 'simulation_task', ResolvedParamPairValues), + spark_jar_task=_from_dict(d, 'spark_jar_task', ResolvedStringParamsValues), + spark_python_task=_from_dict(d, 'spark_python_task', ResolvedStringParamsValues), + spark_submit_task=_from_dict(d, 'spark_submit_task', ResolvedStringParamsValues), + sql_task=_from_dict(d, 'sql_task', ResolvedParamPairValues)) + + @dataclass class Run: attempt_number: Optional[int] = None @@ -1040,6 +1270,7 @@ class Run: git_source: Optional['GitSource'] = None job_clusters: Optional['List[JobCluster]'] = None job_id: Optional[int] = None + job_parameters: Optional['List[JobParameter]'] = None number_in_job: Optional[int] = None original_attempt_run_id: Optional[int] = None overriding_parameters: Optional['RunParameters'] = None @@ -1055,6 +1286,7 @@ class Run: state: Optional['RunState'] = None tasks: Optional['List[RunTask]'] = None trigger: Optional['TriggerType'] = None + trigger_info: Optional['TriggerInfo'] = None def as_dict(self) -> dict: body = {} @@ -1069,6 +1301,7 @@ def as_dict(self) -> dict: if self.git_source: body['git_source'] = self.git_source.as_dict() if self.job_clusters: body['job_clusters'] = [v.as_dict() for v in self.job_clusters] if self.job_id is not None: body['job_id'] = self.job_id + if self.job_parameters: body['job_parameters'] = [v.as_dict() for v in self.job_parameters] if self.number_in_job is not None: body['number_in_job'] = self.number_in_job if self.original_attempt_run_id is not None: body['original_attempt_run_id'] = self.original_attempt_run_id @@ -1085,6 +1318,7 @@ def as_dict(self) -> dict: if self.state: body['state'] = self.state.as_dict() if self.tasks: body['tasks'] = [v.as_dict() for v in self.tasks] if self.trigger is not None: body['trigger'] = self.trigger.value + if self.trigger_info: body['trigger_info'] = self.trigger_info.as_dict() return body @classmethod @@ -1100,6 +1334,7 @@ def from_dict(cls, d: Dict[str, any]) -> 'Run': git_source=_from_dict(d, 'git_source', GitSource), job_clusters=_repeated(d, 'job_clusters', JobCluster), job_id=d.get('job_id', None), + job_parameters=_repeated(d, 'job_parameters', JobParameter), number_in_job=d.get('number_in_job', None), original_attempt_run_id=d.get('original_attempt_run_id', None), overriding_parameters=_from_dict(d, 'overriding_parameters', RunParameters), @@ -1114,7 +1349,8 @@ def from_dict(cls, d: Dict[str, any]) -> 'Run': start_time=d.get('start_time', None), state=_from_dict(d, 'state', RunState), tasks=_repeated(d, 'tasks', RunTask), - trigger=_enum(d, 'trigger', TriggerType)) + trigger=_enum(d, 'trigger', TriggerType), + trigger_info=_from_dict(d, 'trigger_info', TriggerInfo)) @dataclass @@ -1151,6 +1387,47 @@ class RunConditionTaskOp(Enum): NOT_EQUAL = 'NOT_EQUAL' +class RunIf(Enum): + """This describes an enum""" + + ALL_DONE = 'ALL_DONE' + ALL_FAILED = 'ALL_FAILED' + ALL_SUCCESS = 'ALL_SUCCESS' + AT_LEAST_ONE_FAILED = 'AT_LEAST_ONE_FAILED' + AT_LEAST_ONE_SUCCESS = 'AT_LEAST_ONE_SUCCESS' + NONE_FAILED = 'NONE_FAILED' + + +@dataclass +class RunJobOutput: + run_id: Optional[int] = None + + def as_dict(self) -> dict: + body = {} + if self.run_id is not None: body['run_id'] = self.run_id + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'RunJobOutput': + return cls(run_id=d.get('run_id', None)) + + +@dataclass +class RunJobTask: + job_id: int + job_parameters: Optional[Any] = None + + def as_dict(self) -> dict: + body = {} + if self.job_id is not None: body['job_id'] = self.job_id + if self.job_parameters: body['job_parameters'] = self.job_parameters + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'RunJobTask': + return cls(job_id=d.get('job_id', None), job_parameters=d.get('job_parameters', None)) + + class RunLifeCycleState(Enum): """This describes an enum""" @@ -1170,6 +1447,7 @@ class RunNow: dbt_commands: Optional['List[str]'] = None idempotency_token: Optional[str] = None jar_params: Optional['List[str]'] = None + job_parameters: Optional['List[Dict[str,str]]'] = None notebook_params: Optional['Dict[str,str]'] = None pipeline_params: Optional['PipelineParams'] = None python_named_params: Optional['Dict[str,str]'] = None @@ -1183,6 +1461,7 @@ def as_dict(self) -> dict: if self.idempotency_token is not None: body['idempotency_token'] = self.idempotency_token if self.jar_params: body['jar_params'] = [v for v in self.jar_params] if self.job_id is not None: body['job_id'] = self.job_id + if self.job_parameters: body['job_parameters'] = [v for v in self.job_parameters] if self.notebook_params: body['notebook_params'] = self.notebook_params if self.pipeline_params: body['pipeline_params'] = self.pipeline_params.as_dict() if self.python_named_params: body['python_named_params'] = self.python_named_params @@ -1197,6 +1476,7 @@ def from_dict(cls, d: Dict[str, any]) -> 'RunNow': idempotency_token=d.get('idempotency_token', None), jar_params=d.get('jar_params', None), job_id=d.get('job_id', None), + job_parameters=d.get('job_parameters', None), notebook_params=d.get('notebook_params', None), pipeline_params=_from_dict(d, 'pipeline_params', PipelineParams), python_named_params=d.get('python_named_params', None), @@ -1231,6 +1511,7 @@ class RunOutput: logs_truncated: Optional[bool] = None metadata: Optional['Run'] = None notebook_output: Optional['NotebookOutput'] = None + run_job_output: Optional['RunJobOutput'] = None sql_output: Optional['SqlOutput'] = None def as_dict(self) -> dict: @@ -1243,6 +1524,7 @@ def as_dict(self) -> dict: if self.logs_truncated is not None: body['logs_truncated'] = self.logs_truncated if self.metadata: body['metadata'] = self.metadata.as_dict() if self.notebook_output: body['notebook_output'] = self.notebook_output.as_dict() + if self.run_job_output: body['run_job_output'] = self.run_job_output.as_dict() if self.sql_output: body['sql_output'] = self.sql_output.as_dict() return body @@ -1256,6 +1538,7 @@ def from_dict(cls, d: Dict[str, any]) -> 'RunOutput': logs_truncated=d.get('logs_truncated', None), metadata=_from_dict(d, 'metadata', Run), notebook_output=_from_dict(d, 'notebook_output', NotebookOutput), + run_job_output=_from_dict(d, 'run_job_output', RunJobOutput), sql_output=_from_dict(d, 'sql_output', SqlOutput)) @@ -1298,14 +1581,19 @@ class RunResultState(Enum): """This describes an enum""" CANCELED = 'CANCELED' + EXCLUDED = 'EXCLUDED' FAILED = 'FAILED' + MAXIMUM_CONCURRENT_RUNS_REACHED = 'MAXIMUM_CONCURRENT_RUNS_REACHED' SUCCESS = 'SUCCESS' + SUCCESS_WITH_FAILURES = 'SUCCESS_WITH_FAILURES' TIMEDOUT = 'TIMEDOUT' + UPSTREAM_CANCELED = 'UPSTREAM_CANCELED' + UPSTREAM_FAILED = 'UPSTREAM_FAILED' @dataclass class RunState: - """The result and lifecycle state of the run.""" + """The current state of the run.""" life_cycle_state: Optional['RunLifeCycleState'] = None result_state: Optional['RunResultState'] = None @@ -1347,7 +1635,10 @@ class RunTask: notebook_task: Optional['NotebookTask'] = None pipeline_task: Optional['PipelineTask'] = None python_wheel_task: Optional['PythonWheelTask'] = None + resolved_values: Optional['ResolvedValues'] = None run_id: Optional[int] = None + run_if: Optional['RunIf'] = None + run_job_task: Optional['RunJobTask'] = None setup_duration: Optional[int] = None spark_jar_task: Optional['SparkJarTask'] = None spark_python_task: Optional['SparkPythonTask'] = None @@ -1375,7 +1666,10 @@ def as_dict(self) -> dict: if self.notebook_task: body['notebook_task'] = self.notebook_task.as_dict() if self.pipeline_task: body['pipeline_task'] = self.pipeline_task.as_dict() if self.python_wheel_task: body['python_wheel_task'] = self.python_wheel_task.as_dict() + if self.resolved_values: body['resolved_values'] = self.resolved_values.as_dict() if self.run_id is not None: body['run_id'] = self.run_id + if self.run_if is not None: body['run_if'] = self.run_if.value + if self.run_job_task: body['run_job_task'] = self.run_job_task.as_dict() if self.setup_duration is not None: body['setup_duration'] = self.setup_duration if self.spark_jar_task: body['spark_jar_task'] = self.spark_jar_task.as_dict() if self.spark_python_task: body['spark_python_task'] = self.spark_python_task.as_dict() @@ -1404,7 +1698,10 @@ def from_dict(cls, d: Dict[str, any]) -> 'RunTask': notebook_task=_from_dict(d, 'notebook_task', NotebookTask), pipeline_task=_from_dict(d, 'pipeline_task', PipelineTask), python_wheel_task=_from_dict(d, 'python_wheel_task', PythonWheelTask), + resolved_values=_from_dict(d, 'resolved_values', ResolvedValues), run_id=d.get('run_id', None), + run_if=_enum(d, 'run_if', RunIf), + run_job_task=_from_dict(d, 'run_job_task', RunJobTask), setup_duration=d.get('setup_duration', None), spark_jar_task=_from_dict(d, 'spark_jar_task', SparkJarTask), spark_python_task=_from_dict(d, 'spark_python_task', SparkPythonTask), @@ -1769,6 +2066,7 @@ def from_dict(cls, d: Dict[str, any]) -> 'SqlTaskSubscription': @dataclass class SubmitRun: access_control_list: Optional['List[iam.AccessControlRequest]'] = None + email_notifications: Optional['JobEmailNotifications'] = None git_source: Optional['GitSource'] = None idempotency_token: Optional[str] = None notification_settings: Optional['JobNotificationSettings'] = None @@ -1781,6 +2079,7 @@ def as_dict(self) -> dict: body = {} if self.access_control_list: body['access_control_list'] = [v.as_dict() for v in self.access_control_list] + if self.email_notifications: body['email_notifications'] = self.email_notifications.as_dict() if self.git_source: body['git_source'] = self.git_source.as_dict() if self.idempotency_token is not None: body['idempotency_token'] = self.idempotency_token if self.notification_settings: body['notification_settings'] = self.notification_settings.as_dict() @@ -1793,6 +2092,7 @@ def as_dict(self) -> dict: @classmethod def from_dict(cls, d: Dict[str, any]) -> 'SubmitRun': return cls(access_control_list=_repeated(d, 'access_control_list', iam.AccessControlRequest), + email_notifications=_from_dict(d, 'email_notifications', JobEmailNotifications), git_source=_from_dict(d, 'git_source', GitSource), idempotency_token=d.get('idempotency_token', None), notification_settings=_from_dict(d, 'notification_settings', JobNotificationSettings), @@ -1821,10 +2121,12 @@ class SubmitTask: task_key: str condition_task: Optional['ConditionTask'] = None depends_on: Optional['List[TaskDependency]'] = None + email_notifications: Optional['JobEmailNotifications'] = None existing_cluster_id: Optional[str] = None libraries: Optional['List[compute.Library]'] = None new_cluster: Optional['compute.ClusterSpec'] = None notebook_task: Optional['NotebookTask'] = None + notification_settings: Optional['TaskNotificationSettings'] = None pipeline_task: Optional['PipelineTask'] = None python_wheel_task: Optional['PythonWheelTask'] = None spark_jar_task: Optional['SparkJarTask'] = None @@ -1837,10 +2139,12 @@ def as_dict(self) -> dict: body = {} if self.condition_task: body['condition_task'] = self.condition_task.as_dict() if self.depends_on: body['depends_on'] = [v.as_dict() for v in self.depends_on] + if self.email_notifications: body['email_notifications'] = self.email_notifications.as_dict() if self.existing_cluster_id is not None: body['existing_cluster_id'] = self.existing_cluster_id if self.libraries: body['libraries'] = [v.as_dict() for v in self.libraries] if self.new_cluster: body['new_cluster'] = self.new_cluster.as_dict() if self.notebook_task: body['notebook_task'] = self.notebook_task.as_dict() + if self.notification_settings: body['notification_settings'] = self.notification_settings.as_dict() if self.pipeline_task: body['pipeline_task'] = self.pipeline_task.as_dict() if self.python_wheel_task: body['python_wheel_task'] = self.python_wheel_task.as_dict() if self.spark_jar_task: body['spark_jar_task'] = self.spark_jar_task.as_dict() @@ -1855,10 +2159,12 @@ def as_dict(self) -> dict: def from_dict(cls, d: Dict[str, any]) -> 'SubmitTask': return cls(condition_task=_from_dict(d, 'condition_task', ConditionTask), depends_on=_repeated(d, 'depends_on', TaskDependency), + email_notifications=_from_dict(d, 'email_notifications', JobEmailNotifications), existing_cluster_id=d.get('existing_cluster_id', None), libraries=_repeated(d, 'libraries', compute.Library), new_cluster=_from_dict(d, 'new_cluster', compute.ClusterSpec), notebook_task=_from_dict(d, 'notebook_task', NotebookTask), + notification_settings=_from_dict(d, 'notification_settings', TaskNotificationSettings), pipeline_task=_from_dict(d, 'pipeline_task', PipelineTask), python_wheel_task=_from_dict(d, 'python_wheel_task', PythonWheelTask), spark_jar_task=_from_dict(d, 'spark_jar_task', SparkJarTask), @@ -1889,6 +2195,8 @@ class Task: pipeline_task: Optional['PipelineTask'] = None python_wheel_task: Optional['PythonWheelTask'] = None retry_on_timeout: Optional[bool] = None + run_if: Optional['RunIf'] = None + run_job_task: Optional['RunJobTask'] = None spark_jar_task: Optional['SparkJarTask'] = None spark_python_task: Optional['SparkPythonTask'] = None spark_submit_task: Optional['SparkSubmitTask'] = None @@ -1915,6 +2223,8 @@ def as_dict(self) -> dict: if self.pipeline_task: body['pipeline_task'] = self.pipeline_task.as_dict() if self.python_wheel_task: body['python_wheel_task'] = self.python_wheel_task.as_dict() if self.retry_on_timeout is not None: body['retry_on_timeout'] = self.retry_on_timeout + if self.run_if is not None: body['run_if'] = self.run_if.value + if self.run_job_task: body['run_job_task'] = self.run_job_task.as_dict() if self.spark_jar_task: body['spark_jar_task'] = self.spark_jar_task.as_dict() if self.spark_python_task: body['spark_python_task'] = self.spark_python_task.as_dict() if self.spark_submit_task: body['spark_submit_task'] = self.spark_submit_task.as_dict() @@ -1942,6 +2252,8 @@ def from_dict(cls, d: Dict[str, any]) -> 'Task': pipeline_task=_from_dict(d, 'pipeline_task', PipelineTask), python_wheel_task=_from_dict(d, 'python_wheel_task', PythonWheelTask), retry_on_timeout=d.get('retry_on_timeout', None), + run_if=_enum(d, 'run_if', RunIf), + run_job_task=_from_dict(d, 'run_job_task', RunJobTask), spark_jar_task=_from_dict(d, 'spark_jar_task', SparkJarTask), spark_python_task=_from_dict(d, 'spark_python_task', SparkPythonTask), spark_submit_task=_from_dict(d, 'spark_submit_task', SparkSubmitTask), @@ -2048,6 +2360,20 @@ def from_dict(cls, d: Dict[str, any]) -> 'TriggerHistory': last_triggered=_from_dict(d, 'last_triggered', TriggerEvaluation)) +@dataclass +class TriggerInfo: + run_id: Optional[int] = None + + def as_dict(self) -> dict: + body = {} + if self.run_id is not None: body['run_id'] = self.run_id + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'TriggerInfo': + return cls(run_id=d.get('run_id', None)) + + @dataclass class TriggerSettings: file_arrival: Optional['FileArrivalTriggerConfiguration'] = None @@ -2072,6 +2398,7 @@ class TriggerType(Enum): ONE_TIME = 'ONE_TIME' PERIODIC = 'PERIODIC' RETRY = 'RETRY' + RUN_JOB_TASK = 'RUN_JOB_TASK' @dataclass @@ -2265,6 +2592,7 @@ def create(self, max_concurrent_runs: Optional[int] = None, name: Optional[str] = None, notification_settings: Optional[JobNotificationSettings] = None, + parameters: Optional[List[JobParameterDefinition]] = None, run_as: Optional[JobRunAs] = None, schedule: Optional[CronSchedule] = None, tags: Optional[Dict[str, str]] = None, @@ -2315,6 +2643,8 @@ def create(self, :param notification_settings: :class:`JobNotificationSettings` (optional) Optional notification settings that are used when sending notifications to each of the `email_notifications` and `webhook_notifications` for this job. + :param parameters: List[:class:`JobParameterDefinition`] (optional) + Job-level parameter definitions :param run_as: :class:`JobRunAs` (optional) Write-only setting, available only in Create/Update/Reset and Submit calls. Specifies the user or service principal that the job runs as. If not specified, the job runs as the user who created the @@ -2355,6 +2685,7 @@ def create(self, max_concurrent_runs=max_concurrent_runs, name=name, notification_settings=notification_settings, + parameters=parameters, run_as=run_as, schedule=schedule, tags=tags, @@ -2511,8 +2842,8 @@ def list(self, :param expand_tasks: bool (optional) Whether to include task and cluster details in the response. :param limit: int (optional) - The number of jobs to return. This value must be greater than 0 and less or equal to 25. The default - value is 20. + The number of jobs to return. This value must be greater than 0 and less or equal to 100. The + default value is 20. :param name: str (optional) A filter on the list based on the exact (case insensitive) job name. :param offset: int (optional) @@ -2645,6 +2976,7 @@ def repair_run(self, python_named_params: Optional[Dict[str, str]] = None, python_params: Optional[List[str]] = None, rerun_all_failed_tasks: Optional[bool] = None, + rerun_dependent_tasks: Optional[bool] = None, rerun_tasks: Optional[List[str]] = None, spark_submit_params: Optional[List[str]] = None, sql_params: Optional[Dict[str, str]] = None, @@ -2707,7 +3039,10 @@ def repair_run(self, [Task parameter variables]: https://docs.databricks.com/jobs.html#parameter-variables :param rerun_all_failed_tasks: bool (optional) - If true, repair all failed tasks. Only one of rerun_tasks or rerun_all_failed_tasks can be used. + If true, repair all failed tasks. Only one of `rerun_tasks` or `rerun_all_failed_tasks` can be used. + :param rerun_dependent_tasks: bool (optional) + If true, repair all tasks that depend on the tasks in `rerun_tasks`, even if they were previously + successful. Can be also used in combination with `rerun_all_failed_tasks`. :param rerun_tasks: List[str] (optional) The task keys of the task runs to repair. :param spark_submit_params: List[str] (optional) @@ -2744,6 +3079,7 @@ def repair_run(self, python_named_params=python_named_params, python_params=python_params, rerun_all_failed_tasks=rerun_all_failed_tasks, + rerun_dependent_tasks=rerun_dependent_tasks, rerun_tasks=rerun_tasks, run_id=run_id, spark_submit_params=spark_submit_params, @@ -2766,6 +3102,7 @@ def repair_run_and_wait( python_named_params: Optional[Dict[str, str]] = None, python_params: Optional[List[str]] = None, rerun_all_failed_tasks: Optional[bool] = None, + rerun_dependent_tasks: Optional[bool] = None, rerun_tasks: Optional[List[str]] = None, spark_submit_params: Optional[List[str]] = None, sql_params: Optional[Dict[str, str]] = None, @@ -2778,6 +3115,7 @@ def repair_run_and_wait( python_named_params=python_named_params, python_params=python_params, rerun_all_failed_tasks=rerun_all_failed_tasks, + rerun_dependent_tasks=rerun_dependent_tasks, rerun_tasks=rerun_tasks, run_id=run_id, spark_submit_params=spark_submit_params, @@ -2811,6 +3149,7 @@ def run_now(self, dbt_commands: Optional[List[str]] = None, idempotency_token: Optional[str] = None, jar_params: Optional[List[str]] = None, + job_parameters: Optional[List[Dict[str, str]]] = None, notebook_params: Optional[Dict[str, str]] = None, pipeline_params: Optional[PipelineParams] = None, python_named_params: Optional[Dict[str, str]] = None, @@ -2849,6 +3188,8 @@ def run_now(self, Use [Task parameter variables](/jobs.html"#parameter-variables") to set parameters containing information about job runs. + :param job_parameters: List[Dict[str,str]] (optional) + Job-level parameters used in the run :param notebook_params: Dict[str,str] (optional) A map from keys to values for jobs with notebook task, for example `\"notebook_params\": {\"name\": \"john doe\", \"age\": \"35\"}`. The map is passed to the notebook and is accessible through the @@ -2914,6 +3255,7 @@ def run_now(self, idempotency_token=idempotency_token, jar_params=jar_params, job_id=job_id, + job_parameters=job_parameters, notebook_params=notebook_params, pipeline_params=pipeline_params, python_named_params=python_named_params, @@ -2932,6 +3274,7 @@ def run_now_and_wait(self, dbt_commands: Optional[List[str]] = None, idempotency_token: Optional[str] = None, jar_params: Optional[List[str]] = None, + job_parameters: Optional[List[Dict[str, str]]] = None, notebook_params: Optional[Dict[str, str]] = None, pipeline_params: Optional[PipelineParams] = None, python_named_params: Optional[Dict[str, str]] = None, @@ -2943,6 +3286,7 @@ def run_now_and_wait(self, idempotency_token=idempotency_token, jar_params=jar_params, job_id=job_id, + job_parameters=job_parameters, notebook_params=notebook_params, pipeline_params=pipeline_params, python_named_params=python_named_params, @@ -2953,6 +3297,7 @@ def run_now_and_wait(self, def submit(self, *, access_control_list: Optional[List[iam.AccessControlRequest]] = None, + email_notifications: Optional[JobEmailNotifications] = None, git_source: Optional[GitSource] = None, idempotency_token: Optional[str] = None, notification_settings: Optional[JobNotificationSettings] = None, @@ -2969,6 +3314,9 @@ def submit(self, :param access_control_list: List[:class:`AccessControlRequest`] (optional) List of permissions to set on the job. + :param email_notifications: :class:`JobEmailNotifications` (optional) + An optional set of email addresses notified when the run begins or completes. The default behavior + is to not send any emails. :param git_source: :class:`GitSource` (optional) An optional specification for a remote repository containing the notebooks used by this job's notebook tasks. @@ -3004,6 +3352,7 @@ def submit(self, request = kwargs.get('request', None) if not request: # request is not given through keyed args request = SubmitRun(access_control_list=access_control_list, + email_notifications=email_notifications, git_source=git_source, idempotency_token=idempotency_token, notification_settings=notification_settings, @@ -3021,6 +3370,7 @@ def submit_and_wait( self, *, access_control_list: Optional[List[iam.AccessControlRequest]] = None, + email_notifications: Optional[JobEmailNotifications] = None, git_source: Optional[GitSource] = None, idempotency_token: Optional[str] = None, notification_settings: Optional[JobNotificationSettings] = None, @@ -3030,6 +3380,7 @@ def submit_and_wait( webhook_notifications: Optional[WebhookNotifications] = None, timeout=timedelta(minutes=20)) -> Run: return self.submit(access_control_list=access_control_list, + email_notifications=email_notifications, git_source=git_source, idempotency_token=idempotency_token, notification_settings=notification_settings, diff --git a/databricks/sdk/service/ml.py b/databricks/sdk/service/ml.py index 8f9fa984e..45b8b4c25 100755 --- a/databricks/sdk/service/ml.py +++ b/databricks/sdk/service/ml.py @@ -554,10 +554,10 @@ class DeleteTransitionRequestRequest: class DeleteTransitionRequestStage(Enum): - Archived = 'Archived' - None_ = 'None' - Production = 'Production' - Staging = 'Staging' + ARCHIVED = 'Archived' + NONE = 'None' + PRODUCTION = 'Production' + STAGING = 'Staging' @dataclass @@ -1924,10 +1924,10 @@ def from_dict(cls, d: Dict[str, any]) -> 'SetTag': class Stage(Enum): """This describes an enum""" - Archived = 'Archived' - None_ = 'None' - Production = 'Production' - Staging = 'Staging' + ARCHIVED = 'Archived' + NONE = 'None' + PRODUCTION = 'Production' + STAGING = 'Staging' class Status(Enum): diff --git a/databricks/sdk/service/provisioning.py b/databricks/sdk/service/provisioning.py index 2c1bb99ba..b911e4500 100755 --- a/databricks/sdk/service/provisioning.py +++ b/databricks/sdk/service/provisioning.py @@ -431,11 +431,11 @@ class ErrorType(Enum): """The AWS resource associated with this error: credentials, VPC, subnet, security group, or network ACL.""" - credentials = 'credentials' - networkAcl = 'networkAcl' - securityGroup = 'securityGroup' - subnet = 'subnet' - vpc = 'vpc' + CREDENTIALS = 'credentials' + NETWORK_ACL = 'networkAcl' + SECURITY_GROUP = 'securityGroup' + SUBNET = 'subnet' + VPC = 'vpc' @dataclass @@ -988,8 +988,8 @@ class VpcStatus(Enum): class WarningType(Enum): """The AWS resource associated with this warning: a subnet or a security group.""" - securityGroup = 'securityGroup' - subnet = 'subnet' + SECURITY_GROUP = 'securityGroup' + SUBNET = 'subnet' @dataclass diff --git a/databricks/sdk/service/serving.py b/databricks/sdk/service/serving.py index 960d68ee4..690e66c41 100755 --- a/databricks/sdk/service/serving.py +++ b/databricks/sdk/service/serving.py @@ -256,7 +256,7 @@ class ServedModelInput: model_version: str workload_size: str scale_to_zero_enabled: bool - environment_vars: Optional[Any] = None + environment_vars: Optional['Dict[str,str]'] = None name: Optional[str] = None def as_dict(self) -> dict: @@ -283,7 +283,7 @@ def from_dict(cls, d: Dict[str, any]) -> 'ServedModelInput': class ServedModelOutput: creation_timestamp: Optional[int] = None creator: Optional[str] = None - environment_vars: Optional[Any] = None + environment_vars: Optional['Dict[str,str]'] = None model_name: Optional[str] = None model_version: Optional[str] = None name: Optional[str] = None @@ -484,13 +484,13 @@ def from_dict(cls, d: Dict[str, any]) -> 'TrafficConfig': class ServingEndpointsAPI: """The Serving Endpoints API allows you to create, update, and delete model serving endpoints. - You can use a serving endpoint to serve models from the Databricks Model Registry. Endpoints expose the - underlying models as scalable REST API endpoints using serverless compute. This means the endpoints and - associated compute resources are fully managed by Databricks and will not appear in your cloud account. A - serving endpoint can consist of one or more MLflow models from the Databricks Model Registry, called - served models. A serving endpoint can have at most ten served models. You can configure traffic settings - to define how requests should be routed to your served models behind an endpoint. Additionally, you can - configure the scale of resources that should be applied to each served model.""" + You can use a serving endpoint to serve models from the Databricks Model Registry or from Unity Catalog. + Endpoints expose the underlying models as scalable REST API endpoints using serverless compute. This means + the endpoints and associated compute resources are fully managed by Databricks and will not appear in your + cloud account. A serving endpoint can consist of one or more MLflow models from the Databricks Model + Registry, called served models. A serving endpoint can have at most ten served models. You can configure + traffic settings to define how requests should be routed to your served models behind an endpoint. + Additionally, you can configure the scale of resources that should be applied to each served model.""" def __init__(self, api_client): self._api = api_client diff --git a/databricks/sdk/service/settings.py b/databricks/sdk/service/settings.py index c9e48e8fb..9cc0aa7af 100755 --- a/databricks/sdk/service/settings.py +++ b/databricks/sdk/service/settings.py @@ -133,7 +133,7 @@ class DeleteIpAccessListRequest: class DeletePersonalComputeSettingRequest: """Delete Personal Compute setting""" - etag: Optional[str] = None + etag: str @dataclass @@ -312,7 +312,11 @@ def from_dict(cls, d: Dict[str, any]) -> 'PersonalComputeMessage': class PersonalComputeMessageEnum(Enum): - """TBD""" + """ON: Grants all users in all workspaces access to the Personal Compute default policy, allowing + all users to create single-machine compute resources. DELEGATE: Moves access control for the + Personal Compute default policy to individual workspaces and requires a workspace’s users or + groups to be added to the ACLs of that workspace’s Personal Compute default policy before they + will be able to create compute resources through that policy.""" DELEGATE = 'DELEGATE' ON = 'ON' @@ -365,7 +369,7 @@ def from_dict(cls, d: Dict[str, any]) -> 'PublicTokenInfo': class ReadPersonalComputeSettingRequest: """Get Personal Compute setting""" - etag: Optional[str] = None + etag: str @dataclass @@ -708,21 +712,28 @@ def update(self, class AccountSettingsAPI: - """TBD""" + """The Personal Compute enablement setting lets you control which users can use the Personal Compute default + policy to create compute resources. By default all users in all workspaces have access (ON), but you can + change the setting to instead let individual workspaces configure access control (DELEGATE). + + There is only one instance of this setting per account. Since this setting has a default value, this + setting is present on all accounts even though it's never set on a given account. Deletion reverts the + value of the setting back to the default value.""" def __init__(self, api_client): self._api = api_client - def delete_personal_compute_setting(self, - *, - etag: Optional[str] = None, - **kwargs) -> DeletePersonalComputeSettingResponse: + def delete_personal_compute_setting(self, etag: str, **kwargs) -> DeletePersonalComputeSettingResponse: """Delete Personal Compute setting. - TBD + Reverts back the Personal Compute setting value to default (ON) - :param etag: str (optional) - TBD + :param etag: str + etag used for versioning. The response is at least as fresh as the eTag provided. This is used for + optimistic concurrency control as a way to help prevent simultaneous writes of a setting overwriting + each other. It is strongly suggested that systems make use of the etag in the read -> delete pattern + to perform setting deletions in order to avoid race conditions. That is, get an etag from a GET + request, and pass it with the DELETE request to identify the rule set version you are deleting. :returns: :class:`DeletePersonalComputeSettingResponse` """ @@ -739,16 +750,17 @@ def delete_personal_compute_setting(self, query=query) return DeletePersonalComputeSettingResponse.from_dict(json) - def read_personal_compute_setting(self, - *, - etag: Optional[str] = None, - **kwargs) -> PersonalComputeSetting: + def read_personal_compute_setting(self, etag: str, **kwargs) -> PersonalComputeSetting: """Get Personal Compute setting. - TBD + Gets the value of the Personal Compute setting. - :param etag: str (optional) - TBD + :param etag: str + etag used for versioning. The response is at least as fresh as the eTag provided. This is used for + optimistic concurrency control as a way to help prevent simultaneous writes of a setting overwriting + each other. It is strongly suggested that systems make use of the etag in the read -> delete pattern + to perform setting deletions in order to avoid race conditions. That is, get an etag from a GET + request, and pass it with the DELETE request to identify the rule set version you are deleting. :returns: :class:`PersonalComputeSetting` """ @@ -772,10 +784,10 @@ def update_personal_compute_setting(self, **kwargs) -> PersonalComputeSetting: """Update Personal Compute setting. - TBD + Updates the value of the Personal Compute setting. :param allow_missing: bool (optional) - TBD + This should always be set to true for Settings RPCs. Added for AIP compliance. :param setting: :class:`PersonalComputeSetting` (optional) :returns: :class:`PersonalComputeSetting` diff --git a/databricks/sdk/service/sharing.py b/databricks/sdk/service/sharing.py index 781b77240..b9f530af3 100755 --- a/databricks/sdk/service/sharing.py +++ b/databricks/sdk/service/sharing.py @@ -21,6 +21,303 @@ class AuthenticationType(Enum): TOKEN = 'TOKEN' +@dataclass +class CentralCleanRoomInfo: + clean_room_assets: Optional['List[CleanRoomAssetInfo]'] = None + collaborators: Optional['List[CleanRoomCollaboratorInfo]'] = None + creator: Optional['CleanRoomCollaboratorInfo'] = None + station_cloud: Optional[str] = None + station_region: Optional[str] = None + + def as_dict(self) -> dict: + body = {} + if self.clean_room_assets: body['clean_room_assets'] = [v.as_dict() for v in self.clean_room_assets] + if self.collaborators: body['collaborators'] = [v.as_dict() for v in self.collaborators] + if self.creator: body['creator'] = self.creator.as_dict() + if self.station_cloud is not None: body['station_cloud'] = self.station_cloud + if self.station_region is not None: body['station_region'] = self.station_region + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'CentralCleanRoomInfo': + return cls(clean_room_assets=_repeated(d, 'clean_room_assets', CleanRoomAssetInfo), + collaborators=_repeated(d, 'collaborators', CleanRoomCollaboratorInfo), + creator=_from_dict(d, 'creator', CleanRoomCollaboratorInfo), + station_cloud=d.get('station_cloud', None), + station_region=d.get('station_region', None)) + + +@dataclass +class CleanRoomAssetInfo: + added_at: Optional[int] = None + notebook_info: Optional['CleanRoomNotebookInfo'] = None + owner: Optional['CleanRoomCollaboratorInfo'] = None + table_info: Optional['CleanRoomTableInfo'] = None + updated_at: Optional[int] = None + + def as_dict(self) -> dict: + body = {} + if self.added_at is not None: body['added_at'] = self.added_at + if self.notebook_info: body['notebook_info'] = self.notebook_info.as_dict() + if self.owner: body['owner'] = self.owner.as_dict() + if self.table_info: body['table_info'] = self.table_info.as_dict() + if self.updated_at is not None: body['updated_at'] = self.updated_at + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'CleanRoomAssetInfo': + return cls(added_at=d.get('added_at', None), + notebook_info=_from_dict(d, 'notebook_info', CleanRoomNotebookInfo), + owner=_from_dict(d, 'owner', CleanRoomCollaboratorInfo), + table_info=_from_dict(d, 'table_info', CleanRoomTableInfo), + updated_at=d.get('updated_at', None)) + + +@dataclass +class CleanRoomCatalog: + catalog_name: Optional[str] = None + notebook_files: Optional['List[SharedDataObject]'] = None + tables: Optional['List[SharedDataObject]'] = None + + def as_dict(self) -> dict: + body = {} + if self.catalog_name is not None: body['catalog_name'] = self.catalog_name + if self.notebook_files: body['notebook_files'] = [v.as_dict() for v in self.notebook_files] + if self.tables: body['tables'] = [v.as_dict() for v in self.tables] + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'CleanRoomCatalog': + return cls(catalog_name=d.get('catalog_name', None), + notebook_files=_repeated(d, 'notebook_files', SharedDataObject), + tables=_repeated(d, 'tables', SharedDataObject)) + + +@dataclass +class CleanRoomCatalogUpdate: + catalog_name: Optional[str] = None + updates: Optional['SharedDataObjectUpdate'] = None + + def as_dict(self) -> dict: + body = {} + if self.catalog_name is not None: body['catalog_name'] = self.catalog_name + if self.updates: body['updates'] = self.updates.as_dict() + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'CleanRoomCatalogUpdate': + return cls(catalog_name=d.get('catalog_name', None), + updates=_from_dict(d, 'updates', SharedDataObjectUpdate)) + + +@dataclass +class CleanRoomCollaboratorInfo: + global_metastore_id: Optional[str] = None + organization_name: Optional[str] = None + + def as_dict(self) -> dict: + body = {} + if self.global_metastore_id is not None: body['global_metastore_id'] = self.global_metastore_id + if self.organization_name is not None: body['organization_name'] = self.organization_name + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'CleanRoomCollaboratorInfo': + return cls(global_metastore_id=d.get('global_metastore_id', None), + organization_name=d.get('organization_name', None)) + + +@dataclass +class CleanRoomInfo: + comment: Optional[str] = None + created_at: Optional[int] = None + created_by: Optional[str] = None + local_catalogs: Optional['List[CleanRoomCatalog]'] = None + name: Optional[str] = None + owner: Optional[str] = None + remote_detailed_info: Optional['CentralCleanRoomInfo'] = None + updated_at: Optional[int] = None + updated_by: Optional[str] = None + + def as_dict(self) -> dict: + body = {} + if self.comment is not None: body['comment'] = self.comment + if self.created_at is not None: body['created_at'] = self.created_at + if self.created_by is not None: body['created_by'] = self.created_by + if self.local_catalogs: body['local_catalogs'] = [v.as_dict() for v in self.local_catalogs] + if self.name is not None: body['name'] = self.name + if self.owner is not None: body['owner'] = self.owner + if self.remote_detailed_info: body['remote_detailed_info'] = self.remote_detailed_info.as_dict() + if self.updated_at is not None: body['updated_at'] = self.updated_at + if self.updated_by is not None: body['updated_by'] = self.updated_by + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'CleanRoomInfo': + return cls(comment=d.get('comment', None), + created_at=d.get('created_at', None), + created_by=d.get('created_by', None), + local_catalogs=_repeated(d, 'local_catalogs', CleanRoomCatalog), + name=d.get('name', None), + owner=d.get('owner', None), + remote_detailed_info=_from_dict(d, 'remote_detailed_info', CentralCleanRoomInfo), + updated_at=d.get('updated_at', None), + updated_by=d.get('updated_by', None)) + + +@dataclass +class CleanRoomNotebookInfo: + notebook_content: Optional[str] = None + notebook_name: Optional[str] = None + + def as_dict(self) -> dict: + body = {} + if self.notebook_content is not None: body['notebook_content'] = self.notebook_content + if self.notebook_name is not None: body['notebook_name'] = self.notebook_name + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'CleanRoomNotebookInfo': + return cls(notebook_content=d.get('notebook_content', None), + notebook_name=d.get('notebook_name', None)) + + +@dataclass +class CleanRoomTableInfo: + catalog_name: Optional[str] = None + columns: Optional['List[ColumnInfo]'] = None + full_name: Optional[str] = None + name: Optional[str] = None + schema_name: Optional[str] = None + + def as_dict(self) -> dict: + body = {} + if self.catalog_name is not None: body['catalog_name'] = self.catalog_name + if self.columns: body['columns'] = [v.as_dict() for v in self.columns] + if self.full_name is not None: body['full_name'] = self.full_name + if self.name is not None: body['name'] = self.name + if self.schema_name is not None: body['schema_name'] = self.schema_name + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'CleanRoomTableInfo': + return cls(catalog_name=d.get('catalog_name', None), + columns=_repeated(d, 'columns', ColumnInfo), + full_name=d.get('full_name', None), + name=d.get('name', None), + schema_name=d.get('schema_name', None)) + + +@dataclass +class ColumnInfo: + comment: Optional[str] = None + mask: Optional['ColumnMask'] = None + name: Optional[str] = None + nullable: Optional[bool] = None + partition_index: Optional[int] = None + position: Optional[int] = None + type_interval_type: Optional[str] = None + type_json: Optional[str] = None + type_name: Optional['ColumnTypeName'] = None + type_precision: Optional[int] = None + type_scale: Optional[int] = None + type_text: Optional[str] = None + + def as_dict(self) -> dict: + body = {} + if self.comment is not None: body['comment'] = self.comment + if self.mask: body['mask'] = self.mask.as_dict() + if self.name is not None: body['name'] = self.name + if self.nullable is not None: body['nullable'] = self.nullable + if self.partition_index is not None: body['partition_index'] = self.partition_index + if self.position is not None: body['position'] = self.position + if self.type_interval_type is not None: body['type_interval_type'] = self.type_interval_type + if self.type_json is not None: body['type_json'] = self.type_json + if self.type_name is not None: body['type_name'] = self.type_name.value + if self.type_precision is not None: body['type_precision'] = self.type_precision + if self.type_scale is not None: body['type_scale'] = self.type_scale + if self.type_text is not None: body['type_text'] = self.type_text + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'ColumnInfo': + return cls(comment=d.get('comment', None), + mask=_from_dict(d, 'mask', ColumnMask), + name=d.get('name', None), + nullable=d.get('nullable', None), + partition_index=d.get('partition_index', None), + position=d.get('position', None), + type_interval_type=d.get('type_interval_type', None), + type_json=d.get('type_json', None), + type_name=_enum(d, 'type_name', ColumnTypeName), + type_precision=d.get('type_precision', None), + type_scale=d.get('type_scale', None), + type_text=d.get('type_text', None)) + + +@dataclass +class ColumnMask: + function_name: Optional[str] = None + using_column_names: Optional['List[str]'] = None + + def as_dict(self) -> dict: + body = {} + if self.function_name is not None: body['function_name'] = self.function_name + if self.using_column_names: body['using_column_names'] = [v for v in self.using_column_names] + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'ColumnMask': + return cls(function_name=d.get('function_name', None), + using_column_names=d.get('using_column_names', None)) + + +class ColumnTypeName(Enum): + """Name of type (INT, STRUCT, MAP, etc.).""" + + ARRAY = 'ARRAY' + BINARY = 'BINARY' + BOOLEAN = 'BOOLEAN' + BYTE = 'BYTE' + CHAR = 'CHAR' + DATE = 'DATE' + DECIMAL = 'DECIMAL' + DOUBLE = 'DOUBLE' + FLOAT = 'FLOAT' + INT = 'INT' + INTERVAL = 'INTERVAL' + LONG = 'LONG' + MAP = 'MAP' + NULL = 'NULL' + SHORT = 'SHORT' + STRING = 'STRING' + STRUCT = 'STRUCT' + TABLE_TYPE = 'TABLE_TYPE' + TIMESTAMP = 'TIMESTAMP' + TIMESTAMP_NTZ = 'TIMESTAMP_NTZ' + USER_DEFINED_TYPE = 'USER_DEFINED_TYPE' + + +@dataclass +class CreateCleanRoom: + name: str + remote_detailed_info: 'CentralCleanRoomInfo' + comment: Optional[str] = None + + def as_dict(self) -> dict: + body = {} + if self.comment is not None: body['comment'] = self.comment + if self.name is not None: body['name'] = self.name + if self.remote_detailed_info: body['remote_detailed_info'] = self.remote_detailed_info.as_dict() + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'CreateCleanRoom': + return cls(comment=d.get('comment', None), + name=d.get('name', None), + remote_detailed_info=_from_dict(d, 'remote_detailed_info', CentralCleanRoomInfo)) + + @dataclass class CreateProvider: name: str @@ -96,6 +393,13 @@ def from_dict(cls, d: Dict[str, any]) -> 'CreateShare': return cls(comment=d.get('comment', None), name=d.get('name', None)) +@dataclass +class DeleteCleanRoomRequest: + """Delete a clean room""" + + name_arg: str + + @dataclass class DeleteProviderRequest: """Delete a provider""" @@ -124,6 +428,14 @@ class GetActivationUrlInfoRequest: activation_url: str +@dataclass +class GetCleanRoomRequest: + """Get a clean room""" + + name_arg: str + include_remote_details: Optional[bool] = None + + @dataclass class GetProviderRequest: """Get a provider""" @@ -174,6 +486,20 @@ def from_dict(cls, d: Dict[str, any]) -> 'IpAccessList': return cls(allowed_ip_addresses=d.get('allowed_ip_addresses', None)) +@dataclass +class ListCleanRoomsResponse: + clean_rooms: Optional['List[CleanRoomInfo]'] = None + + def as_dict(self) -> dict: + body = {} + if self.clean_rooms: body['clean_rooms'] = [v.as_dict() for v in self.clean_rooms] + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'ListCleanRoomsResponse': + return cls(clean_rooms=_repeated(d, 'clean_rooms', CleanRoomInfo)) + + @dataclass class ListProviderSharesResponse: shares: Optional['List[ProviderShare]'] = None @@ -670,6 +996,7 @@ class SharedDataObject: cdf_enabled: Optional[bool] = None comment: Optional[str] = None data_object_type: Optional[str] = None + history_data_sharing_status: Optional['SharedDataObjectHistoryDataSharingStatus'] = None partitions: Optional['List[Partition]'] = None shared_as: Optional[str] = None start_version: Optional[int] = None @@ -682,6 +1009,8 @@ def as_dict(self) -> dict: if self.cdf_enabled is not None: body['cdf_enabled'] = self.cdf_enabled if self.comment is not None: body['comment'] = self.comment if self.data_object_type is not None: body['data_object_type'] = self.data_object_type + if self.history_data_sharing_status is not None: + body['history_data_sharing_status'] = self.history_data_sharing_status.value if self.name is not None: body['name'] = self.name if self.partitions: body['partitions'] = [v.as_dict() for v in self.partitions] if self.shared_as is not None: body['shared_as'] = self.shared_as @@ -696,6 +1025,8 @@ def from_dict(cls, d: Dict[str, any]) -> 'SharedDataObject': cdf_enabled=d.get('cdf_enabled', None), comment=d.get('comment', None), data_object_type=d.get('data_object_type', None), + history_data_sharing_status=_enum(d, 'history_data_sharing_status', + SharedDataObjectHistoryDataSharingStatus), name=d.get('name', None), partitions=_repeated(d, 'partitions', Partition), shared_as=d.get('shared_as', None), @@ -703,6 +1034,14 @@ def from_dict(cls, d: Dict[str, any]) -> 'SharedDataObject': status=_enum(d, 'status', SharedDataObjectStatus)) +class SharedDataObjectHistoryDataSharingStatus(Enum): + """Whether to enable or disable sharing of data history. If not specified, the default is + **DISABLED**.""" + + DISABLED = 'DISABLED' + ENABLED = 'ENABLED' + + class SharedDataObjectStatus(Enum): """One of: **ACTIVE**, **PERMISSION_DENIED**.""" @@ -735,6 +1074,32 @@ class SharedDataObjectUpdateAction(Enum): UPDATE = 'UPDATE' +@dataclass +class UpdateCleanRoom: + catalog_updates: Optional['List[CleanRoomCatalogUpdate]'] = None + comment: Optional[str] = None + name: Optional[str] = None + name_arg: Optional[str] = None + owner: Optional[str] = None + + def as_dict(self) -> dict: + body = {} + if self.catalog_updates: body['catalog_updates'] = [v.as_dict() for v in self.catalog_updates] + if self.comment is not None: body['comment'] = self.comment + if self.name is not None: body['name'] = self.name + if self.name_arg is not None: body['name_arg'] = self.name_arg + if self.owner is not None: body['owner'] = self.owner + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'UpdateCleanRoom': + return cls(catalog_updates=_repeated(d, 'catalog_updates', CleanRoomCatalogUpdate), + comment=d.get('comment', None), + name=d.get('name', None), + name_arg=d.get('name_arg', None), + owner=d.get('owner', None)) + + @dataclass class UpdateProvider: comment: Optional[str] = None @@ -823,6 +1188,145 @@ def from_dict(cls, d: Dict[str, any]) -> 'UpdateSharePermissions': return cls(changes=_repeated(d, 'changes', catalog.PermissionsChange), name=d.get('name', None)) +class CleanRoomsAPI: + """A clean room is a secure, privacy-protecting environment where two or more parties can share sensitive + enterprise data, including customer data, for measurements, insights, activation and other use cases. + + To create clean rooms, you must be a metastore admin or a user with the **CREATE_CLEAN_ROOM** privilege.""" + + def __init__(self, api_client): + self._api = api_client + + def create(self, + name: str, + remote_detailed_info: CentralCleanRoomInfo, + *, + comment: Optional[str] = None, + **kwargs) -> CleanRoomInfo: + """Create a clean room. + + Creates a new clean room with specified colaborators. The caller must be a metastore admin or have the + **CREATE_CLEAN_ROOM** privilege on the metastore. + + :param name: str + Name of the clean room. + :param remote_detailed_info: :class:`CentralCleanRoomInfo` + Central clean room details. + :param comment: str (optional) + User-provided free-form text description. + + :returns: :class:`CleanRoomInfo` + """ + request = kwargs.get('request', None) + if not request: # request is not given through keyed args + request = CreateCleanRoom(comment=comment, name=name, remote_detailed_info=remote_detailed_info) + body = request.as_dict() + + json = self._api.do('POST', '/api/2.1/unity-catalog/clean-rooms', body=body) + return CleanRoomInfo.from_dict(json) + + def delete(self, name_arg: str, **kwargs): + """Delete a clean room. + + Deletes a data object clean room from the metastore. The caller must be an owner of the clean room. + + :param name_arg: str + The name of the clean room. + + + """ + request = kwargs.get('request', None) + if not request: # request is not given through keyed args + request = DeleteCleanRoomRequest(name_arg=name_arg) + + self._api.do('DELETE', f'/api/2.1/unity-catalog/clean-rooms/{request.name_arg}') + + def get(self, name_arg: str, *, include_remote_details: Optional[bool] = None, **kwargs) -> CleanRoomInfo: + """Get a clean room. + + Gets a data object clean room from the metastore. The caller must be a metastore admin or the owner of + the clean room. + + :param name_arg: str + The name of the clean room. + :param include_remote_details: bool (optional) + Whether to include remote details (central) on the clean room. + + :returns: :class:`CleanRoomInfo` + """ + request = kwargs.get('request', None) + if not request: # request is not given through keyed args + request = GetCleanRoomRequest(include_remote_details=include_remote_details, name_arg=name_arg) + + query = {} + if include_remote_details: query['include_remote_details'] = request.include_remote_details + + json = self._api.do('GET', f'/api/2.1/unity-catalog/clean-rooms/{request.name_arg}', query=query) + return CleanRoomInfo.from_dict(json) + + def list(self) -> Iterator[CleanRoomInfo]: + """List clean rooms. + + Gets an array of data object clean rooms from the metastore. The caller must be a metastore admin or + the owner of the clean room. There is no guarantee of a specific ordering of the elements in the + array. + + :returns: Iterator over :class:`CleanRoomInfo` + """ + + json = self._api.do('GET', '/api/2.1/unity-catalog/clean-rooms') + return [CleanRoomInfo.from_dict(v) for v in json.get('clean_rooms', [])] + + def update(self, + name_arg: str, + *, + catalog_updates: Optional[List[CleanRoomCatalogUpdate]] = None, + comment: Optional[str] = None, + name: Optional[str] = None, + owner: Optional[str] = None, + **kwargs) -> CleanRoomInfo: + """Update a clean room. + + Updates the clean room with the changes and data objects in the request. The caller must be the owner + of the clean room or a metastore admin. + + When the caller is a metastore admin, only the __owner__ field can be updated. + + In the case that the clean room name is changed **updateCleanRoom** requires that the caller is both + the clean room owner and a metastore admin. + + For each table that is added through this method, the clean room owner must also have **SELECT** + privilege on the table. The privilege must be maintained indefinitely for recipients to be able to + access the table. Typically, you should use a group as the clean room owner. + + Table removals through **update** do not require additional privileges. + + :param name_arg: str + The name of the clean room. + :param catalog_updates: List[:class:`CleanRoomCatalogUpdate`] (optional) + Array of shared data object updates. + :param comment: str (optional) + User-provided free-form text description. + :param name: str (optional) + Name of the clean room. + :param owner: str (optional) + Username of current owner of clean room. + + :returns: :class:`CleanRoomInfo` + """ + request = kwargs.get('request', None) + if not request: # request is not given through keyed args + request = UpdateCleanRoom(catalog_updates=catalog_updates, + comment=comment, + name=name, + name_arg=name_arg, + owner=owner) + body = request.as_dict() + + json = self._api.do('PATCH', f'/api/2.1/unity-catalog/clean-rooms/{request.name_arg}', body=body) + return CleanRoomInfo.from_dict(json) + + class ProvidersAPI: """Databricks Providers REST API""" diff --git a/databricks/sdk/service/sql.py b/databricks/sdk/service/sql.py index fa8af3019..261a0822b 100755 --- a/databricks/sdk/service/sql.py +++ b/databricks/sdk/service/sql.py @@ -44,7 +44,7 @@ class Alert: name: Optional[str] = None options: Optional['AlertOptions'] = None parent: Optional[str] = None - query: Optional['Query'] = None + query: Optional['AlertQuery'] = None rearm: Optional[int] = None state: Optional['AlertState'] = None updated_at: Optional[str] = None @@ -73,7 +73,7 @@ def from_dict(cls, d: Dict[str, any]) -> 'Alert': name=d.get('name', None), options=_from_dict(d, 'options', AlertOptions), parent=d.get('parent', None), - query=_from_dict(d, 'query', Query), + query=_from_dict(d, 'query', AlertQuery), rearm=d.get('rearm', None), state=_enum(d, 'state', AlertState), updated_at=d.get('updated_at', None), @@ -86,7 +86,7 @@ class AlertOptions: column: str op: str - value: str + value: Any custom_body: Optional[str] = None custom_subject: Optional[str] = None muted: Optional[bool] = None @@ -98,7 +98,7 @@ def as_dict(self) -> dict: if self.custom_subject is not None: body['custom_subject'] = self.custom_subject if self.muted is not None: body['muted'] = self.muted if self.op is not None: body['op'] = self.op - if self.value is not None: body['value'] = self.value + if self.value: body['value'] = self.value return body @classmethod @@ -111,13 +111,63 @@ def from_dict(cls, d: Dict[str, any]) -> 'AlertOptions': value=d.get('value', None)) +@dataclass +class AlertQuery: + created_at: Optional[str] = None + data_source_id: Optional[str] = None + description: Optional[str] = None + id: Optional[str] = None + is_archived: Optional[bool] = None + is_draft: Optional[bool] = None + is_safe: Optional[bool] = None + name: Optional[str] = None + options: Optional['QueryOptions'] = None + query: Optional[str] = None + tags: Optional['List[str]'] = None + updated_at: Optional[str] = None + user_id: Optional[int] = None + + def as_dict(self) -> dict: + body = {} + if self.created_at is not None: body['created_at'] = self.created_at + if self.data_source_id is not None: body['data_source_id'] = self.data_source_id + if self.description is not None: body['description'] = self.description + if self.id is not None: body['id'] = self.id + if self.is_archived is not None: body['is_archived'] = self.is_archived + if self.is_draft is not None: body['is_draft'] = self.is_draft + if self.is_safe is not None: body['is_safe'] = self.is_safe + if self.name is not None: body['name'] = self.name + if self.options: body['options'] = self.options.as_dict() + if self.query is not None: body['query'] = self.query + if self.tags: body['tags'] = [v for v in self.tags] + if self.updated_at is not None: body['updated_at'] = self.updated_at + if self.user_id is not None: body['user_id'] = self.user_id + return body + + @classmethod + def from_dict(cls, d: Dict[str, any]) -> 'AlertQuery': + return cls(created_at=d.get('created_at', None), + data_source_id=d.get('data_source_id', None), + description=d.get('description', None), + id=d.get('id', None), + is_archived=d.get('is_archived', None), + is_draft=d.get('is_draft', None), + is_safe=d.get('is_safe', None), + name=d.get('name', None), + options=_from_dict(d, 'options', QueryOptions), + query=d.get('query', None), + tags=d.get('tags', None), + updated_at=d.get('updated_at', None), + user_id=d.get('user_id', None)) + + class AlertState(Enum): """State of the alert. Possible values are: `unknown` (yet to be evaluated), `triggered` (evaluated and fulfilled trigger conditions), or `ok` (evaluated and did not fulfill trigger conditions).""" - ok = 'ok' - triggered = 'triggered' - unknown = 'unknown' + OK = 'ok' + TRIGGERED = 'triggered' + UNKNOWN = 'unknown' @dataclass @@ -162,7 +212,6 @@ def from_dict(cls, d: Dict[str, any]) -> 'ChannelInfo': class ChannelName(Enum): - """Name of the channel""" CHANNEL_NAME_CURRENT = 'CHANNEL_NAME_CURRENT' CHANNEL_NAME_CUSTOM = 'CHANNEL_NAME_CUSTOM' @@ -1182,8 +1231,8 @@ class ListDashboardsRequest: class ListOrder(Enum): - created_at = 'created_at' - name = 'name' + CREATED_AT = 'created_at' + NAME = 'name' @dataclass @@ -1273,19 +1322,19 @@ def from_dict(cls, d: Dict[str, any]) -> 'ListWarehousesResponse': class ObjectType(Enum): """A singular noun object type.""" - alert = 'alert' - dashboard = 'dashboard' - data_source = 'data_source' - query = 'query' + ALERT = 'alert' + DASHBOARD = 'dashboard' + DATA_SOURCE = 'data_source' + QUERY = 'query' class ObjectTypePlural(Enum): """Always a plural of the object type.""" - alerts = 'alerts' - dashboards = 'dashboards' - data_sources = 'data_sources' - queries = 'queries' + ALERTS = 'alerts' + DASHBOARDS = 'dashboards' + DATA_SOURCES = 'data_sources' + QUERIES = 'queries' @dataclass @@ -1314,9 +1363,9 @@ def from_dict(cls, d: Dict[str, any]) -> 'OdbcParams': class OwnableObjectType(Enum): """The singular form of the type of object which can be owned.""" - alert = 'alert' - dashboard = 'dashboard' - query = 'query' + ALERT = 'alert' + DASHBOARD = 'dashboard' + QUERY = 'query' @dataclass @@ -1345,9 +1394,9 @@ def from_dict(cls, d: Dict[str, any]) -> 'Parameter': class ParameterType(Enum): """Parameters can have several different types.""" - datetime = 'datetime' - number = 'number' - text = 'text' + DATETIME = 'datetime' + NUMBER = 'number' + TEXT = 'text' class PermissionLevel(Enum): @@ -2107,7 +2156,7 @@ def from_dict(cls, d: Dict[str, any]) -> 'Success': class SuccessMessage(Enum): - Success = 'Success' + SUCCESS = 'Success' @dataclass @@ -2293,26 +2342,18 @@ def from_dict(cls, d: Dict[str, any]) -> 'TransferOwnershipRequest': class User: email: Optional[str] = None id: Optional[int] = None - is_db_admin: Optional[bool] = None name: Optional[str] = None - profile_image_url: Optional[str] = None def as_dict(self) -> dict: body = {} if self.email is not None: body['email'] = self.email if self.id is not None: body['id'] = self.id - if self.is_db_admin is not None: body['is_db_admin'] = self.is_db_admin if self.name is not None: body['name'] = self.name - if self.profile_image_url is not None: body['profile_image_url'] = self.profile_image_url return body @classmethod def from_dict(cls, d: Dict[str, any]) -> 'User': - return cls(email=d.get('email', None), - id=d.get('id', None), - is_db_admin=d.get('is_db_admin', None), - name=d.get('name', None), - profile_image_url=d.get('profile_image_url', None)) + return cls(email=d.get('email', None), id=d.get('id', None), name=d.get('name', None)) @dataclass @@ -2459,9 +2500,9 @@ def create(self, :param options: :class:`AlertOptions` Alert configuration options. :param query_id: str - ID of the query evaluated by the alert. + Query ID. :param parent: str (optional) - The identifier of the workspace folder containing the alert. The default is ther user's home folder. + The identifier of the workspace folder containing the object. :param rearm: int (optional) Number of seconds after being triggered before the alert rearms itself and can be triggered again. If `null`, alert will never be triggered again. @@ -2536,7 +2577,7 @@ def update(self, :param options: :class:`AlertOptions` Alert configuration options. :param query_id: str - ID of the query evaluated by the alert. + Query ID. :param alert_id: str :param rearm: int (optional) Number of seconds after being triggered before the alert rearms itself and can be triggered again. @@ -2576,8 +2617,7 @@ def create(self, :param name: str (optional) The title of this dashboard that appears in list views and at the top of the dashboard page. :param parent: str (optional) - The identifier of the workspace folder containing the dashboard. The default is the user's home - folder. + The identifier of the workspace folder containing the object. :param tags: List[str] (optional) :returns: :class:`Dashboard` @@ -2841,19 +2881,19 @@ def create(self, **Note**: You cannot add a visualization until you create the query. :param data_source_id: str (optional) - The ID of the data source / SQL warehouse where this query will run. + Data source ID. :param description: str (optional) - General description that can convey additional information about this query such as usage notes. + General description that conveys additional information about this query such as usage notes. :param name: str (optional) - The name or title of this query to display in list views. + The title of this query that appears in list views, widget headings, and on the query page. :param options: Any (optional) Exclusively used for storing a list parameter definitions. A parameter is an object with `title`, `name`, `type`, and `value` properties. The `value` field here is the default value. It can be overridden at runtime. :param parent: str (optional) - The identifier of the workspace folder containing the query. The default is the user's home folder. + The identifier of the workspace folder containing the object. :param query: str (optional) - The text of the query. + The text of the query to be run. :returns: :class:`Query` """ @@ -2995,17 +3035,17 @@ def update(self, :param query_id: str :param data_source_id: str (optional) - The ID of the data source / SQL warehouse where this query will run. + Data source ID. :param description: str (optional) - General description that can convey additional information about this query such as usage notes. + General description that conveys additional information about this query such as usage notes. :param name: str (optional) - The name or title of this query to display in list views. + The title of this query that appears in list views, widget headings, and on the query page. :param options: Any (optional) Exclusively used for storing a list parameter definitions. A parameter is an object with `title`, `name`, `type`, and `value` properties. The `value` field here is the default value. It can be overridden at runtime. :param query: str (optional) - The text of the query. + The text of the query to be run. :returns: :class:`Query` """ diff --git a/docs/gen-client-docs.py b/docs/gen-client-docs.py index a84ffe3fb..96dd1017e 100644 --- a/docs/gen-client-docs.py +++ b/docs/gen-client-docs.py @@ -134,21 +134,26 @@ class Generator: def __init__(self): self.mapping = self._load_mapping() - def _load_mapping(self) -> dict[str, Tag]: - mapping = {} - pkgs = {p.name: p for p in self.packages} + def _spec_file(self) -> str: + if 'DATABRICKS_OPENAPI_SPEC' in os.environ: + return os.environ['DATABRICKS_OPENAPI_SPEC'] with open(os.path.expanduser('~/.openapi-codegen.json'), 'r') as f: config = json.load(f) if 'spec' not in config: raise ValueError('Cannot find OpenAPI spec') - with open(config['spec'], 'r') as fspec: - spec = json.load(fspec) - for tag in spec['tags']: - t = Tag(name=tag['name'], - service=tag['x-databricks-service'], - is_account=tag.get('x-databricks-is-accounts', False), - package=pkgs[tag['x-databricks-package']]) - mapping[tag['name']] = t + return config['spec'] + + def _load_mapping(self) -> dict[str, Tag]: + mapping = {} + pkgs = {p.name: p for p in self.packages} + with open(self._spec_file(), 'r') as fspec: + spec = json.load(fspec) + for tag in spec['tags']: + t = Tag(name=tag['name'], + service=tag['x-databricks-service'], + is_account=tag.get('x-databricks-is-accounts', False), + package=pkgs[tag['x-databricks-package']]) + mapping[tag['name']] = t return mapping def class_methods(self, inst) -> list[MethodDoc]: @@ -216,12 +221,12 @@ def _write_client_packages(self, folder: str, label: str, description: str, pack f.write(f''' {label} {'=' * len(label)} - + {description} - + .. toctree:: :maxdepth: 1 - + {all}''') def _write_client_package_doc(self, folder: str, pkg: Package, services: list[str]): @@ -230,12 +235,12 @@ def _write_client_package_doc(self, folder: str, pkg: Package, services: list[st f.write(f''' {pkg.label} {'=' * len(pkg.label)} - + {pkg.description} - + .. toctree:: :maxdepth: 1 - + {all}''') diff --git a/examples/clusters/ensure_cluster_is_running_commands_direct_usage.py b/examples/clusters/ensure_cluster_is_running_commands_direct_usage.py index 82fcfece3..4d93f3d3b 100755 --- a/examples/clusters/ensure_cluster_is_running_commands_direct_usage.py +++ b/examples/clusters/ensure_cluster_is_running_commands_direct_usage.py @@ -7,7 +7,7 @@ cluster_id = os.environ["TEST_DEFAULT_CLUSTER_ID"] -context = w.command_execution.create(cluster_id=cluster_id, language=compute.Language.python).result() +context = w.command_execution.create(cluster_id=cluster_id, language=compute.Language.PYTHON).result() w.clusters.ensure_cluster_is_running(cluster_id) diff --git a/examples/command_execution/create_commands_direct_usage.py b/examples/command_execution/create_commands_direct_usage.py index c5c8e3354..fc61e55f3 100755 --- a/examples/command_execution/create_commands_direct_usage.py +++ b/examples/command_execution/create_commands_direct_usage.py @@ -7,7 +7,7 @@ cluster_id = os.environ["TEST_DEFAULT_CLUSTER_ID"] -context = w.command_execution.create(cluster_id=cluster_id, language=compute.Language.python).result() +context = w.command_execution.create(cluster_id=cluster_id, language=compute.Language.PYTHON).result() # cleanup w.command_execution.destroy(cluster_id=cluster_id, context_id=context.id) diff --git a/examples/command_execution/execute_commands_direct_usage.py b/examples/command_execution/execute_commands_direct_usage.py index 15297b765..98fa13a19 100755 --- a/examples/command_execution/execute_commands_direct_usage.py +++ b/examples/command_execution/execute_commands_direct_usage.py @@ -7,11 +7,11 @@ cluster_id = os.environ["TEST_DEFAULT_CLUSTER_ID"] -context = w.command_execution.create(cluster_id=cluster_id, language=compute.Language.python).result() +context = w.command_execution.create(cluster_id=cluster_id, language=compute.Language.PYTHON).result() text_results = w.command_execution.execute(cluster_id=cluster_id, context_id=context.id, - language=compute.Language.python, + language=compute.Language.PYTHON, command="print(1)").result() # cleanup diff --git a/examples/command_execution/start_commands.py b/examples/command_execution/start_commands.py index 261f9fc02..9901a2f2c 100755 --- a/examples/command_execution/start_commands.py +++ b/examples/command_execution/start_commands.py @@ -7,4 +7,4 @@ cluster_id = os.environ["TEST_DEFAULT_CLUSTER_ID"] -command_context = w.command_execution.start(cluster_id, compute.Language.python) +command_context = w.command_execution.start(cluster_id, compute.Language.PYTHON) diff --git a/examples/users/list_users.py b/examples/users/list_users.py index 87c66adf8..fbc477cbd 100755 --- a/examples/users/list_users.py +++ b/examples/users/list_users.py @@ -5,4 +5,4 @@ all_users = w.users.list(attributes="id,userName", sort_by="userName", - sort_order=iam.ListSortOrder.descending) + sort_order=iam.ListSortOrder.DESCENDING) diff --git a/setup.py b/setup.py index 35d592328..7f7106784 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,6 @@ -import io, pathlib +import io +import pathlib + from setuptools import setup, find_packages version_data = {} @@ -12,7 +14,9 @@ python_requires=">=3.7", install_requires=["requests>=2.28.1,<3"], extras_require={"dev": ["pytest", "pytest-cov", "pytest-xdist", "pytest-mock", - "yapf", "pycodestyle", "autoflake", "isort", "wheel"]}, + "yapf", "pycodestyle", "autoflake", "isort", "wheel", + "ipython", "ipywidgets"], + "notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"]}, author="Serge Smertin", author_email="serge.smertin@databricks.com", description="Databricks SDK for Python (Beta)", diff --git a/tests/test_core.py b/tests/test_core.py index f00c0b740..15ee7f456 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,209 +1,225 @@ -import os -import pathlib -import platform -import random -import string - -import pytest - -from databricks.sdk.core import (Config, CredentialsProvider, - DatabricksCliTokenSource, HeaderFactory, - databricks_cli) -from databricks.sdk.version import __version__ - - -def test_parse_dsn(): - cfg = Config.parse_dsn('databricks://user:pass@foo.databricks.com?retry_timeout_seconds=600') - - headers = cfg.authenticate() - - assert headers['Authorization'] == 'Basic dXNlcjpwYXNz' - assert 'basic' == cfg.auth_type - - -def test_databricks_cli_token_source_relative_path(config): - config.databricks_cli_path = "./relative/path/to/cli" - ts = DatabricksCliTokenSource(config) - assert ts._cmd[0] == config.databricks_cli_path - - -def test_databricks_cli_token_source_absolute_path(config): - config.databricks_cli_path = "/absolute/path/to/cli" - ts = DatabricksCliTokenSource(config) - assert ts._cmd[0] == config.databricks_cli_path - - -def test_databricks_cli_token_source_not_installed(config, monkeypatch): - monkeypatch.setenv('PATH', 'whatever') - with pytest.raises(FileNotFoundError, match="not installed"): - DatabricksCliTokenSource(config) - - -def write_small_dummy_executable(path: pathlib.Path): - cli = path.joinpath('databricks') - cli.write_text('#!/bin/sh\necho "hello world"\n') - cli.chmod(0o755) - assert cli.stat().st_size < 1024 - return cli - - -def write_large_dummy_executable(path: pathlib.Path): - cli = path.joinpath('databricks') - - # Generate a long random string to inflate the file size. - random_string = ''.join(random.choice(string.ascii_letters) for i in range(1024 * 1024)) - cli.write_text("""#!/bin/sh -cat <= (1024 * 1024) - return cli - - -def test_databricks_cli_token_source_installed_legacy(config, monkeypatch, tmp_path): - write_small_dummy_executable(tmp_path) - monkeypatch.setenv('PATH', tmp_path.as_posix()) - with pytest.raises(FileNotFoundError, match="version <0.100.0 detected"): - DatabricksCliTokenSource(config) - - -def test_databricks_cli_token_source_installed_legacy_with_symlink(config, monkeypatch, tmp_path): - dir1 = tmp_path.joinpath('dir1') - dir2 = tmp_path.joinpath('dir2') - dir1.mkdir() - dir2.mkdir() - - (dir1 / "databricks").symlink_to(write_small_dummy_executable(dir2)) - - monkeypatch.setenv('PATH', dir1.as_posix()) - with pytest.raises(FileNotFoundError, match="version <0.100.0 detected"): - DatabricksCliTokenSource(config) - - -def test_databricks_cli_token_source_installed_new(config, monkeypatch, tmp_path): - write_large_dummy_executable(tmp_path) - monkeypatch.setenv('PATH', tmp_path.as_posix()) - DatabricksCliTokenSource(config) - - -def test_databricks_cli_token_source_installed_both(config, monkeypatch, tmp_path): - dir1 = tmp_path.joinpath('dir1') - dir2 = tmp_path.joinpath('dir2') - dir1.mkdir() - dir2.mkdir() - - write_small_dummy_executable(dir1) - write_large_dummy_executable(dir2) - - # Resolve small before large. - monkeypatch.setenv('PATH', str(os.pathsep).join([dir1.as_posix(), dir2.as_posix()])) - DatabricksCliTokenSource(config) - - # Resolve large before small. - monkeypatch.setenv('PATH', str(os.pathsep).join([dir2.as_posix(), dir1.as_posix()])) - DatabricksCliTokenSource(config) - - -def test_databricks_cli_credential_provider_not_installed(config, monkeypatch): - monkeypatch.setenv('PATH', 'whatever') - assert databricks_cli(config) == None - - -def test_databricks_cli_credential_provider_installed_legacy(config, monkeypatch, tmp_path): - write_small_dummy_executable(tmp_path) - monkeypatch.setenv('PATH', tmp_path.as_posix()) - assert databricks_cli(config) == None - - -def test_databricks_cli_credential_provider_installed_new(config, monkeypatch, tmp_path): - write_large_dummy_executable(tmp_path) - monkeypatch.setenv('PATH', str(os.pathsep).join([tmp_path.as_posix(), os.environ['PATH']])) - assert databricks_cli(config) is not None - - -def test_extra_and_upstream_user_agent(monkeypatch): - - class MockUname: - - @property - def system(self): - return 'TestOS' - - monkeypatch.setattr(platform, 'python_version', lambda: '3.0.0') - monkeypatch.setattr(platform, 'uname', MockUname) - monkeypatch.setenv('DATABRICKS_SDK_UPSTREAM', "upstream-product") - monkeypatch.setenv('DATABRICKS_SDK_UPSTREAM_VERSION', "0.0.1") - - config = Config(host='http://localhost', username="something", password="something", product='test', - product_version='0.0.0') \ - .with_user_agent_extra('test-extra-1', '1') \ - .with_user_agent_extra('test-extra-2', '2') - - assert config.user_agent == ( - f"test/0.0.0 databricks-sdk-py/{__version__} python/3.0.0 os/testos auth/basic" - f" test-extra-1/1 test-extra-2/2 upstream/upstream-product upstream-version/0.0.1") - - -def test_config_copy_shallow_copies_credential_provider(): - - class TestCredentialsProvider(CredentialsProvider): - - def __init__(self): - super().__init__() - self._token = "token1" - - def auth_type(self) -> str: - return "test" - - def __call__(self, cfg: 'Config') -> HeaderFactory: - return lambda: {"token": self._token} - - def refresh(self): - self._token = "token2" - - credential_provider = TestCredentialsProvider() - config = Config(credentials_provider=credential_provider) - config_copy = config.copy() - - assert config.authenticate()["token"] == "token1" - assert config_copy.authenticate()["token"] == "token1" - - credential_provider.refresh() - - assert config.authenticate()["token"] == "token2" - assert config_copy.authenticate()["token"] == "token2" - assert config._credentials_provider == config_copy._credentials_provider - - -def test_config_copy_deep_copies_user_agent_other_info(config): - config_copy = config.copy() - - config.with_user_agent_extra("test", "test1") - assert "test/test1" not in config_copy.user_agent - assert "test/test1" in config.user_agent - - config_copy.with_user_agent_extra("test", "test2") - assert "test/test2" in config_copy.user_agent - assert "test/test2" not in config.user_agent - - -def test_config_accounts_aws_is_accounts_host(config): - config.host = "https://accounts.cloud.databricks.com" - assert config.is_account_client - - -def test_config_accounts_dod_is_accounts_host(config): - config.host = "https://accounts-dod.cloud.databricks.us" - assert config.is_account_client - - -def test_config_workspace_is_not_accounts_host(config): - config.host = "https://westeurope.azuredatabricks.net" - assert not config.is_account_client +import os +import pathlib +import platform +import random +import string + +import pytest + +from databricks.sdk.core import (Config, CredentialsProvider, + DatabricksCliTokenSource, HeaderFactory, + databricks_cli) +from databricks.sdk.version import __version__ + + +def test_parse_dsn(): + cfg = Config.parse_dsn('databricks://user:pass@foo.databricks.com?retry_timeout_seconds=600') + + headers = cfg.authenticate() + + assert headers['Authorization'] == 'Basic dXNlcjpwYXNz' + assert 'basic' == cfg.auth_type + + +def test_databricks_cli_token_source_relative_path(config): + config.databricks_cli_path = "./relative/path/to/cli" + ts = DatabricksCliTokenSource(config) + assert ts._cmd[0] == config.databricks_cli_path + + +def test_databricks_cli_token_source_absolute_path(config): + config.databricks_cli_path = "/absolute/path/to/cli" + ts = DatabricksCliTokenSource(config) + assert ts._cmd[0] == config.databricks_cli_path + + +def test_databricks_cli_token_source_not_installed(config, monkeypatch): + monkeypatch.setenv('PATH', 'whatever') + with pytest.raises(FileNotFoundError, match="not installed"): + DatabricksCliTokenSource(config) + + +def write_small_dummy_executable(path: pathlib.Path): + cli = path.joinpath('databricks') + cli.write_text('#!/bin/sh\necho "hello world"\n') + cli.chmod(0o755) + assert cli.stat().st_size < 1024 + return cli + + +def write_large_dummy_executable(path: pathlib.Path): + cli = path.joinpath('databricks') + + # Generate a long random string to inflate the file size. + random_string = ''.join(random.choice(string.ascii_letters) for i in range(1024 * 1024)) + cli.write_text("""#!/bin/sh +cat <= (1024 * 1024) + return cli + + +def test_databricks_cli_token_source_installed_legacy(config, monkeypatch, tmp_path): + write_small_dummy_executable(tmp_path) + monkeypatch.setenv('PATH', tmp_path.as_posix()) + with pytest.raises(FileNotFoundError, match="version <0.100.0 detected"): + DatabricksCliTokenSource(config) + + +def test_databricks_cli_token_source_installed_legacy_with_symlink(config, monkeypatch, tmp_path): + dir1 = tmp_path.joinpath('dir1') + dir2 = tmp_path.joinpath('dir2') + dir1.mkdir() + dir2.mkdir() + + (dir1 / "databricks").symlink_to(write_small_dummy_executable(dir2)) + + monkeypatch.setenv('PATH', dir1.as_posix()) + with pytest.raises(FileNotFoundError, match="version <0.100.0 detected"): + DatabricksCliTokenSource(config) + + +def test_databricks_cli_token_source_installed_new(config, monkeypatch, tmp_path): + write_large_dummy_executable(tmp_path) + monkeypatch.setenv('PATH', tmp_path.as_posix()) + DatabricksCliTokenSource(config) + + +def test_databricks_cli_token_source_installed_both(config, monkeypatch, tmp_path): + dir1 = tmp_path.joinpath('dir1') + dir2 = tmp_path.joinpath('dir2') + dir1.mkdir() + dir2.mkdir() + + write_small_dummy_executable(dir1) + write_large_dummy_executable(dir2) + + # Resolve small before large. + monkeypatch.setenv('PATH', str(os.pathsep).join([dir1.as_posix(), dir2.as_posix()])) + DatabricksCliTokenSource(config) + + # Resolve large before small. + monkeypatch.setenv('PATH', str(os.pathsep).join([dir2.as_posix(), dir1.as_posix()])) + DatabricksCliTokenSource(config) + + +def test_databricks_cli_credential_provider_not_installed(config, monkeypatch): + monkeypatch.setenv('PATH', 'whatever') + assert databricks_cli(config) == None + + +def test_databricks_cli_credential_provider_installed_legacy(config, monkeypatch, tmp_path): + write_small_dummy_executable(tmp_path) + monkeypatch.setenv('PATH', tmp_path.as_posix()) + assert databricks_cli(config) == None + + +def test_databricks_cli_credential_provider_installed_new(config, monkeypatch, tmp_path): + write_large_dummy_executable(tmp_path) + monkeypatch.setenv('PATH', str(os.pathsep).join([tmp_path.as_posix(), os.environ['PATH']])) + assert databricks_cli(config) is not None + + +def test_extra_and_upstream_user_agent(monkeypatch): + + class MockUname: + + @property + def system(self): + return 'TestOS' + + monkeypatch.setattr(platform, 'python_version', lambda: '3.0.0') + monkeypatch.setattr(platform, 'uname', MockUname) + monkeypatch.setenv('DATABRICKS_SDK_UPSTREAM', "upstream-product") + monkeypatch.setenv('DATABRICKS_SDK_UPSTREAM_VERSION', "0.0.1") + + config = Config(host='http://localhost', username="something", password="something", product='test', + product_version='0.0.0') \ + .with_user_agent_extra('test-extra-1', '1') \ + .with_user_agent_extra('test-extra-2', '2') + + assert config.user_agent == ( + f"test/0.0.0 databricks-sdk-py/{__version__} python/3.0.0 os/testos auth/basic" + f" test-extra-1/1 test-extra-2/2 upstream/upstream-product upstream-version/0.0.1") + + +def test_config_copy_shallow_copies_credential_provider(): + + class TestCredentialsProvider(CredentialsProvider): + + def __init__(self): + super().__init__() + self._token = "token1" + + def auth_type(self) -> str: + return "test" + + def __call__(self, cfg: 'Config') -> HeaderFactory: + return lambda: {"token": self._token} + + def refresh(self): + self._token = "token2" + + credential_provider = TestCredentialsProvider() + config = Config(credentials_provider=credential_provider) + config_copy = config.copy() + + assert config.authenticate()["token"] == "token1" + assert config_copy.authenticate()["token"] == "token1" + + credential_provider.refresh() + + assert config.authenticate()["token"] == "token2" + assert config_copy.authenticate()["token"] == "token2" + assert config._credentials_provider == config_copy._credentials_provider + + +def test_config_copy_deep_copies_user_agent_other_info(config): + config_copy = config.copy() + + config.with_user_agent_extra("test", "test1") + assert "test/test1" not in config_copy.user_agent + assert "test/test1" in config.user_agent + + config_copy.with_user_agent_extra("test", "test2") + assert "test/test2" in config_copy.user_agent + assert "test/test2" not in config.user_agent + + +def test_config_accounts_aws_is_accounts_host(config): + config.host = "https://accounts.cloud.databricks.com" + assert config.is_account_client + + +def test_config_accounts_dod_is_accounts_host(config): + config.host = "https://accounts-dod.cloud.databricks.us" + assert config.is_account_client + + +def test_config_workspace_is_not_accounts_host(config): + config.host = "https://westeurope.azuredatabricks.net" + assert not config.is_account_client + + +def test_config_can_be_subclassed(): + + class DatabricksConfig(Config): + + def __init__(self): + super().__init__() + + with pytest.raises(ValueError): # As opposed to `KeyError`. + DatabricksConfig() + + +if __name__ == "__main__": + import conftest + test_config_accounts_dod_is_accounts_host(conftest.config.__pytest_wrapped__.obj()) diff --git a/tests/test_dbutils.py b/tests/test_dbutils.py index ac7e3102b..99d363b4a 100644 --- a/tests/test_dbutils.py +++ b/tests/test_dbutils.py @@ -117,14 +117,14 @@ def inner(results_data: any, expect_command: str): command_execute = mocker.patch( 'databricks.sdk.service.compute.CommandExecutionAPI.execute', return_value=Wait(lambda **kwargs: CommandStatusResponse( - results=Results(data=json.dumps(results_data)), status=CommandStatus.Finished))) + results=Results(data=json.dumps(results_data)), status=CommandStatus.FINISHED))) def assertions(): cluster_get.assert_called_with('x') - context_create.assert_called_with(cluster_id='x', language=Language.python) + context_create.assert_called_with(cluster_id='x', language=Language.PYTHON) command_execute.assert_called_with(cluster_id='x', context_id='y', - language=Language.python, + language=Language.PYTHON, command=expect_command) dbutils = RemoteDbUtils( @@ -184,14 +184,14 @@ def test_any_proxy(dbutils_proxy): command = ('\n' ' import json\n' ' (args, kwargs) = json.loads(\'[["a"], {}]\')\n' - ' result = dbutils.widgets.getParameter(*args, **kwargs)\n' + ' result = dbutils.notebook.exit(*args, **kwargs)\n' ' dbutils.notebook.exit(json.dumps(result))\n' ' ') - dbutils, assertions = dbutils_proxy('b', command) + dbutils, assertions = dbutils_proxy('a', command) - param = dbutils.widgets.getParameter('a') + param = dbutils.notebook.exit("a") - assert param == 'b' + assert param == 'a' assertions()