Skip to content

Commit

Permalink
fix(type): make some decorator utility functions type-safe and add so…
Browse files Browse the repository at this point in the history
…me type annotations (bentoml#1785)
  • Loading branch information
yetone authored Jul 28, 2021
1 parent e53abd5 commit 750b7e1
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 48 deletions.
72 changes: 44 additions & 28 deletions bentoml/_internal/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@
import socket
import tarfile
from io import StringIO
from typing import Callable, Optional
from typing import (
Optional, TypeVar, Type, Union, overload, Dict, Iterator, Any, Tuple,
TYPE_CHECKING, Generic, Callable)
from urllib.parse import urlparse, uses_netloc, uses_params, uses_relative

from google.protobuf.message import Message
from mypy.typeshed.stdlib.contextlib import _GeneratorContextManager

if TYPE_CHECKING:
from bentoml._internal.yatai_client import YataiClient

from ..utils.gcs import is_gcs_url
from ..utils.lazy_loader import LazyLoader
Expand Down Expand Up @@ -41,17 +47,21 @@


class _Missing(object):
def __repr__(self):
def __repr__(self) -> str:
return "no value"

def __reduce__(self):
def __reduce__(self) -> str:
return "_missing"


_missing = _Missing()


class cached_property(property):
T = TypeVar("T")
V = TypeVar("V")


class cached_property(property, Generic[T, V]):
"""A decorator that converts a function into a lazy property. The
function wrapped is called the first time to retrieve the result
and then that calculated result is used the next time you access
Expand All @@ -76,28 +86,32 @@ def foo(self):
manual invocation.
"""

def __init__(
self, func: Callable, name: str = None, doc: str = None
): # pylint:disable=super-init-not-called
def __init__(self, func: Callable[[T], V], name: Optional[str] = None, doc: Optional[str] = None): # pylint:disable=super-init-not-called
self.__name__ = name or func.__name__
self.__module__ = func.__module__
self.__doc__ = doc or func.__doc__
self.func = func

def __set__(self, obj, value):
def __set__(self, obj: T, value: V) -> None:
obj.__dict__[self.__name__] = value

def __get__(self, obj, type=None): # pylint:disable=redefined-builtin
@overload
def __get__(self, obj: None, type: Optional[Type[T]] = None) -> "cached_property": ...

@overload
def __get__(self, obj: T, type: Optional[Type[T]] = None) -> V: ...

def __get__(self, obj: Optional[T], type: Optional[Type[T]] = None) -> Union["cached_property", V]: # pylint:disable=redefined-builtin
if obj is None:
return self
value = obj.__dict__.get(self.__name__, _missing)
value: V = obj.__dict__.get(self.__name__, _missing)
if value is _missing:
value = self.func(obj)
obj.__dict__[self.__name__] = value
return value


class cached_contextmanager:
class cached_contextmanager(Generic[T]):
"""
Just like contextlib.contextmanager, but will cache the yield value for the same
arguments. When one instance of the contextmanager exits, the cache value will
Expand All @@ -113,20 +127,21 @@ def start_docker_container_from_image(docker_image, timeout=60):
container.stop()
"""

def __init__(self, cache_key_template=None):
def __init__(self, cache_key_template: Optional[str] = None) -> None:
self._cache_key_template = cache_key_template
self._cache = {}
self._cache: Dict[Union[str, Tuple], T] = {}

def __call__(self, func):
# TODO: use ParamSpec: https://github.com/python/mypy/issues/8645
def __call__(self, func: Callable[..., Iterator[T]]) -> Callable[..., _GeneratorContextManager[T]]:
func_m = contextlib.contextmanager(func)

@contextlib.contextmanager
@functools.wraps(func)
def _func(*args, **kwargs):
def _func(*args: Any, **kwargs: Any) -> Iterator[T]:
bound_args = inspect.signature(func).bind(*args, **kwargs)
bound_args.apply_defaults()
if self._cache_key_template:
cache_key = self._cache_key_template.format(**bound_args.arguments)
cache_key: Union[str, Tuple] = self._cache_key_template.format(**bound_args.arguments)
else:
cache_key = tuple(bound_args.arguments.values())
if cache_key in self._cache:
Expand All @@ -141,7 +156,7 @@ def _func(*args, **kwargs):


@contextlib.contextmanager
def reserve_free_port(host="localhost"):
def reserve_free_port(host: str = "localhost") -> Iterator[int]:
"""
detect free port and reserve until exit the context
"""
Expand All @@ -152,13 +167,13 @@ def reserve_free_port(host="localhost"):
sock.close()


def get_free_port(host="localhost"):
def get_free_port(host: str = "localhost") -> int:
"""
detect free port and reserve until exit the context
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind((host, 0))
port = sock.getsockname()[1]
port: int = sock.getsockname()[1]
sock.close()
return port

Expand All @@ -170,7 +185,7 @@ def is_url(url: str) -> bool:
return False


def dump_to_yaml_str(yaml_dict):
def dump_to_yaml_str(yaml_dict: Dict) -> str:
from ..utils.ruamel_yaml import YAML

yaml = YAML()
Expand All @@ -186,7 +201,7 @@ def pb_to_yaml(message: Message) -> str:
return dump_to_yaml_str(message_dict)


def ProtoMessageToDict(protobuf_msg: Message, **kwargs) -> object:
def ProtoMessageToDict(protobuf_msg: Message, **kwargs: Any) -> object:
from google.protobuf.json_format import MessageToDict

if "preserving_proto_field_name" not in kwargs:
Expand All @@ -196,7 +211,7 @@ def ProtoMessageToDict(protobuf_msg: Message, **kwargs) -> object:


# This function assume the status is not status.OK
def status_pb_to_error_code_and_message(pb_status) -> (int, str):
def status_pb_to_error_code_and_message(pb_status) -> Tuple[int, str]:
from ..yatai_client.proto import status_pb2

assert pb_status.status_code != status_pb2.Status.OK
Expand All @@ -205,14 +220,15 @@ def status_pb_to_error_code_and_message(pb_status) -> (int, str):
return error_code, error_message


class catch_exceptions(object):
def __init__(self, exceptions, fallback=None):
class catch_exceptions(object, Generic[T]):
def __init__(self, exceptions: Union[Type[BaseException], Tuple[Type[BaseException], ...]], fallback: Optional[T] = None) -> None:
self.exceptions = exceptions
self.fallback = fallback

def __call__(self, func):
# TODO: use ParamSpec: https://github.com/python/mypy/issues/8645
def __call__(self, func: Callable[..., T]) -> Callable[..., Optional[T]]:
@functools.wraps(func)
def _(*args, **kwargs):
def _(*args: Any, **kwargs: Any) -> Optional[T]:
try:
return func(*args, **kwargs)
except self.exceptions:
Expand Down Expand Up @@ -253,7 +269,7 @@ def resolve_bundle_path(
)


def get_default_yatai_client():
def get_default_yatai_client() -> YataiClient:
from bentoml._internal.yatai_client import YataiClient

return YataiClient()
Expand All @@ -271,7 +287,7 @@ def resolve_bento_bundle_uri(bento_pb):

def archive_directory_to_tar(
source_dir: str, tarfile_dir: str, tarfile_name: str
) -> (str, str):
) -> Tuple[str, str]:
file_name = f"{tarfile_name}.tar"
tarfile_path = os.path.join(tarfile_dir, file_name)
with tarfile.open(tarfile_path, mode="w:gz") as tar:
Expand Down
16 changes: 8 additions & 8 deletions bentoml/_internal/utils/csv.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# CSV utils following https://tools.ietf.org/html/rfc4180
from typing import Iterable, Iterator
from typing import Iterable, Iterator, Union


def csv_splitlines(string) -> Iterator[str]:
def csv_splitlines(string: str) -> Iterator[str]:
if '"' in string:

def _iter_line(line):
def _iter_line(line: str) -> Iterator[str]:
quoted = False
last_cur = 0
for i, c in enumerate(line):
Expand All @@ -25,11 +25,11 @@ def _iter_line(line):
return iter(string.splitlines())


def csv_split(string, delimiter) -> Iterator[str]:
def csv_split(string: str, delimiter: str) -> Iterator[str]:
if '"' in string:
dlen = len(delimiter)

def _iter_line(line):
def _iter_line(line: str) -> Iterator[str]:
quoted = False
last_cur = 0
for i, c in enumerate(line):
Expand All @@ -45,19 +45,19 @@ def _iter_line(line):
return iter(string.split(delimiter))


def csv_row(tds: Iterable):
def csv_row(tds: Iterable) -> str:
return ",".join(csv_quote(td) for td in tds)


def csv_unquote(string):
def csv_unquote(string: str) -> str:
if '"' in string:
string = string.strip()
assert string[0] == '"' and string[-1] == '"'
return string[1:-1].replace('""', '"')
return string


def csv_quote(td):
def csv_quote(td: Union[int, str]) -> str:
"""
>>> csv_quote(1)
'1'
Expand Down
4 changes: 2 additions & 2 deletions bentoml/_internal/utils/dataframe_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import itertools
import json
from typing import Iterable, Iterator, Mapping
from typing import Iterable, Iterator, Mapping, Any, Union, Set

from bentoml.exceptions import BadInput

Expand All @@ -24,7 +24,7 @@ def check_dataframe_column_contains(required_column_names, df):


@catch_exceptions(Exception, fallback=None)
def guess_orient(table, strict=False):
def guess_orient(table: Any, strict: bool = False) -> Union[None, str, Set[str]]:
if isinstance(table, list):
if not table:
if strict:
Expand Down
12 changes: 8 additions & 4 deletions bentoml/_internal/yatai_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
#
# logger = logging.getLogger(__name__)

from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from bentoml._internal.yatai_client.proto.yatai_service_pb2_grpc import YataiStub

import logging

Expand All @@ -33,11 +37,11 @@ def __init__(self, yatai_server_name: str):
self.deploy_api_client = None

@cached_property
def bundles(self):
def bundles(self) -> BentoRepositoryAPIClient:
return BentoRepositoryAPIClient(self._yatai_service)

@cached_property
def deployment(self):
def deployment(self) -> DeploymentAPIClient:
return DeploymentAPIClient(self._yatai_service)

# def __init__(self, yatai_service: Optional["YataiStub"] = None):
Expand All @@ -50,7 +54,7 @@ def deployment(self):
# return BentoRepositoryAPIClient(self.yatai_service)


def get_yatai_client(yatai_url: str = None) -> "YataiClient":
def get_yatai_client(yatai_url: Optional[str] = None) -> YataiClient:
"""
Args:
yatai_url (`str`):
Expand Down Expand Up @@ -80,7 +84,7 @@ def get_yatai_service(
tls_root_ca_cert: str,
tls_client_key: str,
tls_client_cert: str,
):
) -> YataiStub:
import certifi
import grpc

Expand Down
5 changes: 4 additions & 1 deletion bentoml/_internal/yatai_client/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
def parse_grpc_url(url):
from typing import Tuple, Optional


def parse_grpc_url(url: str) -> Tuple[Optional[str], str]:
"""
>>> parse_grpc_url("grpcs://yatai.com:43/query")
('grpcs', 'yatai.com:43/query')
Expand Down
6 changes: 3 additions & 3 deletions bentoml/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@

import cloudpickle

from ._internal.models.base import (
from bentoml._internal.models.base import (
H5_EXTENSION,
HDF5_EXTENSION,
JSON_EXTENSION,
MODEL_NAMESPACE,
PICKLE_EXTENSION,
Model,
)
from ._internal.types import MetadataType, PathType
from .exceptions import MissingDependencyException
from bentoml._internal.types import MetadataType, PathType
from bentoml.exceptions import MissingDependencyException

# fmt: off
try:
Expand Down
4 changes: 2 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[mypy]
show_error_codes = True
disable_error_code = attr-defined
exclude = "|(bentoml/_internal/yatai_client/proto)|(yatai/yatai/proto)|(yatai/versioneer.py)|"
exclude = "|venv|(bentoml/_internal/yatai_client/proto)|(yatai/yatai/proto)|(yatai/versioneer.py)|"
ignore_missing_imports = True

# mypy --strict --allow-any-generics --allow-subclassing-any --no-check-untyped-defs --allow-untyped-call
Expand All @@ -22,4 +22,4 @@ ignore_errors = True
ignore_errors = True

[mypy-*.exceptions.*]
ignore_errors = True
ignore_errors = True

0 comments on commit 750b7e1

Please sign in to comment.