11import  logging 
22from  itertools  import  groupby 
3- from  typing  import  Any , Callable , List , Optional , Type , TypeVar ,  Union 
3+ from  typing  import  Any , Callable , List , Optional , Type , Union 
44
55from  aws_lambda_powertools .utilities .data_classes  import  AppSyncResolverEvent 
66from  aws_lambda_powertools .utilities .typing  import  LambdaContext 
77
88logger  =  logging .getLogger (__name__ )
99
10- AppSyncResolverEventT  =  TypeVar ("AppSyncResolverEventT" , bound = AppSyncResolverEvent )
1110
11+ class  RouterContext :
12+     def  __init__ (self ):
13+         super ().__init__ ()
14+         self .context  =  {}
1215
13- class  BaseRouter :
14-     current_event : Union [AppSyncResolverEventT , List [AppSyncResolverEventT ]]  # type: ignore[valid-type] 
15-     lambda_context : LambdaContext 
16-     context : dict 
16+     def  append_context (self , ** additional_context ):
17+         """Append key=value data as routing context""" 
18+         self .context .update (** additional_context )
1719
20+     def  clear_context (self ):
21+         """Resets routing context""" 
22+         self .context .clear ()
23+ 
24+ 
25+ class  ResolverRegistry :
1826    def  __init__ (self ):
27+         super ().__init__ ()
1928        self ._resolvers : dict  =  {}
29+         self ._batch_resolvers : dict  =  {}
2030
2131    def  resolver (self , type_name : str  =  "*" , field_name : Optional [str ] =  None ):
2232        """Registers the resolver for field_name 
@@ -29,23 +39,33 @@ def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
2939            Field name 
3040        """ 
3141
32-         def  register_resolver (func ):
42+         def  register (func ):
3343            logger .debug (f"Adding resolver `{ func .__name__ } { type_name } { field_name }  )
3444            self ._resolvers [f"{ type_name } { field_name }  ] =  {"func" : func }
3545            return  func 
3646
37-         return  register_resolver 
47+         return  register 
3848
39-     def  append_context (self , ** additional_context ):
40-         """Append key=value data as routing context""" 
41-         self .context .update (** additional_context )
49+     def  batch_resolver (self , type_name : str  =  "*" , field_name : Optional [str ] =  None ):
50+         """Registers the resolver for field_name 
4251
43-     def  clear_context (self ):
44-         """Resets routing context""" 
45-         self .context .clear ()
52+         Parameters 
53+         ---------- 
54+         type_name : str 
55+             Type name 
56+         field_name : str 
57+             Field name 
58+         """ 
4659
60+         def  register (func ):
61+             logger .debug (f"Adding batch resolver `{ func .__name__ } { type_name } { field_name }  )
62+             self ._batch_resolvers [f"{ type_name } { field_name }  ] =  {"func" : func }
63+             return  func 
4764
48- class  AppSyncResolver (BaseRouter ):
65+         return  register 
66+ 
67+ 
68+ class  AppSyncResolver (ResolverRegistry , RouterContext ):
4969    """ 
5070    AppSync resolver decorator 
5171
@@ -78,16 +98,20 @@ def common_field() -> str:
7898
7999    def  __init__ (self ):
80100        super ().__init__ ()
81-         self .context  =  {}  # early init as customers might add context before event resolution 
101+         self .current_batch_event : List [AppSyncResolverEvent ] =  []
102+         self .current_event : Optional [AppSyncResolverEvent ] =  None 
82103
83104    def  resolve (
84-         self , event : dict , context : LambdaContext , data_model : Type [AppSyncResolverEvent ] =  AppSyncResolverEvent 
105+         self ,
106+         event : Union [dict , List [dict ]],
107+         context : LambdaContext ,
108+         data_model : Type [AppSyncResolverEvent ] =  AppSyncResolverEvent ,
85109    ) ->  Any :
86110        """Resolve field_name 
87111
88112        Parameters 
89113        ---------- 
90-         event : dict 
114+         event : dict | List[dict]  
91115            Lambda event 
92116        context : LambdaContext 
93117            Lambda context 
@@ -152,33 +176,38 @@ def lambda_handler(event, context):
152176        ValueError 
153177            If we could not find a field resolver 
154178        """ 
155-         # Maintenance: revisit generics/overload to fix [attr-defined] in mypy usage 
156- 
157-         BaseRouter .lambda_context  =  context 
158- 
159-         # If event is a list it means that AppSync sent batch request 
160-         if  isinstance (event , list ):
161-             event_groups  =  [
162-                 {"field_name" : field_name , "events" : list (events )}
163-                 for  field_name , events  in  groupby (event , key = lambda  x : x ["info" ]["fieldName" ])
164-             ]
165-             if  len (event_groups ) >  1 :
166-                 ValueError ("batch with different field names. It shouldn't happen!" )
167- 
168-             appconfig_events  =  [data_model (event ) for  event  in  event_groups [0 ]["events" ]]
169-             BaseRouter .current_event  =  appconfig_events 
170-             resolver  =  self ._get_resolver (appconfig_events [0 ].type_name , event_groups [0 ]["field_name" ])
171-             response  =  resolver ()
172-         else :
173-             appconfig_event  =  data_model (event )
174-             BaseRouter .current_event  =  appconfig_event 
175-             resolver  =  self ._get_resolver (appconfig_event .type_name , appconfig_event .field_name )
176-             response  =  resolver (** appconfig_event .arguments )
177179
180+         self .lambda_context  =  context 
181+ 
182+         response  =  (
183+             self ._call_batch_resolver (event , data_model )
184+             if  isinstance (event , list )
185+             else  self ._call_resolver (event , data_model )
186+         )
178187        self .clear_context ()
179188
180189        return  response 
181190
191+     def  _call_resolver (self , event : dict , data_model : Type [AppSyncResolverEvent ]) ->  Any :
192+         self .current_event  =  data_model (event )
193+         resolver  =  self ._get_resolver (self .current_event .type_name , self .current_event .field_name )
194+         return  resolver (** self .current_event .arguments )
195+ 
196+     def  _call_batch_resolver (self , event : List [dict ], data_model : Type [AppSyncResolverEvent ]) ->  list [Any ]:
197+         event_groups  =  [
198+             {"field_name" : field_name , "events" : list (events )}
199+             for  field_name , events  in  groupby (event , key = lambda  x : x ["info" ]["fieldName" ])
200+         ]
201+         if  len (event_groups ) >  1 :
202+             ValueError ("batch with different field names. It shouldn't happen!" )
203+ 
204+         self .current_batch_event  =  [data_model (event ) for  event  in  event_groups [0 ]["events" ]]
205+         resolver  =  self ._get_batch_resolver (
206+             self .current_batch_event [0 ].type_name , self .current_batch_event [0 ].field_name 
207+         )
208+ 
209+         return  [resolver (event = appconfig_event ) for  appconfig_event  in  self .current_batch_event ]
210+ 
182211    def  _get_resolver (self , type_name : str , field_name : str ) ->  Callable :
183212        """Get resolver for field_name 
184213
@@ -200,8 +229,32 @@ def _get_resolver(self, type_name: str, field_name: str) -> Callable:
200229            raise  ValueError (f"No resolver found for '{ full_name }  )
201230        return  resolver ["func" ]
202231
232+     def  _get_batch_resolver (self , type_name : str , field_name : str ) ->  Callable :
233+         """Get resolver for field_name 
234+ 
235+         Parameters 
236+         ---------- 
237+         type_name : str 
238+             Type name 
239+         field_name : str 
240+             Field name 
241+ 
242+         Returns 
243+         ------- 
244+         Callable 
245+             callable function and configuration 
246+         """ 
247+         full_name  =  f"{ type_name } { field_name }  
248+         resolver  =  self ._batch_resolvers .get (full_name , self ._batch_resolvers .get (f"*.{ field_name }  ))
249+         if  not  resolver :
250+             raise  ValueError (f"No batch resolver found for '{ full_name }  )
251+         return  resolver ["func" ]
252+ 
203253    def  __call__ (
204-         self , event : dict , context : LambdaContext , data_model : Type [AppSyncResolverEvent ] =  AppSyncResolverEvent 
254+         self ,
255+         event : Union [dict , List [dict ]],
256+         context : LambdaContext ,
257+         data_model : Type [AppSyncResolverEvent ] =  AppSyncResolverEvent ,
205258    ) ->  Any :
206259        """Implicit lambda handler which internally calls `resolve`""" 
207260        return  self .resolve (event , context , data_model )
@@ -222,7 +275,6 @@ def include_router(self, router: "Router") -> None:
222275        self ._resolvers .update (router ._resolvers )
223276
224277
225- class  Router (BaseRouter ):
278+ class  Router (RouterContext ,  ResolverRegistry ):
226279    def  __init__ (self ):
227280        super ().__init__ ()
228-         self .context  =  {}  # early init as customers might add context before event resolution 
0 commit comments