Skip to content

Commit 4532467

Browse files
authored
🚸 improve parser and code generator
1 parent 15903b8 commit 4532467

File tree

17 files changed

+2439
-2465
lines changed

17 files changed

+2439
-2465
lines changed

codegen/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ def build_templates(data: GeneratorData, config: Config):
4040
models_template = env.get_template("models/models.py.jinja")
4141
models_path = Path(config.models_output)
4242
models_path.parent.mkdir(parents=True, exist_ok=True)
43-
models_path.write_text(models_template.render(data=data))
43+
models_path.write_text(models_template.render(models=data.models))
4444
logger.info("Successfully built models!")
4545

4646
# build types
4747
logger.info("Building types...")
48-
types_template = env.get_template("types/types.py.jinja")
48+
types_template = env.get_template("models/types.py.jinja")
4949
types_path = Path(config.types_output)
5050
types_path.parent.mkdir(parents=True, exist_ok=True)
51-
types_path.write_text(types_template.render(data=data))
51+
types_path.write_text(types_template.render(models=data.models))
5252
logger.info("Successfully built types!")
5353

5454
# build endpoints

codegen/parser/data.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
from itertools import chain
21
from collections import defaultdict
3-
from typing import Set, Dict, List, Optional
2+
from typing import Dict, List, Optional
43

54
from pydantic import BaseModel
65

76
from .endpoints import EndpointData
87
from .schemas import SchemaData, ModelSchema
9-
from .utils import snake_case, fix_reserved_words
108

119

1210
class GeneratorData(BaseModel):
@@ -28,6 +26,3 @@ def endpoints_by_tag(self) -> Dict[str, List[EndpointData]]:
2826
@property
2927
def models(self) -> List[ModelSchema]:
3028
return [schema for schema in self.schemas if isinstance(schema, ModelSchema)]
31-
32-
def get_imports(self) -> Set[str]:
33-
return set(chain.from_iterable(schema.get_imports() for schema in self.schemas))

codegen/parser/endpoints/request_body.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,10 @@ def get_params_defination(self) -> List[str]:
3131

3232
def get_param_imports(self) -> Set[str]:
3333
imports = set()
34+
imports.update(self.body_schema.get_param_imports())
3435
if isinstance(self.body_schema, ModelSchema):
35-
imports.update(self.body_schema.get_param_imports())
3636
for prop in self.body_schema.properties:
3737
imports.update(prop.get_param_imports())
38-
else:
39-
imports.update(self.body_schema.get_param_imports())
4038
return imports
4139

4240
def get_using_imports(self) -> Set[str]:

codegen/parser/schemas/schema.py

Lines changed: 73 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,23 @@ def get_param_type_string(self) -> str:
2020
"""Get type string used by client codegen"""
2121
return self.get_type_string()
2222

23-
def get_imports(self) -> Set[str]:
24-
"""Get schema needed imports for creating the schema"""
23+
def get_model_imports(self) -> Set[str]:
24+
"""Get schema needed imports for model codegen"""
2525
return set()
2626

27+
def get_type_imports(self) -> Set[str]:
28+
"""Get schema needed imports for types codegen"""
29+
return self.get_model_imports()
30+
2731
def get_param_imports(self) -> Set[str]:
28-
"""Get schema needed imports for typing params"""
29-
return self.get_imports()
32+
"""Get schema needed imports for client param codegen"""
33+
return self.get_model_imports()
3034

3135
def get_using_imports(self) -> Set[str]:
32-
"""Get schema needed imports for using"""
33-
return self.get_imports()
36+
"""Get schema needed imports for client request codegen"""
37+
return self.get_model_imports()
3438

35-
def get_default_args(self) -> Dict[str, str]:
39+
def _get_default_args(self) -> Dict[str, str]:
3640
"""Get pydantic field info args"""
3741
default = self.default
3842
args = {}
@@ -59,18 +63,26 @@ def get_type_string(self) -> str:
5963
return type_string if self.required else f"Union[Unset, {type_string}]"
6064

6165
def get_param_type_string(self) -> str:
66+
"""Get type string used by client codegen"""
6267
type_string = self.schema_data.get_param_type_string()
6368
return type_string if self.required else f"Union[Unset, {type_string}]"
6469

65-
def get_imports(self) -> Set[str]:
66-
"""Get schema needed imports for creating the schema"""
67-
imports = self.schema_data.get_imports()
70+
def get_model_imports(self) -> Set[str]:
71+
"""Get schema needed imports for model codegen"""
72+
imports = self.schema_data.get_model_imports()
6873
imports.add("from pydantic import Field")
6974
if not self.required:
7075
imports.add("from typing import Union")
7176
imports.add("from githubkit.utils import UNSET, Unset")
7277
return imports
7378

79+
def get_type_imports(self) -> Set[str]:
80+
"""Get schema needed imports for type codegen"""
81+
imports = self.schema_data.get_type_imports()
82+
if not self.required:
83+
imports.add("from typing_extensions import NotRequired")
84+
return imports
85+
7486
def get_param_imports(self) -> Set[str]:
7587
"""Get schema needed imports for typing params"""
7688
imports = self.schema_data.get_param_imports()
@@ -79,8 +91,8 @@ def get_param_imports(self) -> Set[str]:
7991
imports.add("from githubkit.utils import UNSET, Unset")
8092
return imports
8193

82-
def get_default_string(self) -> str:
83-
args = self.schema_data.get_default_args()
94+
def _get_default_string(self) -> str:
95+
args = self.schema_data._get_default_args()
8496
if "default" not in args and "default_factory" not in args:
8597
args["default"] = "..." if self.required else "UNSET"
8698
if self.prop_name != self.name:
@@ -103,12 +115,12 @@ def get_param_defination(self) -> str:
103115
def get_model_defination(self) -> str:
104116
"""Get defination used by model codegen"""
105117
type_ = self.get_type_string()
106-
default = self.get_default_string()
118+
default = self._get_default_string()
107119
return f"{self.prop_name}: {type_} = {default}"
108120

109121
def get_type_defination(self) -> str:
110122
"""Get defination used by types codegen"""
111-
type_ = self.get_param_type_string()
123+
type_ = self.schema_data.get_param_type_string()
112124
return (
113125
f"{self.prop_name}: {type_ if self.required else f'NotRequired[{type_}]'}"
114126
)
@@ -117,8 +129,8 @@ def get_type_defination(self) -> str:
117129
class AnySchema(SchemaData):
118130
_type_string: ClassVar[str] = "Any"
119131

120-
def get_imports(self) -> Set[str]:
121-
imports = super().get_imports()
132+
def get_model_imports(self) -> Set[str]:
133+
imports = super().get_model_imports()
122134
imports.add("from typing import Any")
123135
return imports
124136

@@ -140,8 +152,8 @@ class IntSchema(SchemaData):
140152

141153
_type_string: ClassVar[str] = "int"
142154

143-
def get_default_args(self) -> Dict[str, str]:
144-
args = super().get_default_args()
155+
def _get_default_args(self) -> Dict[str, str]:
156+
args = super()._get_default_args()
145157
if self.multiple_of is not None:
146158
args["multiple_of"] = repr(self.multiple_of)
147159
if self.maximum is not None:
@@ -164,8 +176,8 @@ class FloatSchema(SchemaData):
164176

165177
_type_string: ClassVar[str] = "float"
166178

167-
def get_default_args(self) -> Dict[str, str]:
168-
args = super().get_default_args()
179+
def _get_default_args(self) -> Dict[str, str]:
180+
args = super()._get_default_args()
169181
if self.multiple_of is not None:
170182
args["multiple_of"] = str(self.multiple_of)
171183
if self.maximum is not None:
@@ -186,8 +198,8 @@ class StringSchema(SchemaData):
186198

187199
_type_string: ClassVar[str] = "str"
188200

189-
def get_default_args(self) -> Dict[str, str]:
190-
args = super().get_default_args()
201+
def _get_default_args(self) -> Dict[str, str]:
202+
args = super()._get_default_args()
191203
if self.min_length is not None:
192204
args["min_length"] = str(self.min_length)
193205
if self.max_length is not None:
@@ -200,26 +212,26 @@ def get_default_args(self) -> Dict[str, str]:
200212
class DateTimeSchema(SchemaData):
201213
_type_string: ClassVar[str] = "datetime"
202214

203-
def get_imports(self) -> Set[str]:
204-
imports = super().get_imports()
215+
def get_model_imports(self) -> Set[str]:
216+
imports = super().get_model_imports()
205217
imports.add("from datetime import datetime")
206218
return imports
207219

208220

209221
class DateSchema(SchemaData):
210222
_type_string: ClassVar[str] = "date"
211223

212-
def get_imports(self) -> Set[str]:
213-
imports = super().get_imports()
224+
def get_model_imports(self) -> Set[str]:
225+
imports = super().get_model_imports()
214226
imports.add("from datetime import date")
215227
return imports
216228

217229

218230
class FileSchema(SchemaData):
219231
_type_string: ClassVar[str] = "FileTypes"
220232

221-
def get_imports(self) -> Set[str]:
222-
imports = super().get_imports()
233+
def get_model_imports(self) -> Set[str]:
234+
imports = super().get_model_imports()
223235
imports.add("from githubkit.typing import FileTypes")
224236
return imports
225237

@@ -236,10 +248,15 @@ def get_type_string(self) -> str:
236248
def get_param_type_string(self) -> str:
237249
return f"List[{self.item_schema.get_param_type_string()}]"
238250

239-
def get_imports(self) -> Set[str]:
240-
imports = super().get_imports()
251+
def get_model_imports(self) -> Set[str]:
252+
imports = super().get_model_imports()
241253
imports.add("from typing import List")
242-
imports.update(self.item_schema.get_imports())
254+
imports.update(self.item_schema.get_model_imports())
255+
return imports
256+
257+
def get_type_imports(self) -> Set[str]:
258+
imports = {"from typing import List"}
259+
imports.update(self.item_schema.get_type_imports())
243260
return imports
244261

245262
def get_param_imports(self) -> Set[str]:
@@ -252,8 +269,8 @@ def get_using_imports(self) -> Set[str]:
252269
imports.update(self.item_schema.get_using_imports())
253270
return imports
254271

255-
def get_default_args(self) -> Dict[str, str]:
256-
args = super().get_default_args()
272+
def _get_default_args(self) -> Dict[str, str]:
273+
args = super()._get_default_args()
257274
# FIXME: remove list constraints due to forwardref not supported
258275
# See https://github.com/samuelcolvin/pydantic/issues/3745
259276
if isinstance(self.item_schema, (ModelSchema, UnionSchema)):
@@ -282,8 +299,8 @@ def is_float_enum(self) -> bool:
282299
def get_type_string(self) -> str:
283300
return f"Literal[{', '.join(repr(value) for value in self.values)}]"
284301

285-
def get_imports(self) -> Set[str]:
286-
imports = super().get_imports()
302+
def get_model_imports(self) -> Set[str]:
303+
imports = super().get_model_imports()
287304
imports.add("from typing import Literal")
288305
return imports
289306

@@ -299,13 +316,19 @@ def get_type_string(self) -> str:
299316
def get_param_type_string(self) -> str:
300317
return f"{self.class_name}Type"
301318

302-
def get_imports(self) -> Set[str]:
303-
imports = super().get_imports()
319+
def get_model_imports(self) -> Set[str]:
320+
imports = super().get_model_imports()
304321
imports.add("from pydantic import BaseModel")
305322
if self.allow_extra:
306323
imports.add("from pydantic import Extra")
307324
for prop in self.properties:
308-
imports.update(prop.get_imports())
325+
imports.update(prop.get_model_imports())
326+
return imports
327+
328+
def get_type_imports(self) -> Set[str]:
329+
imports = {"from typing_extensions import TypedDict"}
330+
for prop in self.properties:
331+
imports.update(prop.get_type_imports())
309332
return imports
310333

311334
def get_param_imports(self) -> Set[str]:
@@ -331,11 +354,17 @@ def get_param_type_string(self) -> str:
331354
return self.schemas[0].get_param_type_string()
332355
return f"Union[{', '.join(schema.get_param_type_string() for schema in self.schemas)}]"
333356

334-
def get_imports(self) -> Set[str]:
335-
imports = super().get_imports()
357+
def get_model_imports(self) -> Set[str]:
358+
imports = super().get_model_imports()
336359
imports.add("from typing import Union")
337360
for schema in self.schemas:
338-
imports.update(schema.get_imports())
361+
imports.update(schema.get_model_imports())
362+
return imports
363+
364+
def get_type_imports(self) -> Set[str]:
365+
imports = {"from typing import Union"}
366+
for schema in self.schemas:
367+
imports.update(schema.get_type_imports())
339368
return imports
340369

341370
def get_param_imports(self) -> Set[str]:
@@ -350,11 +379,11 @@ def get_using_imports(self) -> Set[str]:
350379
imports.update(schema.get_using_imports())
351380
return imports
352381

353-
def get_default_args(self) -> Dict[str, str]:
382+
def _get_default_args(self) -> Dict[str, str]:
354383
args = {}
355384
for schema in self.schemas:
356-
args.update(schema.get_default_args())
357-
args.update(super().get_default_args())
385+
args.update(schema._get_default_args())
386+
args.update(super()._get_default_args())
358387
if self.discriminator:
359388
args["discriminator"] = self.discriminator
360389
return args

codegen/templates/client/param.py.jinja renamed to codegen/templates/client/_param.py.jinja

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
{% macro path_params(endpoint) %}
32
{% for path_param in endpoint.path_params %}
43
{{ path_param.get_param_defination() }},

codegen/templates/client/request.py.jinja renamed to codegen/templates/client/_request.py.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{% from "client/response.py.jinja" import build_response_model, build_error_models %}
1+
{% from "client/_response.py.jinja" import build_response_model, build_error_models %}
22

33
{% macro build_path(endpoint) %}
44
{% if endpoint.path_params %}

codegen/templates/client/client.py.jinja

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ This file is auto generated by github rest api discription.
44
See https://github.com/github/rest-api-description for more information.
55
"""
66

7-
{% from "client/param.py.jinja" import endpoint_params %}
8-
{% from "client/response.py.jinja" import build_response_type %}
9-
{% from "client/request.py.jinja" import build_request, build_request_params %}
7+
{% from "client/_param.py.jinja" import endpoint_params %}
8+
{% from "client/_response.py.jinja" import build_response_type %}
9+
{% from "client/_request.py.jinja" import build_request, build_request_params %}
1010

1111
{% for endpoint in endpoints %}
1212
{% for import_ in endpoint.get_imports() %}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{% macro build_model_docstring(model) %}
2+
"""{{ model.title or model.class_name | wordwrap(80) }}
3+
{% if model.description %}
4+
5+
{{ model.description | wordwrap(80) }}
6+
{% endif %}
7+
{% if model.examples %}
8+
9+
Examples:
10+
{% for example in model.examples %}
11+
{{ example | string | wordwrap(82) }}
12+
{% endfor %}
13+
{% endif %}
14+
"""
15+
{% endmacro %}

0 commit comments

Comments
 (0)