Skip to content

Commit

Permalink
refactor: Convert TranslatedObject to dataclass with improved type sa…
Browse files Browse the repository at this point in the history
…fety
  • Loading branch information
DavideGalilei committed Dec 3, 2024
1 parent 910cdca commit c1f2a54
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 44 deletions.
46 changes: 29 additions & 17 deletions gpytranslate/types/base_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,47 @@ class BaseTranslator:
@staticmethod
def parse(
raw: Union[dict, Mapping], translated: bool = True
) -> Union[TranslatedObject, Dict[str, Union[TranslatedObject, str, List[str]]]]:
x = {
"raw": TranslatedObject(raw),
"orig": " ".join(s["orig"] for s in raw["sentences"] if "orig" in s),
"text": " ".join(s["trans"] for s in raw["sentences"] if "trans" in s),
"orig_raw": [s["orig"] for s in raw["sentences"] if "orig" in s],
"text_raw": [s["trans"] for s in raw["sentences"] if "trans" in s],
"lang": raw["src"],
}
) -> Union[TranslatedObject, Dict[str, Any]]:
"""Parse raw API response into TranslatedObject.
Args:
raw: Raw response from translation API
translated: Whether to return TranslatedObject or dict
Returns:
Either TranslatedObject or raw dict based on translated parameter
"""
if translated:
return TranslatedObject(x)
return x
return TranslatedObject.from_raw_response(raw)
return raw

def check(
self,
text: Union[str, Mapping, Any],
text: Union[str, Mapping, List[str]],
raw: Union[Mapping, List],
client: str,
dt: str,
):
) -> Union[TranslatedObject, Dict[str, TranslatedObject], List[TranslatedObject]]:
"""Check and parse API response based on input type.
Args:
text: Original input text
raw: Raw API response
client: API client type
dt: Data type parameter
Returns:
Parsed translation result(s)
"""
if client != "gtx" or dt != "t":
return raw
return raw # type: ignore

if isinstance(text, str):
return self.parse(raw)
return self.parse(raw) # type: ignore
elif isinstance(text, Mapping):
return {k: self.parse(v) for k, v in raw.items()}
return {k: self.parse(v) for k, v in raw.items()} # type: ignore
else:
return [self.parse(elem) for elem in raw]
return [self.parse(elem) for elem in raw] # type: ignore

def get_headers(self) -> dict:
return self.headers() if callable(self.headers) else self.headers
Expand Down
56 changes: 29 additions & 27 deletions gpytranslate/types/translated_object.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,41 @@
"""Translation result object implementation."""
import json
from typing import Any, Dict, List, Union
from dataclasses import dataclass
from typing import Any, Dict, List, Optional


class TranslatedObject(dict):
"""A dictionary subclass that holds translation results with attribute access."""
@dataclass
class TranslatedObject:
"""A dataclass that holds translation results."""
raw: Dict[str, Any]
orig: str
text: str
orig_raw: List[str]
text_raw: List[str]
lang: str

def __getattr__(self, attr: str) -> Union['TranslatedObject', List['TranslatedObject'], Any]:
"""Get attributes allowing dot notation access.
def __str__(self) -> str:
"""Get string representation with the translated text.
Args:
attr: The attribute name to access
Returns:
The attribute value, wrapped in TranslatedObject if it's a dict
str: The translated text
"""
if isinstance(self, list):
return [TranslatedObject(elem) for elem in self]

value = dict.get(self, attr)
if isinstance(value, dict):
return TranslatedObject(value)
return value
return self.text

def __str__(self) -> str:
"""Get string representation, truncating long values.
@classmethod
def from_raw_response(cls, raw: Dict[str, Any]) -> "TranslatedObject":
"""Create TranslatedObject from raw API response.
Args:
raw: Raw response dictionary from the translation API
Returns:
str: JSON formatted string with truncated values
TranslatedObject: Parsed translation result
"""
return json.dumps(
{k: v if len(str(v)) < 200 else "..." for k, v in self.items()},
indent=4
return cls(
raw=raw,
orig=" ".join(s["orig"] for s in raw["sentences"] if "orig" in s),
text=" ".join(s["trans"] for s in raw["sentences"] if "trans" in s),
orig_raw=[s["orig"] for s in raw["sentences"] if "orig" in s],
text_raw=[s["trans"] for s in raw["sentences"] if "trans" in s],
lang=raw["src"]
)

# Maintain dict-like attribute access
__setattr__ = dict.__setitem__ # type: ignore
__delattr__ = dict.__delitem__ # type: ignore

0 comments on commit c1f2a54

Please sign in to comment.