@@ -173,6 +173,9 @@ def leave(self, node, key, parent, path, ancestors):
173173 # Provide special return values as attributes
174174 BREAK , SKIP , REMOVE , IDLE = BREAK , SKIP , REMOVE , IDLE
175175
176+ def __init__ (self ):
177+ self ._visit_fns = {}
178+
176179 def __init_subclass__ (cls ) -> None :
177180 """Verify that all defined handlers are valid."""
178181 super ().__init_subclass__ ()
@@ -197,10 +200,18 @@ def __init_subclass__(cls) -> None:
197200
198201 def get_visit_fn (self , kind : str , is_leaving : bool = False ) -> Callable :
199202 """Get the visit function for the given node kind and direction."""
203+
204+ key = (kind , is_leaving )
205+ if key in self ._visit_fns :
206+ return self ._visit_fns [key ]
207+
200208 method = "leave" if is_leaving else "enter"
201209 visit_fn = getattr (self , f"{ method } _{ kind } " , None )
202210 if not visit_fn :
203211 visit_fn = getattr (self , method , None )
212+
213+ self ._visit_fns [key ] = visit_fn
214+
204215 return visit_fn
205216
206217
@@ -367,14 +378,22 @@ class ParallelVisitor(Visitor):
367378
368379 def __init__ (self , visitors : Collection [Visitor ]):
369380 """Create a new visitor from the given list of parallel visitors."""
381+ super ().__init__ ()
370382 self .visitors = visitors
371383 self .skipping : List [Any ] = [None ] * len (visitors )
384+ self ._enter_visit_fns = {}
385+ self ._leave_visit_fns = {}
372386
373387 def enter (self , node : Node , * args : Any ) -> Optional [VisitorAction ]:
388+ visit_fns = self ._enter_visit_fns .get (node .kind )
389+ if visit_fns is None :
390+ visit_fns = [v .get_visit_fn (node .kind ) for v in self .visitors ]
391+ self ._enter_visit_fns [node .kind ] = visit_fns
392+
374393 skipping = self .skipping
375394 for i , visitor in enumerate (self .visitors ):
376395 if not skipping [i ]:
377- fn = visitor . get_visit_fn ( node . kind )
396+ fn = visit_fns [ i ]
378397 if fn :
379398 result = fn (node , * args )
380399 if result is SKIP or result is False :
@@ -386,10 +405,15 @@ def enter(self, node: Node, *args: Any) -> Optional[VisitorAction]:
386405 return None
387406
388407 def leave (self , node : Node , * args : Any ) -> Optional [VisitorAction ]:
408+ visit_fns = self ._leave_visit_fns .get (node .kind )
409+ if visit_fns is None :
410+ visit_fns = [v .get_visit_fn (node .kind , is_leaving = True ) for v in self .visitors ]
411+ self ._leave_visit_fns [node .kind ] = visit_fns
412+
389413 skipping = self .skipping
390414 for i , visitor in enumerate (self .visitors ):
391415 if not skipping [i ]:
392- fn = visitor . get_visit_fn ( node . kind , is_leaving = True )
416+ fn = visit_fns [ i ]
393417 if fn :
394418 result = fn (node , * args )
395419 if result is BREAK or result is True :
0 commit comments