Skip to content

Commit a5edb4c

Browse files
authored
✨ add overloads for rest api
1 parent 4532467 commit a5edb4c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+13583
-3120
lines changed

codegen/parser/endpoints/endpoint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def get_imports(self) -> Set[str]:
5959
if self.request_body:
6060
imports.update(self.request_body.get_param_imports())
6161
imports.update(self.request_body.get_using_imports())
62+
if self.request_body.allowed_models:
63+
imports.add("from typing import Union")
64+
imports.add("from githubkit.utils import UNSET, Unset")
6265
if self.success_response:
6366
imports.update(self.success_response.get_using_imports())
6467
for resp in self.error_responses.values():

codegen/parser/endpoints/request_body.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from ...source import Source
77
from ..utils import concat_snake_name
8-
from ..schemas import Property, SchemaData, ModelSchema, parse_schema
8+
from ..schemas import Property, SchemaData, ModelSchema, UnionSchema, parse_schema
99

1010

1111
class RequestBodyData(BaseModel):
@@ -14,26 +14,40 @@ class RequestBodyData(BaseModel):
1414
required: bool = False
1515

1616
@property
17-
def is_model(self) -> bool:
18-
return isinstance(self.body_schema, ModelSchema)
17+
def allowed_models(self) -> List[ModelSchema]:
18+
if isinstance(self.body_schema, ModelSchema):
19+
return [self.body_schema]
20+
elif isinstance(self.body_schema, UnionSchema):
21+
return [
22+
schema
23+
for schema in self.body_schema.schemas
24+
if isinstance(schema, ModelSchema)
25+
]
26+
return []
1927

20-
def get_params_defination(self) -> List[str]:
21-
if not isinstance(self.body_schema, ModelSchema):
22-
prop = Property(
23-
name="body",
24-
prop_name="body",
25-
required=self.required,
26-
schema_data=self.body_schema,
27-
)
28-
return [prop.get_param_defination()]
28+
def get_raw_definition(self) -> str:
29+
prop = Property(
30+
name="data",
31+
prop_name="data",
32+
required=self.required,
33+
schema_data=self.body_schema,
34+
)
35+
return prop.get_param_defination()
2936

30-
return [prop.get_param_defination() for prop in self.body_schema.properties]
37+
def get_endpoint_definition(self) -> str:
38+
prop = Property(
39+
name="data",
40+
prop_name="data",
41+
required=not bool(self.allowed_models),
42+
schema_data=self.body_schema,
43+
)
44+
return prop.get_param_defination()
3145

3246
def get_param_imports(self) -> Set[str]:
3347
imports = set()
3448
imports.update(self.body_schema.get_param_imports())
35-
if isinstance(self.body_schema, ModelSchema):
36-
for prop in self.body_schema.properties:
49+
for model in self.allowed_models:
50+
for prop in model.properties:
3751
imports.update(prop.get_param_imports())
3852
return imports
3953

codegen/templates/client/_param.py.jinja

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,39 @@
2222
{% endfor %}
2323
{% endmacro %}
2424

25-
{% macro body_params(endpoint) %}
26-
{% for prop in endpoint.request_body.get_params_defination() %}
27-
{{ prop }},
25+
{% macro body_params(model) %}
26+
{% for prop in model.properties %}
27+
{{ prop.get_param_defination() }},
2828
{% endfor %}
2929
{% endmacro %}
3030

31-
{% macro endpoint_params(endpoint) %}
31+
{% macro endpoint_raw_params(endpoint) %}
32+
{{ path_params(endpoint) }}
33+
{{ query_params(endpoint) }}
34+
{{ header_params(endpoint) }}
35+
{{ cookie_params(endpoint) }}
36+
*,
37+
{{ endpoint.request_body.get_raw_definition() }}
38+
{% endmacro %}
39+
40+
{% macro endpoint_model_params(endpoint, model) %}
41+
{{ path_params(endpoint) }}
42+
{{ query_params(endpoint) }}
43+
{{ header_params(endpoint) }}
44+
{{ cookie_params(endpoint) }}
45+
*,
46+
data: Unset = UNSET,
47+
{{ body_params(model) }}
48+
{% endmacro %}
49+
50+
{% macro endpoint_params(endpoint, model) %}
3251
{{ path_params(endpoint) }}
3352
{{ query_params(endpoint) }}
3453
{{ header_params(endpoint) }}
3554
{{ cookie_params(endpoint) }}
3655
{%- if endpoint.request_body %}
3756
*,
38-
{{ body_params(endpoint) }}
57+
{{ endpoint.request_body.get_endpoint_definition() }},
58+
**kwargs
3959
{%- endif %}
4060
{% endmacro %}

codegen/templates/client/_request.py.jinja

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
{% from "client/_response.py.jinja" import build_response_model, build_error_models %}
22

3+
{% set TYPE_MAPPING = {"json": "json", "form": "data", "file": "files", "raw": "content"} %}
4+
35
{% macro build_path(endpoint) %}
46
{% if endpoint.path_params %}
57
url = f"{{ endpoint.path }}"
@@ -32,28 +34,17 @@ cookies = {
3234
}
3335
{% endmacro %}
3436

35-
{% macro _get_body_name(request_body) %}
36-
{% if request_body.is_model %}
37-
{{ request_body.body_schema.class_name }}(**{
38-
{% for prop in request_body.body_schema.properties %}
39-
"{{ prop.name }}": {{ prop.prop_name }},
40-
{% endfor %}
41-
}).dict(by_alias=True)
42-
{% else %}
43-
body
44-
{% endif %}
45-
{% endmacro %}
46-
4737
{% macro build_body(request_body) %}
48-
{% if request_body.type == "json" %}
49-
json = {{ _get_body_name(request_body) }}
50-
{% elif request_body.type == "form" %}
51-
data = {{ _get_body_name(request_body) }}
52-
{% elif request_body.type == "file" %}
53-
files = {{ _get_body_name(request_body) }}
54-
{% elif request_body.type == "raw" %}
55-
content = {{ _get_body_name(request_body) }}
56-
{% endif %}
38+
{% set name = TYPE_MAPPING[request_body.type] %}
39+
if not kwargs:
40+
kwargs = UNSET
41+
42+
{{ name }} = kwargs if data is UNSET else data
43+
{{ name }} = parse_obj_as(
44+
{{ request_body.body_schema.get_type_string() }},
45+
{{ name }}
46+
)
47+
{{ name }} = {{ name }}.dict(by_alias=True) if isinstance({{ name }}, BaseModel) else {{ name }}
5748
{% endmacro %}
5849

5950
{% macro build_request(endpoint) %}
@@ -73,15 +64,8 @@ url,
7364
params=exclude_unset(params),
7465
{% endif %}
7566
{% if endpoint.request_body %}
76-
{% if endpoint.request_body.type == "raw" %}
77-
content=exclude_unset(content),
78-
{% elif endpoint.request_body.type == "form" %}
79-
data=exclude_unset(data),
80-
{% elif endpoint.request_body.type == "file" %}
81-
files=exclude_unset(files),
82-
{% elif endpoint.request_body.type == "json" %}
83-
json=exclude_unset(json),
84-
{% endif %}
67+
{% set name = TYPE_MAPPING[endpoint.request_body.type] %}
68+
{{ name }}=exclude_unset({{ name }}),
8569
{% endif %}
8670
{% if endpoint.header_params %}
8771
headers=exclude_unset(headers),

codegen/templates/client/client.py.jinja

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ This file is auto generated by github rest api discription.
44
See https://github.com/github/rest-api-description for more information.
55
"""
66

7-
{% from "client/_param.py.jinja" import endpoint_params %}
7+
{% from "client/_param.py.jinja" import endpoint_params, endpoint_raw_params, endpoint_model_params %}
88
{% from "client/_response.py.jinja" import build_response_type %}
99
{% from "client/_request.py.jinja" import build_request, build_request_params %}
1010

@@ -14,7 +14,9 @@ See https://github.com/github/rest-api-description for more information.
1414
{% endfor %}
1515
{% endfor %}
1616

17-
from typing import TYPE_CHECKING
17+
from typing import TYPE_CHECKING, overload
18+
19+
from pydantic import BaseModel, parse_obj_as
1820

1921
from githubkit.utils import exclude_unset
2022

@@ -27,6 +29,26 @@ class {{ pascal_case(tag) }}Client:
2729
self._github = github
2830

2931
{% for endpoint in endpoints %}
32+
{% if endpoint.request_body and endpoint.request_body.allowed_models %}
33+
{# generate raw data overload #}
34+
@overload
35+
def {{ endpoint.name }}(
36+
self,
37+
{{ endpoint_raw_params(endpoint) | indent(8) }}
38+
) -> "{{ build_response_type(endpoint.success_response) }}":
39+
...
40+
41+
{# generate model data overload #}
42+
{% for model in endpoint.request_body.allowed_models %}
43+
@overload
44+
def {{ endpoint.name }}(
45+
self,
46+
{{ endpoint_model_params(endpoint, model) | indent(8) }}
47+
) -> "{{ build_response_type(endpoint.success_response) }}":
48+
...
49+
50+
{% endfor %}
51+
{% endif %}
3052
def {{ endpoint.name }}(
3153
self,
3254
{{ endpoint_params(endpoint) | indent(8) }}
@@ -36,6 +58,26 @@ class {{ pascal_case(tag) }}Client:
3658
{{ build_request_params(endpoint) | indent(12) }}
3759
)
3860

61+
{% if endpoint.request_body and endpoint.request_body.allowed_models %}
62+
{# generate raw data overload #}
63+
@overload
64+
async def async_{{ endpoint.name }}(
65+
self,
66+
{{ endpoint_raw_params(endpoint) | indent(8) }}
67+
) -> "{{ build_response_type(endpoint.success_response) }}":
68+
...
69+
70+
{# generate model data overload #}
71+
{% for model in endpoint.request_body.allowed_models %}
72+
@overload
73+
async def async_{{ endpoint.name }}(
74+
self,
75+
{{ endpoint_model_params(endpoint, model) | indent(8) }}
76+
) -> "{{ build_response_type(endpoint.success_response) }}":
77+
...
78+
79+
{% endfor %}
80+
{% endif %}
3981
async def async_{{ endpoint.name }}(
4082
self,
4183
{{ endpoint_params(endpoint) | indent(8) }}

codegen/templates/models/models.py.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ from __future__ import annotations
1515
{% endfor %}
1616

1717

18-
class GitHubModel(BaseModel, arbitrary_types_allowed=True):
18+
class GitHubModel(BaseModel, allow_population_by_field_name=True):
1919
...
2020

2121
{# model #}

codegen/templates/namespace/namespace.py.jinja

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ if TYPE_CHECKING:
1717

1818

1919
class RestNamespace:
20-
__slots__ = ("_github",)
21-
2220
def __init__(self, github: "GitHubCore"):
2321
self._github = github
2422

githubkit/rest/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@
5050

5151

5252
class RestNamespace:
53-
__slots__ = ("_github",)
54-
5553
def __init__(self, github: "GitHubCore"):
5654
self._github = github
5755

0 commit comments

Comments
 (0)