Skip to content

Commit

Permalink
refactor: Add type annotations and type variables to improve type che…
Browse files Browse the repository at this point in the history
…cking
  • Loading branch information
DavideGalilei committed Dec 3, 2024
1 parent 0857a45 commit 425ad9e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
16 changes: 12 additions & 4 deletions gpytranslate/gpytranslate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@
get_base_headers,
)

T = TypeVar("T", str, List[str], Dict[Any, str], Mapping)
T = TypeVar("T", str, List[str], Dict[Any, str], Mapping[Any, str])
K = TypeVar("K")


class AsyncBufferedIOBase(Protocol):
async def write(self, data: bytes) -> int: ...
async def close(self) -> None: ...


class TranslatorOptions(TypedDict, total=False):
timeout: Optional[float]
verify: bool
cert: Optional[str]
trust_env: bool


class Translator(BaseTranslator):
Expand All @@ -27,9 +35,9 @@ def __init__(
proxies: Optional[Dict[str, str]] = None,
url: str = DEFAULT_TRANSLATION_ENDPOINT,
tts_url: str = DEFAULT_TTS_ENDPOINT,
headers: Optional[Union[dict, Callable[[], dict]]] = None,
**options,
):
headers: Optional[Union[Dict[str, str], Callable[[], Dict[str, str]]]] = None,
**options: Any,
) -> None:
self.url = url
self.tts_url = tts_url
self.proxies = proxies
Expand Down
13 changes: 8 additions & 5 deletions gpytranslate/types/base_translator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from collections.abc import Mapping
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union

from .translated_object import TranslatedObject

K = TypeVar('K')
V = TypeVar('V')


class BaseTranslator:
headers: Union[Dict[str, str], Callable[[], Dict[str, str]]]
Expand All @@ -24,11 +27,11 @@ def parse(raw: Dict[str, Any], translated: bool = True) -> Union[TranslatedObjec

def check(
self,
text: Union[str, Mapping, List[str]],
text: Union[str, Mapping[K, str], List[str]],
raw: Union[Mapping[str, Any], List[Any]],
client: str,
dt: str,
) -> Union[TranslatedObject, Dict[str, TranslatedObject], List[TranslatedObject]]:
) -> Union[TranslatedObject, Dict[K, TranslatedObject], List[TranslatedObject]]:
"""Check and parse API response based on input type.
Args:
Expand Down Expand Up @@ -59,9 +62,9 @@ def parse_tts(
targetlang: str,
idx: int,
prev: str,
text: Union[str, List[str], Dict[Any, str], Mapping[Any, Any]],
text: Union[str, List[str], Dict[Any, str], Mapping[K, str]],
textlen: Optional[int],
extra: dict,
extra: Dict[str, Any],
) -> Dict[str, Union[str, int]]:
return {
k: v
Expand Down

0 comments on commit 425ad9e

Please sign in to comment.