@@ -20,19 +20,23 @@ def get_param_type_string(self) -> str:
20
20
"""Get type string used by client codegen"""
21
21
return self .get_type_string ()
22
22
23
- def get_imports (self ) -> Set [str ]:
24
- """Get schema needed imports for creating the schema """
23
+ def get_model_imports (self ) -> Set [str ]:
24
+ """Get schema needed imports for model codegen """
25
25
return set ()
26
26
27
+ def get_type_imports (self ) -> Set [str ]:
28
+ """Get schema needed imports for types codegen"""
29
+ return self .get_model_imports ()
30
+
27
31
def get_param_imports (self ) -> Set [str ]:
28
- """Get schema needed imports for typing params """
29
- return self .get_imports ()
32
+ """Get schema needed imports for client param codegen """
33
+ return self .get_model_imports ()
30
34
31
35
def get_using_imports (self ) -> Set [str ]:
32
- """Get schema needed imports for using """
33
- return self .get_imports ()
36
+ """Get schema needed imports for client request codegen """
37
+ return self .get_model_imports ()
34
38
35
- def get_default_args (self ) -> Dict [str , str ]:
39
+ def _get_default_args (self ) -> Dict [str , str ]:
36
40
"""Get pydantic field info args"""
37
41
default = self .default
38
42
args = {}
@@ -59,18 +63,26 @@ def get_type_string(self) -> str:
59
63
return type_string if self .required else f"Union[Unset, { type_string } ]"
60
64
61
65
def get_param_type_string (self ) -> str :
66
+ """Get type string used by client codegen"""
62
67
type_string = self .schema_data .get_param_type_string ()
63
68
return type_string if self .required else f"Union[Unset, { type_string } ]"
64
69
65
- def get_imports (self ) -> Set [str ]:
66
- """Get schema needed imports for creating the schema """
67
- imports = self .schema_data .get_imports ()
70
+ def get_model_imports (self ) -> Set [str ]:
71
+ """Get schema needed imports for model codegen """
72
+ imports = self .schema_data .get_model_imports ()
68
73
imports .add ("from pydantic import Field" )
69
74
if not self .required :
70
75
imports .add ("from typing import Union" )
71
76
imports .add ("from githubkit.utils import UNSET, Unset" )
72
77
return imports
73
78
79
+ def get_type_imports (self ) -> Set [str ]:
80
+ """Get schema needed imports for type codegen"""
81
+ imports = self .schema_data .get_type_imports ()
82
+ if not self .required :
83
+ imports .add ("from typing_extensions import NotRequired" )
84
+ return imports
85
+
74
86
def get_param_imports (self ) -> Set [str ]:
75
87
"""Get schema needed imports for typing params"""
76
88
imports = self .schema_data .get_param_imports ()
@@ -79,8 +91,8 @@ def get_param_imports(self) -> Set[str]:
79
91
imports .add ("from githubkit.utils import UNSET, Unset" )
80
92
return imports
81
93
82
- def get_default_string (self ) -> str :
83
- args = self .schema_data .get_default_args ()
94
+ def _get_default_string (self ) -> str :
95
+ args = self .schema_data ._get_default_args ()
84
96
if "default" not in args and "default_factory" not in args :
85
97
args ["default" ] = "..." if self .required else "UNSET"
86
98
if self .prop_name != self .name :
@@ -103,12 +115,12 @@ def get_param_defination(self) -> str:
103
115
def get_model_defination (self ) -> str :
104
116
"""Get defination used by model codegen"""
105
117
type_ = self .get_type_string ()
106
- default = self .get_default_string ()
118
+ default = self ._get_default_string ()
107
119
return f"{ self .prop_name } : { type_ } = { default } "
108
120
109
121
def get_type_defination (self ) -> str :
110
122
"""Get defination used by types codegen"""
111
- type_ = self .get_param_type_string ()
123
+ type_ = self .schema_data . get_param_type_string ()
112
124
return (
113
125
f"{ self .prop_name } : { type_ if self .required else f'NotRequired[{ type_ } ]' } "
114
126
)
@@ -117,8 +129,8 @@ def get_type_defination(self) -> str:
117
129
class AnySchema (SchemaData ):
118
130
_type_string : ClassVar [str ] = "Any"
119
131
120
- def get_imports (self ) -> Set [str ]:
121
- imports = super ().get_imports ()
132
+ def get_model_imports (self ) -> Set [str ]:
133
+ imports = super ().get_model_imports ()
122
134
imports .add ("from typing import Any" )
123
135
return imports
124
136
@@ -140,8 +152,8 @@ class IntSchema(SchemaData):
140
152
141
153
_type_string : ClassVar [str ] = "int"
142
154
143
- def get_default_args (self ) -> Dict [str , str ]:
144
- args = super ().get_default_args ()
155
+ def _get_default_args (self ) -> Dict [str , str ]:
156
+ args = super ()._get_default_args ()
145
157
if self .multiple_of is not None :
146
158
args ["multiple_of" ] = repr (self .multiple_of )
147
159
if self .maximum is not None :
@@ -164,8 +176,8 @@ class FloatSchema(SchemaData):
164
176
165
177
_type_string : ClassVar [str ] = "float"
166
178
167
- def get_default_args (self ) -> Dict [str , str ]:
168
- args = super ().get_default_args ()
179
+ def _get_default_args (self ) -> Dict [str , str ]:
180
+ args = super ()._get_default_args ()
169
181
if self .multiple_of is not None :
170
182
args ["multiple_of" ] = str (self .multiple_of )
171
183
if self .maximum is not None :
@@ -186,8 +198,8 @@ class StringSchema(SchemaData):
186
198
187
199
_type_string : ClassVar [str ] = "str"
188
200
189
- def get_default_args (self ) -> Dict [str , str ]:
190
- args = super ().get_default_args ()
201
+ def _get_default_args (self ) -> Dict [str , str ]:
202
+ args = super ()._get_default_args ()
191
203
if self .min_length is not None :
192
204
args ["min_length" ] = str (self .min_length )
193
205
if self .max_length is not None :
@@ -200,26 +212,26 @@ def get_default_args(self) -> Dict[str, str]:
200
212
class DateTimeSchema (SchemaData ):
201
213
_type_string : ClassVar [str ] = "datetime"
202
214
203
- def get_imports (self ) -> Set [str ]:
204
- imports = super ().get_imports ()
215
+ def get_model_imports (self ) -> Set [str ]:
216
+ imports = super ().get_model_imports ()
205
217
imports .add ("from datetime import datetime" )
206
218
return imports
207
219
208
220
209
221
class DateSchema (SchemaData ):
210
222
_type_string : ClassVar [str ] = "date"
211
223
212
- def get_imports (self ) -> Set [str ]:
213
- imports = super ().get_imports ()
224
+ def get_model_imports (self ) -> Set [str ]:
225
+ imports = super ().get_model_imports ()
214
226
imports .add ("from datetime import date" )
215
227
return imports
216
228
217
229
218
230
class FileSchema (SchemaData ):
219
231
_type_string : ClassVar [str ] = "FileTypes"
220
232
221
- def get_imports (self ) -> Set [str ]:
222
- imports = super ().get_imports ()
233
+ def get_model_imports (self ) -> Set [str ]:
234
+ imports = super ().get_model_imports ()
223
235
imports .add ("from githubkit.typing import FileTypes" )
224
236
return imports
225
237
@@ -236,10 +248,15 @@ def get_type_string(self) -> str:
236
248
def get_param_type_string (self ) -> str :
237
249
return f"List[{ self .item_schema .get_param_type_string ()} ]"
238
250
239
- def get_imports (self ) -> Set [str ]:
240
- imports = super ().get_imports ()
251
+ def get_model_imports (self ) -> Set [str ]:
252
+ imports = super ().get_model_imports ()
241
253
imports .add ("from typing import List" )
242
- imports .update (self .item_schema .get_imports ())
254
+ imports .update (self .item_schema .get_model_imports ())
255
+ return imports
256
+
257
+ def get_type_imports (self ) -> Set [str ]:
258
+ imports = {"from typing import List" }
259
+ imports .update (self .item_schema .get_type_imports ())
243
260
return imports
244
261
245
262
def get_param_imports (self ) -> Set [str ]:
@@ -252,8 +269,8 @@ def get_using_imports(self) -> Set[str]:
252
269
imports .update (self .item_schema .get_using_imports ())
253
270
return imports
254
271
255
- def get_default_args (self ) -> Dict [str , str ]:
256
- args = super ().get_default_args ()
272
+ def _get_default_args (self ) -> Dict [str , str ]:
273
+ args = super ()._get_default_args ()
257
274
# FIXME: remove list constraints due to forwardref not supported
258
275
# See https://github.com/samuelcolvin/pydantic/issues/3745
259
276
if isinstance (self .item_schema , (ModelSchema , UnionSchema )):
@@ -282,8 +299,8 @@ def is_float_enum(self) -> bool:
282
299
def get_type_string (self ) -> str :
283
300
return f"Literal[{ ', ' .join (repr (value ) for value in self .values )} ]"
284
301
285
- def get_imports (self ) -> Set [str ]:
286
- imports = super ().get_imports ()
302
+ def get_model_imports (self ) -> Set [str ]:
303
+ imports = super ().get_model_imports ()
287
304
imports .add ("from typing import Literal" )
288
305
return imports
289
306
@@ -299,13 +316,19 @@ def get_type_string(self) -> str:
299
316
def get_param_type_string (self ) -> str :
300
317
return f"{ self .class_name } Type"
301
318
302
- def get_imports (self ) -> Set [str ]:
303
- imports = super ().get_imports ()
319
+ def get_model_imports (self ) -> Set [str ]:
320
+ imports = super ().get_model_imports ()
304
321
imports .add ("from pydantic import BaseModel" )
305
322
if self .allow_extra :
306
323
imports .add ("from pydantic import Extra" )
307
324
for prop in self .properties :
308
- imports .update (prop .get_imports ())
325
+ imports .update (prop .get_model_imports ())
326
+ return imports
327
+
328
+ def get_type_imports (self ) -> Set [str ]:
329
+ imports = {"from typing_extensions import TypedDict" }
330
+ for prop in self .properties :
331
+ imports .update (prop .get_type_imports ())
309
332
return imports
310
333
311
334
def get_param_imports (self ) -> Set [str ]:
@@ -331,11 +354,17 @@ def get_param_type_string(self) -> str:
331
354
return self .schemas [0 ].get_param_type_string ()
332
355
return f"Union[{ ', ' .join (schema .get_param_type_string () for schema in self .schemas )} ]"
333
356
334
- def get_imports (self ) -> Set [str ]:
335
- imports = super ().get_imports ()
357
+ def get_model_imports (self ) -> Set [str ]:
358
+ imports = super ().get_model_imports ()
336
359
imports .add ("from typing import Union" )
337
360
for schema in self .schemas :
338
- imports .update (schema .get_imports ())
361
+ imports .update (schema .get_model_imports ())
362
+ return imports
363
+
364
+ def get_type_imports (self ) -> Set [str ]:
365
+ imports = {"from typing import Union" }
366
+ for schema in self .schemas :
367
+ imports .update (schema .get_type_imports ())
339
368
return imports
340
369
341
370
def get_param_imports (self ) -> Set [str ]:
@@ -350,11 +379,11 @@ def get_using_imports(self) -> Set[str]:
350
379
imports .update (schema .get_using_imports ())
351
380
return imports
352
381
353
- def get_default_args (self ) -> Dict [str , str ]:
382
+ def _get_default_args (self ) -> Dict [str , str ]:
354
383
args = {}
355
384
for schema in self .schemas :
356
- args .update (schema .get_default_args ())
357
- args .update (super ().get_default_args ())
385
+ args .update (schema ._get_default_args ())
386
+ args .update (super ()._get_default_args ())
358
387
if self .discriminator :
359
388
args ["discriminator" ] = self .discriminator
360
389
return args
0 commit comments