From 425ad9eb581655cd1902c630dac6901a92b61933 Mon Sep 17 00:00:00 2001 From: "Davide Galilei (aider)" Date: Tue, 3 Dec 2024 20:03:34 +0100 Subject: [PATCH] refactor: Add type annotations and type variables to improve type checking --- gpytranslate/gpytranslate.py | 16 ++++++++++++---- gpytranslate/types/base_translator.py | 13 ++++++++----- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/gpytranslate/gpytranslate.py b/gpytranslate/gpytranslate.py index 3f0437e..e85cc62 100644 --- a/gpytranslate/gpytranslate.py +++ b/gpytranslate/gpytranslate.py @@ -13,12 +13,20 @@ get_base_headers, ) -T = TypeVar("T", str, List[str], Dict[Any, str], Mapping) +T = TypeVar("T", str, List[str], Dict[Any, str], Mapping[Any, str]) +K = TypeVar("K") class AsyncBufferedIOBase(Protocol): async def write(self, data: bytes) -> int: ... async def close(self) -> None: ... + + +class TranslatorOptions(TypedDict, total=False): + timeout: Optional[float] + verify: bool + cert: Optional[str] + trust_env: bool class Translator(BaseTranslator): @@ -27,9 +35,9 @@ def __init__( proxies: Optional[Dict[str, str]] = None, url: str = DEFAULT_TRANSLATION_ENDPOINT, tts_url: str = DEFAULT_TTS_ENDPOINT, - headers: Optional[Union[dict, Callable[[], dict]]] = None, - **options, - ): + headers: Optional[Union[Dict[str, str], Callable[[], Dict[str, str]]]] = None, + **options: Any, + ) -> None: self.url = url self.tts_url = tts_url self.proxies = proxies diff --git a/gpytranslate/types/base_translator.py b/gpytranslate/types/base_translator.py index ad9d669..f46e006 100644 --- a/gpytranslate/types/base_translator.py +++ b/gpytranslate/types/base_translator.py @@ -1,8 +1,11 @@ from collections.abc import Mapping -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union from .translated_object import TranslatedObject +K = TypeVar('K') +V = TypeVar('V') + class BaseTranslator: headers: Union[Dict[str, str], Callable[[], Dict[str, str]]] @@ -24,11 +27,11 @@ def parse(raw: Dict[str, Any], translated: bool = True) -> Union[TranslatedObjec def check( self, - text: Union[str, Mapping, List[str]], + text: Union[str, Mapping[K, str], List[str]], raw: Union[Mapping[str, Any], List[Any]], client: str, dt: str, - ) -> Union[TranslatedObject, Dict[str, TranslatedObject], List[TranslatedObject]]: + ) -> Union[TranslatedObject, Dict[K, TranslatedObject], List[TranslatedObject]]: """Check and parse API response based on input type. Args: @@ -59,9 +62,9 @@ def parse_tts( targetlang: str, idx: int, prev: str, - text: Union[str, List[str], Dict[Any, str], Mapping[Any, Any]], + text: Union[str, List[str], Dict[Any, str], Mapping[K, str]], textlen: Optional[int], - extra: dict, + extra: Dict[str, Any], ) -> Dict[str, Union[str, int]]: return { k: v