Skip to content
This repository was archived by the owner on Mar 25, 2022. It is now read-only.

Commit 0e2d4ce

Browse files
authored
Merge pull request #22 from sh1ma/develop
Develop
2 parents b9513a2 + 404569d commit 0e2d4ce

File tree

3 files changed

+90
-30
lines changed

3 files changed

+90
-30
lines changed

src/apywrapper/_api.py

+37-13
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,62 @@
22
_api.py
33
"""
44

5-
from typing import Any, Callable
5+
from typing import Any, Callable, Optional
66

77
from httpx._types import HeaderTypes
88

99
from ._abc import Api
1010
from ._request import HttpClient, make_request, make_request_function
11-
from ._types import ApiFunc, Entity, ReturnEntity
11+
from ._types import ApiFunc, Entity, HookFunc, ReturnEntity, SerializeFunc
1212

1313

1414
class Apy(Api):
1515
"""
1616
Apy class
1717
"""
1818

19-
def __init__(self, host: str, headers: HeaderTypes) -> None:
19+
def __init__(
20+
self,
21+
host: str,
22+
headers: HeaderTypes,
23+
hook_func: Optional[HookFunc] = None,
24+
serialize_func: Optional[SerializeFunc] = None,
25+
) -> None:
26+
self.hook_func = hook_func if hook_func else None
27+
self.serialize_func = serialize_func if serialize_func else None
2028
self.http_client = HttpClient(base_url=host, headers=headers)
2129

2230
def get(self, path: str) -> Callable[[ApiFunc], ReturnEntity]:
23-
return make_request(path, self.http_client.get_request)
31+
return make_request(
32+
path, self.http_client.get_request, self.hook_func, self.serialize_func
33+
)
2434

2535
def post(self, path: str) -> Callable[[ApiFunc], ReturnEntity]:
26-
return make_request(path, self.http_client.post_request)
36+
return make_request(
37+
path, self.http_client.post_request, self.hook_func, self.serialize_func
38+
)
2739

2840
def put(self, path: str) -> Callable[[ApiFunc], ReturnEntity]:
29-
return make_request(path, self.http_client.put_request)
41+
return make_request(
42+
path, self.http_client.put_request, self.hook_func, self.serialize_func
43+
)
3044

3145
def delete(self, path: str) -> Callable[[ApiFunc], ReturnEntity]:
32-
return make_request(path, self.http_client.delete_request)
46+
return make_request(
47+
path, self.http_client.delete_request, self.hook_func, self.serialize_func
48+
)
3349

3450
def patch(self, path: str) -> Callable[[ApiFunc], ReturnEntity]:
35-
return make_request(path, self.http_client.patch_request)
51+
return make_request(
52+
path, self.http_client.patch_request, self.hook_func, self.serialize_func
53+
)
3654

3755

3856
def get(path: str) -> Callable[[ApiFunc], ReturnEntity]:
3957
def _get(func: ApiFunc) -> ReturnEntity:
4058
def wrapper(self: Apy, *args: Any, **kwargs: Any) -> Entity:
4159
return make_request_function(func, self, *args, **kwargs)(
42-
path, self.http_client.get_request
60+
path, self.http_client.get_request, self.hook_func, self.serialize_func
4361
)
4462

4563
return wrapper
@@ -51,7 +69,7 @@ def post(path: str) -> Callable[[ApiFunc], ReturnEntity]:
5169
def _post(func: ApiFunc) -> ReturnEntity:
5270
def wrapper(self: Apy, *args: Any, **kwargs: Any) -> Entity:
5371
return make_request_function(func, self, *args, **kwargs)(
54-
path, self.http_client.post_request
72+
path, self.http_client.post_request, self.hook_func, self.serialize_func
5573
)
5674

5775
return wrapper
@@ -63,7 +81,7 @@ def put(path: str) -> Callable[[ApiFunc], ReturnEntity]:
6381
def _put(func: ApiFunc) -> ReturnEntity:
6482
def wrapper(self: Apy, *args: Any, **kwargs: Any) -> Entity:
6583
return make_request_function(func, self, *args, **kwargs)(
66-
path, self.http_client.put_request
84+
path, self.http_client.put_request, self.hook_func, self.serialize_func
6785
)
6886

6987
return wrapper
@@ -75,7 +93,10 @@ def delete(path: str) -> Callable[[ApiFunc], ReturnEntity]:
7593
def _delete(func: ApiFunc) -> ReturnEntity:
7694
def wrapper(self: Apy, *args: Any, **kwargs: Any) -> Entity:
7795
return make_request_function(func, self, *args, **kwargs)(
78-
path, self.http_client.delete_request
96+
path,
97+
self.http_client.delete_request,
98+
self.hook_func,
99+
self.serialize_func,
79100
)
80101

81102
return wrapper
@@ -87,7 +108,10 @@ def patch(path: str) -> Callable[[ApiFunc], ReturnEntity]:
87108
def _patch(func: ApiFunc) -> ReturnEntity:
88109
def wrapper(self: Apy, *args: Any, **kwargs: Any) -> Entity:
89110
return make_request_function(func, self, *args, **kwargs)(
90-
path, self.http_client.patch_request
111+
path,
112+
self.http_client.patch_request,
113+
self.hook_func,
114+
self.serialize_func,
91115
)
92116

93117
return wrapper

src/apywrapper/_request.py

+34-13
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,44 @@
33
"""
44

55

6-
from typing import Any, Callable, Dict, List, Optional, Type, Union
6+
from typing import Any, Callable, Dict, List, Optional, Type, Union, cast
77

88
from dacite import from_dict
99
from httpx import Client, Response
1010

1111
from ._path import Path
12-
from ._types import ApiFunc, Entity, EntityType, RequestFunc, ReturnEntity
12+
from ._types import (
13+
ApiFunc,
14+
Entity,
15+
EntityType,
16+
HookFunc,
17+
RequestFunc,
18+
ReturnEntity,
19+
SerializeFunc,
20+
)
1321
from ._utils import get_returntype_from_annotation
1422

1523

1624
def serialize(
1725
entity: Type[EntityType], json: Union[Dict, List]
1826
) -> Union[List[EntityType], EntityType]:
1927
if isinstance(json, list):
20-
return [from_dict(data_class=entity, data=d) for d in json]
28+
return [cast(EntityType, from_dict(data_class=entity, data=d)) for d in json]
2129

22-
return from_dict(data_class=entity, data=json)
30+
return cast(EntityType, from_dict(data_class=entity, data=json))
2331

2432

2533
def make_request_function(
2634
func: ApiFunc, *args: Any, **kwargs: Any
27-
) -> Callable[[str, RequestFunc], Entity]:
28-
def wrapper(path_str: str, request_func: RequestFunc) -> Entity:
35+
) -> Callable[
36+
[str, RequestFunc, Optional[HookFunc], Optional[SerializeFunc]], Union[Entity, Any]
37+
]:
38+
def wrapper(
39+
path_str: str,
40+
request_func: RequestFunc,
41+
hook_func: Optional[HookFunc] = None,
42+
serialize_func: Optional[SerializeFunc] = None,
43+
) -> Union[Entity, Any]:
2944
entity = get_returntype_from_annotation(func)
3045
params = func(*args, **kwargs)
3146
path = Path(path_str, params)
@@ -34,17 +49,28 @@ def wrapper(path_str: str, request_func: RequestFunc) -> Entity:
3449
entity is None or response.status_code == 204
3550
): # entity is None or response body is None
3651
return None
52+
if hook_func:
53+
return hook_func(entity, response)
54+
elif serialize_func:
55+
response.raise_for_status()
56+
return serialize(entity, response.json())
57+
response.raise_for_status()
3758
return serialize(entity, response.json())
3859

3960
return wrapper
4061

4162

4263
def make_request(
43-
path_str: str, request_func: RequestFunc
64+
path_str: str,
65+
request_func: RequestFunc,
66+
hook_func: Optional[HookFunc] = None,
67+
serialize_func: Optional[SerializeFunc] = None,
4468
) -> Callable[[ApiFunc], ReturnEntity]:
4569
def decorator(func: ApiFunc) -> ReturnEntity:
4670
def wrapper(*args: Any, **kwargs: Any) -> Entity:
47-
return make_request_function(func, *args, **kwargs)(path_str, request_func)
71+
return make_request_function(func, *args, **kwargs)(
72+
path_str, request_func, hook_func, serialize_func
73+
)
4874

4975
return wrapper
5076

@@ -66,29 +92,24 @@ class HttpClient(Client):
6692
def get_request(self, path: Path, params: Optional[Dict] = None) -> Response:
6793
real_params = pick_params(path, params)
6894
res = self.get(path, params=real_params)
69-
res.raise_for_status()
7095
return res
7196

7297
def post_request(self, path: Path, params: Optional[Dict] = None) -> Response:
7398
params = pick_params(path, params)
7499
res = self.post(path, json=params)
75-
res.raise_for_status()
76100
return res
77101

78102
def put_request(self, path: Path, params: Optional[Dict] = None) -> Response:
79103
params = pick_params(path, params)
80104
res = self.put(path, json=params)
81-
res.raise_for_status()
82105
return res
83106

84107
def delete_request(self, path: Path, params: Optional[Dict] = None) -> Response:
85108
params = pick_params(path, params)
86109
res = self.delete(path, params=params)
87-
res.raise_for_status()
88110
return res
89111

90112
def patch_request(self, path: Path, params: Optional[Dict] = None) -> Response:
91113
params = pick_params(path, params)
92114
res = self.patch(path, json=params)
93-
res.raise_for_status()
94115
return res

src/apywrapper/_types.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,36 @@
33
"""
44

55

6-
from typing import Callable, Dict, List, Optional, Protocol, Union
6+
from typing import Any, Callable, Dict, List, Optional, Protocol, Type, Union
77

88
from httpx import Response
99

1010

11-
class EntityType(Protocol):
11+
class DataclassEntityType(Protocol):
1212
# pylint: disable=too-few-public-methods
1313
"""
14-
Entity Object Type
14+
Entity Object Type on dataclass
1515
"""
1616

1717
__dataclass_fields__: Dict
1818

1919

20+
class PydanticEntityType(Protocol):
21+
"""
22+
Entity Object Type on Pydantic
23+
"""
24+
25+
__fields__: Dict
26+
27+
28+
EntityType = Union[DataclassEntityType, PydanticEntityType]
29+
30+
2031
Entity = Optional[Union[List[EntityType], EntityType]]
21-
ReturnEntity = Callable[..., Entity]
32+
ReturnEntity = Callable[..., Union[Entity, Any]]
2233
RequestFunc = Callable[..., Response]
2334
ApiFunc = Callable[..., Dict]
35+
SerializeFunc = Callable[
36+
[Type[EntityType], Union[Dict, List]], Union[List[EntityType], EntityType],
37+
]
38+
HookFunc = Callable[[Type[EntityType], Response], Any]

0 commit comments

Comments
 (0)