Skip to content

Commit

Permalink
first work
Browse files Browse the repository at this point in the history
  • Loading branch information
mgyucht committed Oct 7, 2024
1 parent 5e871cb commit 230ed00
Show file tree
Hide file tree
Showing 10 changed files with 988 additions and 721 deletions.
343 changes: 343 additions & 0 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
import logging
from datetime import timedelta
from types import TracebackType
from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List,
Optional, Type, Union)
import urllib.parse

import requests
import requests.adapters

from . import useragent
from .casing import Casing
from .clock import Clock, RealClock
from .errors import DatabricksError, _ErrorCustomizer, _Parser
from .logger import RoundTrip
from .retries import retried

logger = logging.getLogger('databricks.sdk')


def fix_host_if_needed(host: Optional[str]) -> Optional[str]:
if not host:
return host

# Add a default scheme if it's missing
if '://' not in host:
host = 'https://' + host

o = urllib.parse.urlparse(host)
# remove trailing slash
path = o.path.rstrip('/')
# remove port if 443
netloc = o.netloc
if o.port == 443:
netloc = netloc.split(':')[0]

return urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))


class _BaseClient:

def __init__(self,
debug_truncate_bytes: int = None,
retry_timeout_seconds: int = None,
user_agent_base: str = None,
header_factory: Callable[[], dict] = None,
max_connection_pools: int = None,
max_connections_per_pool: int = None,
pool_block: bool = True,
http_timeout_seconds: float = None,
extra_error_customizers: List[_ErrorCustomizer] = None,
debug_headers: bool = False,
clock: Clock = None):
"""
:param debug_truncate_bytes:
:param retry_timeout_seconds:
:param user_agent_base:
:param header_factory: A function that returns a dictionary of headers to include in the request.
:param max_connection_pools: Number of urllib3 connection pools to cache before discarding the least
recently used pool. Python requests default value is 10.
:param max_connections_per_pool: The maximum number of connections to save in the pool. Improves performance
in multithreaded situations. For now, we're setting it to the same value as connection_pool_size.
:param pool_block: If pool_block is False, then more connections will are created, but not saved after the
first use. Blocks when no free connections are available. urllib3 ensures that no more than
pool_maxsize connections are used at a time. Prevents platform from flooding. By default, requests library
doesn't block.
:param http_timeout_seconds:
:param extra_error_customizers:
:param debug_headers: Whether to include debug headers in the request log.
:param clock: Clock object to use for time-related operations.
"""

self._debug_truncate_bytes = debug_truncate_bytes or 96
self._debug_headers = debug_headers
self._retry_timeout_seconds = retry_timeout_seconds or 300
self._user_agent_base = user_agent_base or useragent.to_string()
self._header_factory = header_factory
self._clock = clock or RealClock()
self._session = requests.Session()
self._session.auth = self._authenticate

# We don't use `max_retries` from HTTPAdapter to align with a more production-ready
# retry strategy established in the Databricks SDK for Go. See _is_retryable and
# @retried for more details.
http_adapter = requests.adapters.HTTPAdapter(pool_connections=max_connections_per_pool or 20,
pool_maxsize=max_connection_pools or 20,
pool_block=pool_block)
self._session.mount("https://", http_adapter)

# Default to 60 seconds
self._http_timeout_seconds = http_timeout_seconds or 60

self._error_parser = _Parser(extra_error_customizers=extra_error_customizers)

def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
if self._header_factory:
headers = self._header_factory()
for k, v in headers.items():
r.headers[k] = v
return r

@staticmethod
def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]:
# Convert True -> "true" for Databricks APIs to understand booleans.
# See: https://github.com/databricks/databricks-sdk-py/issues/142
if query is None:
return None
with_fixed_bools = {k: v if type(v) != bool else ('true' if v else 'false') for k, v in query.items()}

# Query parameters may be nested, e.g.
# {'filter_by': {'user_ids': [123, 456]}}
# The HTTP-compatible representation of this is
# filter_by.user_ids=123&filter_by.user_ids=456
# To achieve this, we convert the above dictionary to
# {'filter_by.user_ids': [123, 456]}
# See the following for more information:
# https://cloud.google.com/endpoints/docs/grpc-service-config/reference/rpc/google.api#google.api.HttpRule
def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
for k1, v1 in d.items():
if isinstance(v1, dict):
v1 = dict(flatten_dict(v1))
for k2, v2 in v1.items():
yield f"{k1}.{k2}", v2
else:
yield k1, v1

flattened = dict(flatten_dict(with_fixed_bools))
return flattened

def do(self,
method: str,
url: str,
query: dict = None,
headers: dict = None,
body: dict = None,
raw: bool = False,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None,
response_headers: List[str] = None) -> Union[dict, list, BinaryIO]:
if headers is None:
headers = {}
headers['User-Agent'] = self._user_agent_base
retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock)
response = retryable(self._perform)(method,
url,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data,
auth=auth)

resp = dict()
for header in response_headers if response_headers else []:
resp[header] = response.headers.get(Casing.to_header_case(header))
if raw:
resp["contents"] = _StreamingResponse(response)
return resp
if not len(response.content):
return resp

json_response = response.json()
if json_response is None:
return resp

if isinstance(json_response, list):
return json_response

return {**resp, **json_response}

@staticmethod
def _is_retryable(err: BaseException) -> Optional[str]:
# this method is Databricks-specific port of urllib3 retries
# (see https://github.com/urllib3/urllib3/blob/main/src/urllib3/util/retry.py)
# and Databricks SDK for Go retries
# (see https://github.com/databricks/databricks-sdk-go/blob/main/apierr/errors.go)
from urllib3.exceptions import ProxyError
if isinstance(err, ProxyError):
err = err.original_error
if isinstance(err, requests.ConnectionError):
# corresponds to `connection reset by peer` and `connection refused` errors from Go,
# which are generally related to the temporary glitches in the networking stack,
# also caused by endpoint protection software, like ZScaler, to drop connections while
# not yet authenticated.
#
# return a simple string for debug log readability, as `raise TimeoutError(...) from err`
# will bubble up the original exception in case we reach max retries.
return f'cannot connect'
if isinstance(err, requests.Timeout):
# corresponds to `TLS handshake timeout` and `i/o timeout` in Go.
#
# return a simple string for debug log readability, as `raise TimeoutError(...) from err`
# will bubble up the original exception in case we reach max retries.
return f'timeout'
if isinstance(err, DatabricksError):
message = str(err)
transient_error_string_matches = [
"com.databricks.backend.manager.util.UnknownWorkerEnvironmentException",
"does not have any associated worker environments", "There is no worker environment with id",
"Unknown worker environment", "ClusterNotReadyException", "Unexpected error",
"Please try again later or try a faster operation.",
"RPC token bucket limit has been exceeded",
]
for substring in transient_error_string_matches:
if substring not in message:
continue
return f'matched {substring}'
return None

def _perform(self,
method: str,
url: str,
query: dict = None,
headers: dict = None,
body: dict = None,
raw: bool = False,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
response = self._session.request(method,
url,
params=self._fix_query_string(query),
json=body,
headers=headers,
files=files,
data=data,
auth=auth,
stream=raw,
timeout=self._http_timeout_seconds)
self._record_request_log(response, raw=raw or data is not None or files is not None)
error = self._error_parser.get_api_error(response)
if error is not None:
raise error from None
return response

def _record_request_log(self, response: requests.Response, raw: bool = False) -> None:
if not logger.isEnabledFor(logging.DEBUG):
return
logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate())


class _StreamingResponse(BinaryIO):
_response: requests.Response
_buffer: bytes
_content: Union[Iterator[bytes], None]
_chunk_size: Union[int, None]
_closed: bool = False

def fileno(self) -> int:
pass

def flush(self) -> int:
pass

def __init__(self, response: requests.Response, chunk_size: Union[int, None] = None):
self._response = response
self._buffer = b''
self._content = None
self._chunk_size = chunk_size

def _open(self) -> None:
if self._closed:
raise ValueError("I/O operation on closed file")
if not self._content:
self._content = self._response.iter_content(chunk_size=self._chunk_size)

def __enter__(self) -> BinaryIO:
self._open()
return self

def set_chunk_size(self, chunk_size: Union[int, None]) -> None:
self._chunk_size = chunk_size

def close(self) -> None:
self._response.close()
self._closed = True

def isatty(self) -> bool:
return False

def read(self, n: int = -1) -> bytes:
self._open()
read_everything = n < 0
remaining_bytes = n
res = b''
while remaining_bytes > 0 or read_everything:
if len(self._buffer) == 0:
try:
self._buffer = next(self._content)
except StopIteration:
break
bytes_available = len(self._buffer)
to_read = bytes_available if read_everything else min(remaining_bytes, bytes_available)
res += self._buffer[:to_read]
self._buffer = self._buffer[to_read:]
remaining_bytes -= to_read
return res

def readable(self) -> bool:
return self._content is not None

def readline(self, __limit: int = ...) -> bytes:
raise NotImplementedError()

def readlines(self, __hint: int = ...) -> List[bytes]:
raise NotImplementedError()

def seek(self, __offset: int, __whence: int = ...) -> int:
raise NotImplementedError()

def seekable(self) -> bool:
return False

def tell(self) -> int:
raise NotImplementedError()

def truncate(self, __size: Union[int, None] = ...) -> int:
raise NotImplementedError()

def writable(self) -> bool:
return False

def write(self, s: Union[bytes, bytearray]) -> int:
raise NotImplementedError()

def writelines(self, lines: Iterable[bytes]) -> None:
raise NotImplementedError()

def __next__(self) -> bytes:
return self.read(1)

def __iter__(self) -> Iterator[bytes]:
return self._content

def __exit__(self, t: Union[Type[BaseException], None], value: Union[BaseException, None],
traceback: Union[TracebackType, None]) -> None:
self._content = None
self._buffer = b''
self.close()
Loading

0 comments on commit 230ed00

Please sign in to comment.