11import inspect
2- from copy import copy
32from enum import Enum
4- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
3+ from typing import Any , Callable , Dict , List , Optional , Tuple , Type , Union
54
65from pydantic import BaseConfig
76from pydantic .fields import FieldInfo
8- from typing_extensions import Annotated , get_args , get_origin
7+ from typing_extensions import Annotated , Literal , get_args , get_origin
98
109from aws_lambda_powertools .event_handler .openapi .compat import (
1110 ModelField ,
1211 Required ,
1312 Undefined ,
13+ UndefinedType ,
14+ copy_field_info ,
1415 get_annotation_from_field_info ,
1516)
1617from aws_lambda_powertools .event_handler .openapi .types import PYDANTIC_V2 , CacheKey
@@ -302,7 +303,8 @@ def analyze_param(
302303 annotation : Any ,
303304 value : Any ,
304305 is_path_param : bool ,
305- ) -> Tuple [Any , Optional [ModelField ]]:
306+ is_response_param : bool ,
307+ ) -> Optional [ModelField ]:
306308 """
307309 Analyze a parameter annotation and value to determine the type and default value of the parameter.
308310
@@ -316,10 +318,12 @@ def analyze_param(
316318 The value of the parameter
317319 is_path_param
318320 Whether the parameter is a path parameter
321+ is_response_param
322+ Whether the parameter is the return annotation
319323
320324 Returns
321325 -------
322- Tuple[Any, Optional[ModelField] ]
326+ Optional[ModelField]
323327 The type annotation and the Pydantic field representing the parameter
324328 """
325329 field_info , type_annotation = _get_field_info_and_type_annotation (annotation , value , is_path_param )
@@ -336,12 +340,16 @@ def analyze_param(
336340
337341 # Check if the parameter is part of the path. Otherwise, defaults to query.
338342 if is_path_param :
339- field_info = Path (annotation = type_annotation , default = default_value )
343+ field_info = Path (annotation = type_annotation )
340344 else :
341345 field_info = Query (annotation = type_annotation , default = default_value )
342346
347+ # When we have a response field, we need to set the default value to Required
348+ if is_response_param :
349+ field_info .default = Required
350+
343351 field = _create_model_field (field_info , type_annotation , param_name , is_path_param )
344- return type_annotation , field
352+ return field
345353
346354
347355def _get_field_info_and_type_annotation (annotation , value , is_path_param : bool ) -> Tuple [Optional [FieldInfo ], Any ]:
@@ -372,7 +380,10 @@ def _get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tu
372380
373381 if isinstance (powertools_annotation , FieldInfo ):
374382 # Copy `field_info` because we mutate `field_info.default` later
375- field_info = copy (powertools_annotation )
383+ field_info = copy_field_info (
384+ field_info = powertools_annotation ,
385+ annotation = annotation ,
386+ )
376387 if field_info .default not in [Undefined , Required ]:
377388 raise AssertionError ("FieldInfo needs to have a default value of Undefined or Required" )
378389
@@ -386,6 +397,44 @@ def _get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tu
386397 return field_info , type_annotation
387398
388399
400+ def _create_response_field (
401+ name : str ,
402+ type_ : Type [Any ],
403+ default : Optional [Any ] = Undefined ,
404+ required : Union [bool , UndefinedType ] = Undefined ,
405+ model_config : Type [BaseConfig ] = BaseConfig ,
406+ field_info : Optional [FieldInfo ] = None ,
407+ alias : Optional [str ] = None ,
408+ mode : Literal ["validation" , "serialization" ] = "validation" ,
409+ ) -> ModelField :
410+ """
411+ Create a new response field. Raises if type_ is invalid.
412+ """
413+ if PYDANTIC_V2 :
414+ field_info = field_info or FieldInfo (
415+ annotation = type_ ,
416+ default = default ,
417+ alias = alias ,
418+ )
419+ else :
420+ field_info = field_info or FieldInfo ()
421+ kwargs = {"name" : name , "field_info" : field_info }
422+ if PYDANTIC_V2 :
423+ kwargs .update ({"mode" : mode })
424+ else :
425+ kwargs .update (
426+ {
427+ "type_" : type_ ,
428+ "class_validators" : {},
429+ "default" : default ,
430+ "required" : required ,
431+ "model_config" : model_config ,
432+ "alias" : alias ,
433+ },
434+ )
435+ return ModelField (** kwargs ) # type: ignore[arg-type]
436+
437+
389438def _create_model_field (
390439 field_info : Optional [FieldInfo ],
391440 type_annotation : Any ,
@@ -411,21 +460,11 @@ def _create_model_field(
411460 alias = field_info .alias or param_name
412461 field_info .alias = alias
413462
414- # Create the Pydantic field
415- kwargs = {"name" : param_name , "field_info" : field_info }
416-
417- if PYDANTIC_V2 :
418- kwargs .update ({"mode" : "validation" })
419- else :
420- kwargs .update (
421- {
422- "type_" : use_annotation ,
423- "class_validators" : {},
424- "default" : field_info .default ,
425- "required" : field_info .default in (Required , Undefined ),
426- "model_config" : BaseConfig ,
427- "alias" : alias ,
428- },
429- )
430-
431- return ModelField (** kwargs ) # type: ignore[arg-type]
463+ return _create_response_field (
464+ name = param_name ,
465+ type_ = use_annotation ,
466+ default = field_info .default ,
467+ alias = alias ,
468+ required = field_info .default in (Required , Undefined ),
469+ field_info = field_info ,
470+ )
0 commit comments