@@ -242,7 +242,8 @@ def resolve_imports(self):
242242 if len (to_nodes ) > 0
243243 }
244244
245- def filter_data (self , function : Union [None , str ] = None , namespace : Union [None , str ] = None , max_iter : int = 1000 ):
245+ def filter_data (self , function : Union [None , str ] = None , namespace : Union [None , str ] = None , max_iter : int = 1000 ,
246+ filter_down = True , filter_up = False ):
246247 if function :
247248 function_name = function .split ("." )[- 1 ]
248249 function_namespace = "." .join (function .split ("." )[:- 1 ])
@@ -251,9 +252,10 @@ def filter_data(self, function: Union[None, str] = None, namespace: Union[None,
251252 else :
252253 node = None
253254
254- self .filter (node = node , namespace = namespace )
255+ self .filter (node = node , namespace = namespace , filter_down = filter_down , filter_up = filter_up )
255256
256- def filter (self , node : Union [None , Node ] = None , namespace : Union [str , None ] = None , max_iter : int = 1000 ):
257+ def filter (self , node : Union [None , Node ] = None , namespace : Union [str , None ] = None , max_iter : int = 1000 ,
258+ filter_down : bool = True , filter_up : bool = False ):
257259 """
258260 filter callgraph nodes that related to `node` or are in `namespace`
259261
@@ -262,12 +264,15 @@ def filter(self, node: Union[None, Node] = None, namespace: Union[str, None] = N
262264 namespace: namespace to search in (name of top level module),
263265 if None, determines namespace from `node`
264266 max_iter: maximum number of iterations and nodes to iterate
267+ filter_down: filter nodes in downward
268+ filter_up: filter nodes in upward
265269
266270 Returns:
267271 self
268272 """
269273 # filter the nodes to avoid cluttering the callgraph with irrelevant information
270- filtered_nodes = self .get_related_nodes (node , namespace = namespace , max_iter = max_iter )
274+ filtered_nodes = self .get_related_nodes (node , namespace = namespace , max_iter = max_iter ,
275+ find_downward = filter_down , find_upward = filter_up )
271276
272277 self .nodes = {name : [node for node in nodes if node in filtered_nodes ] for name , nodes in self .nodes .items ()}
273278 self .uses_edges = {
@@ -283,7 +288,8 @@ def filter(self, node: Union[None, Node] = None, namespace: Union[str, None] = N
283288 return self
284289
285290 def get_related_nodes (
286- self , node : Union [None , Node ] = None , namespace : Union [str , None ] = None , max_iter : int = 1000
291+ self , node : Union [None , Node ] = None , namespace : Union [str , None ] = None , max_iter : int = 1000 ,
292+ find_downward : bool = True , find_upward : bool = False
287293 ) -> set :
288294 """
289295 get nodes that related to `node` or are in `namespace`
@@ -293,6 +299,8 @@ def get_related_nodes(
293299 namespace: namespace to search in (name of top level module),
294300 if None, determines namespace from `node`
295301 max_iter: maximum number of iterations and nodes to iterate
302+ find_downward: look for nodes in downward
303+ find_upward: look for nodes in upward
296304
297305 Returns:
298306 set: set of nodes related to `node` including `node` itself
@@ -316,63 +324,67 @@ def get_related_nodes(
316324 namespace = node .namespace .strip ("." ).split ("." , 1 )[0 ]
317325 queue = [node ]
318326
319- # use queue system to search through nodes
320- # essentially add a node to the queue and then search all connected nodes which are in turn added to the queue
321- # until the queue itself is empty or the maximum limit of max_iter searches have been hit
322- downstream_new_nodes = new_nodes .copy ()
323- downstream_queue = queue .copy ()
324- i = max_iter
325- while len (downstream_queue ) > 0 :
326- item = downstream_queue .pop ()
327- if item not in downstream_new_nodes :
328- downstream_new_nodes .add (item )
329- i -= 1
330- if i < 0 :
331- break
332- # add used nodes that are not already added and are in desired namespace
333- downstream_queue .extend (
334- [
335- n
336- for n in self .uses_edges .get (item , [])
337- if n in self .uses_edges and n not in downstream_new_nodes and namespace in n .namespace
338- ]
339- )
340- # add defined nodes that are not already added and are in desired namespace
341- downstream_queue .extend (
342- [
343- n
344- for n in self .defines_edges .get (item , [])
345- if n in self .defines_edges and n not in downstream_new_nodes and namespace in n .namespace
346- ]
347- )
327+ downstream_new_nodes = set ()
328+ if find_downward :
329+ # use queue system to search through nodes
330+ # essentially add a node to the queue and then search all connected nodes which are in turn added to the queue
331+ # until the queue itself is empty or the maximum limit of max_iter searches have been hit
332+ downstream_new_nodes = new_nodes .copy ()
333+ downstream_queue = queue .copy ()
334+ i = max_iter
335+ while len (downstream_queue ) > 0 :
336+ item = downstream_queue .pop ()
337+ if item not in downstream_new_nodes :
338+ downstream_new_nodes .add (item )
339+ i -= 1
340+ if i < 0 :
341+ break
342+ # add used nodes that are not already added and are in desired namespace
343+ downstream_queue .extend (
344+ [
345+ n
346+ for n in self .uses_edges .get (item , [])
347+ if n in self .uses_edges and n not in downstream_new_nodes and namespace in n .namespace
348+ ]
349+ )
350+ # add defined nodes that are not already added and are in desired namespace
351+ downstream_queue .extend (
352+ [
353+ n
354+ for n in self .defines_edges .get (item , [])
355+ if n in self .defines_edges and n not in downstream_new_nodes and namespace in n .namespace
356+ ]
357+ )
348358
349- # get callers of node
350- upstream_new_nodes = new_nodes .copy ()
351- upstream_queue = queue .copy ()
352- i = max_iter
353- while len (upstream_queue ) > 0 :
354- item = upstream_queue .pop ()
355- if item not in upstream_new_nodes :
356- upstream_new_nodes .add (item )
357- i -= 1
358- if i < 0 :
359- break
360- # add used nodes that are not already added and are in desired namespace
361- upstream_queue .extend (
362- [
363- n
364- for n in self .get_callers (self .uses_edges , item )
365- if n in self .uses_edges and n not in upstream_new_nodes and namespace in n .namespace
366- ]
367- )
368- # add defined nodes that are not already added and are in desired namespace
369- upstream_queue .extend (
370- [
371- n
372- for n in self .get_callers (self .defines_edges , item )
373- if n in self .defines_edges and n not in upstream_new_nodes and namespace in n .namespace
374- ]
375- )
359+ upstream_new_nodes = set ()
360+ if find_upward :
361+ # get callers of node
362+ upstream_new_nodes = new_nodes .copy ()
363+ upstream_queue = queue .copy ()
364+ i = max_iter
365+ while len (upstream_queue ) > 0 :
366+ item = upstream_queue .pop ()
367+ if item not in upstream_new_nodes :
368+ upstream_new_nodes .add (item )
369+ i -= 1
370+ if i < 0 :
371+ break
372+ # add used nodes that are not already added and are in desired namespace
373+ upstream_queue .extend (
374+ [
375+ n
376+ for n in self .get_callers (self .uses_edges , item )
377+ if n in self .uses_edges and n not in upstream_new_nodes and namespace in n .namespace
378+ ]
379+ )
380+ # add defined nodes that are not already added and are in desired namespace
381+ upstream_queue .extend (
382+ [
383+ n
384+ for n in self .get_callers (self .defines_edges , item )
385+ if n in self .defines_edges and n not in upstream_new_nodes and namespace in n .namespace
386+ ]
387+ )
376388
377389 return downstream_new_nodes .union (upstream_new_nodes )
378390
0 commit comments