@@ -309,10 +309,15 @@ def find(obj: ast.AST, name: str) -> ast.AST:
309
309
"""Find a particular named node in a tree"""
310
310
return next (node for node in ast .walk (obj ) if getattr (node , "name" , "" ) == name )
311
311
312
+
312
313
if typing .TYPE_CHECKING :
313
- AstVal : "typing.TypeAlias" = "int | float | complex | str | list[AstVal] | bytes | None"
314
+ AstVal : "typing.TypeAlias" = (
315
+ "int | float | complex | str | list[AstVal] | bytes | None"
316
+ )
314
317
AstValNoBytes : "typing.TypeAlias" = "int | float | str | list[AstValNoBytes]"
315
- JSONObject : "typing.TypeAlias" = "int | float | str | list[JSONObject] | JSONDict | None"
318
+ JSONObject : "typing.TypeAlias" = (
319
+ "int | float | str | list[JSONObject] | JSONDict | None"
320
+ )
316
321
JSONDict : "typing.TypeAlias" = "dict[str, JSONObject]"
317
322
318
323
@@ -327,6 +332,7 @@ def to_serializable(val: "AstVal") -> "JSONObject":
327
332
else :
328
333
return val
329
334
335
+
330
336
def get_value (node : ast .AST ) -> "AstVal" :
331
337
"""Return the value of constant or list of constants"""
332
338
if isinstance (node , ast .Constant ):
@@ -339,7 +345,7 @@ def get_value(node: ast.AST) -> "AstVal":
339
345
if isinstance (node , (ast .List , ast .Tuple )):
340
346
return [get_value (e ) for e in node .elts ]
341
347
if isinstance (node , ast .UnaryOp ) and isinstance (node .op , ast .USub ):
342
- return - typing .cast (typing .Union [int , float , complex ], get_value (node .operand ))
348
+ return - typing .cast (typing .Union [int , float , complex ], get_value (node .operand ))
343
349
raise ValueError ("Unexpected node type" , type (node ))
344
350
345
351
@@ -372,6 +378,7 @@ def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | types.EllipsisT
372
378
defaults = [...] * (len (args ) - len (predict .args .defaults )) + predict .args .defaults
373
379
return list (zip (args , defaults ))
374
380
381
+
375
382
def parse_assignment (assignment : ast .AST ) -> "None | tuple[str, JSONObject]" :
376
383
"""Parse an assignment into an OpenAPI object property"""
377
384
if isinstance (assignment , ast .AnnAssign ):
@@ -403,7 +410,9 @@ def parse_class(classdef: ast.AST) -> "JSONDict":
403
410
"""Parse a class definition into an OpenAPI object"""
404
411
assert isinstance (classdef , ast .ClassDef )
405
412
properties = {
406
- assignment [0 ]: assignment [1 ] for assignment in map (parse_assignment , classdef .body ) if assignment
413
+ assignment [0 ]: assignment [1 ]
414
+ for assignment in map (parse_assignment , classdef .body )
415
+ if assignment
407
416
}
408
417
return {
409
418
"title" : classdef .name ,
@@ -428,15 +437,17 @@ def resolve_name(node: ast.expr) -> str:
428
437
return node .id
429
438
if isinstance (node , ast .Index ):
430
439
# deprecated, but needed for py3.8
431
- return resolve_name (node .value ) # type: ignore
440
+ return resolve_name (node .value ) # type: ignore
432
441
if isinstance (node , ast .Attribute ):
433
442
return node .attr
434
443
if isinstance (node , ast .Subscript ):
435
444
return resolve_name (node .value )
436
445
raise ValueError ("Unexpected node type" , type (node ), ast .unparse (node ))
437
446
438
447
439
- def parse_return_annotation (tree : ast .AST , fn : str = "predict" ) -> "tuple[JSONDict, JSONDict]" :
448
+ def parse_return_annotation (
449
+ tree : ast .AST , fn : str = "predict"
450
+ ) -> "tuple[JSONDict, JSONDict]" :
440
451
predict = find (tree , fn )
441
452
if not isinstance (predict , (ast .FunctionDef , ast .AsyncFunctionDef )):
442
453
raise ValueError ("Could not find predict function" )
@@ -550,7 +561,7 @@ def extract_info(code: str) -> "JSONDict":
550
561
** return_schema ,
551
562
}
552
563
# trust me, typechecker, I know BASE_SCHEMA
553
- x : "JSONDict" = schema ["components" ]["schemas" ] # type: ignore
564
+ x : "JSONDict" = schema ["components" ]["schemas" ] # type: ignore
554
565
x .update (components )
555
566
return schema
556
567
0 commit comments