Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 149 additions & 30 deletions flowrep/models/parsers/dependency_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import ast
import builtins
import types
from collections.abc import Callable

Expand All @@ -9,6 +12,13 @@
CallDependencies = dict[versions.VersionInfo, Callable]


def _get_collector(func: types.FunctionType) -> CallCollector:
tree = parser_helpers.get_ast_function_node(func)
collector = CallCollector()
collector.visit(tree)
return collector


def get_call_dependencies(
func: types.FunctionType,
version_scraping: versions.VersionScrapingMap | None = None,
Expand All @@ -23,6 +33,35 @@ def get_call_dependencies(
for every resolved callee that is a :class:`~types.FunctionType` (i.e. has
inspectable source), the function recurses into the callee's own scope.

Args:
func: The function whose call-graph to analyse.
version_scraping (VersionScrapingMap | None): Since some modules may store
their version in other ways, this provides an optional map between module
names and callables to leverage for extracting that module's version.
_call_dependencies: Accumulator for recursive calls — do not pass manually.
_visited: Fully-qualified names already traversed — do not pass manually.

Returns:
A mapping from :class:`VersionInfo` to the callables found under that
identity across the entire (sub-)tree.
"""
return get_dependencies(func, version_scraping, _call_dependencies, _visited)[0]


def get_dependencies(
func: types.FunctionType,
version_scraping: versions.VersionScrapingMap | None = None,
_call_dependencies: CallDependencies | None = None,
_visited: set[str] | None = None,
) -> tuple[CallDependencies, list[object]]:
"""
Recursively collect all callable dependencies of *func* via AST introspection.

Each dependency is keyed by its :class:`~pyiron_snippets.versions.VersionInfo`
and maps to the callables instance with that identity. The search is depth-first:
for every resolved callee that is a :class:`~types.FunctionType` (i.e. has
inspectable source), the function recurses into the callee's own scope.

Args:
func: The function whose call-graph to analyse.
version_scraping (VersionScrapingMap | None): Since some modules may store
Expand All @@ -36,47 +75,35 @@ def get_call_dependencies(
identity across the entire (sub-)tree.
"""
call_dependencies: CallDependencies = _call_dependencies or {}
variables = []
visited: set[str] = _visited or set()

func_fqn = versions.VersionInfo.of(func).fully_qualified_name
if func_fqn in visited:
return call_dependencies
return call_dependencies, variables
visited.add(func_fqn)

scope = object_scope.get_scope(func)
tree = parser_helpers.get_ast_function_node(func)
collector = CallCollector()
collector.visit(tree)
collector = _get_collector(func)
items = collector.items.difference(set(dir(builtins)))

for call in collector.calls:
for item in items:
try:
caller = object_scope.resolve_symbol_to_object(call, scope)
obj = object_scope.resolve_attribute_to_object(item, scope)
except (ValueError, TypeError):
continue

if not callable(caller): # pragma: no cover
# Under remotely normal circumstances, this should be unreachable
raise TypeError(
f"Caller {caller} is not callable, yet was generated from the list of "
f"ast.Call calls, in particular {call}. We're expecting these to "
f"actually connect to callables. Please raise a GitHub issue if you "
f"think this is not a mistake."
)
if callable(obj): # pragma: no cover
info = versions.VersionInfo.of(obj, version_scraping=version_scraping)
call_dependencies[info] = obj

info = versions.VersionInfo.of(caller, version_scraping=version_scraping)
# In principle, we open ourselves to overwriting an existing dependency here,
# but it would need to somehow have exactly the same version info (including
# qualname) yet be a different object.
# This ought not happen by accident, and in case it somehow does happen on
# purpose (it probably shouldn't), we just silently keep the more recent one.

call_dependencies[info] = caller

# Depth-first search on dependencies — only possible when we have source
if isinstance(caller, types.FunctionType):
get_call_dependencies(caller, version_scraping, call_dependencies, visited)
# Depth-first search on dependencies — only possible when we have source
if isinstance(obj, types.FunctionType):
get_call_dependencies(obj, version_scraping, call_dependencies, visited)
else:
variables.append(obj)

return call_dependencies
return call_dependencies, variables


def split_by_version_availability(
Expand Down Expand Up @@ -104,8 +131,100 @@ def split_by_version_availability(

class CallCollector(ast.NodeVisitor):
def __init__(self):
self.calls: list[ast.expr] = []
self.items: set[str] = set()
self.local_vars: set[str] = set()

def _append_item(self, node: ast.expr) -> None:
item = ast.unparse(node)
if item.split(".")[0] not in self.local_vars:
self.items.add(item.split("(")[0])

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
# Collect function arguments as local variables
for arg in node.args.args:
self.local_vars.add(arg.arg)
# Check for type hints in arguments
if arg.annotation and isinstance(arg.annotation, ast.Name):
self._append_item(arg.annotation)

# Check for type hints in the return type
if node.returns and isinstance(node.returns, ast.Name):
self._append_item(node.returns)

# Visit the body of the function
self.generic_visit(node)

# Clear local variables after leaving the function scope
self.local_vars.clear()

def visit_Assign(self, node: ast.Assign) -> None:
# Handle multiple assignments and unpacking
for target in node.targets:
self._process_assignment_target(target)
self.generic_visit(node)

def _process_assignment_target(self, target):
# Recursively process assignment targets to handle unpacking
if isinstance(target, ast.Attribute):
if target.id not in self.local_vars:
self._append_item(target)
elif isinstance(target, ast.Name):
# Add the variable name to local_vars
self.local_vars.add(target.id)
elif isinstance(target, (ast.Tuple, ast.List)):
# Handle tuple or list unpacking (e.g., x, y = ...)
for element in target.elts:
self._process_assignment_target(element)

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
# Handle annotated assignments (e.g., x: CustomType = 42)
if isinstance(node.target, ast.Name):
self.local_vars.add(node.target.id)
if node.annotation and isinstance(node.annotation, ast.Name):
self._append_item(node.annotation)
self.generic_visit(node)

def visit_Name(self, node: ast.Name) -> None:
# Collect all variables that are not locally defined
if node.id not in self.local_vars:
self._append_item(node)

def visit_Attribute(self, node: ast.Attribute) -> None:
# Collect attributes that are not locally defined
self._append_item(node)

def visit_For(self, node: ast.For) -> None:
# Handle loop variables as local variables
self._process_assignment_target(node.target)
self.generic_visit(node)

def visit_With(self, node: ast.With) -> None:
# Handle variables defined in with statements (e.g., with open(...) as f)
for item in node.items:
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
self.local_vars.add(item.optional_vars.id)
self.generic_visit(node)

def visit_ListComp(self, node: ast.ListComp) -> None:
# Handle variables defined in list comprehensions
for generator in node.generators:
self._process_assignment_target(generator.target)
self.generic_visit(node)

def visit_DictComp(self, node: ast.DictComp) -> None:
# Handle variables defined in dict comprehensions
for generator in node.generators:
self._process_assignment_target(generator.target)
self.generic_visit(node)

def visit_SetComp(self, node: ast.SetComp) -> None:
# Handle variables defined in set comprehensions
for generator in node.generators:
self._process_assignment_target(generator.target)
self.generic_visit(node)

def visit_Call(self, node: ast.Call) -> None:
self.calls.append(node.func)
def visit_GeneratorExp(self, node: ast.GeneratorExp) -> None:
# Handle variables defined in generator expressions
for generator in node.generators:
self._process_assignment_target(generator.target)
self.generic_visit(node)
22 changes: 12 additions & 10 deletions flowrep/models/parsers/object_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,30 @@ def get_scope(func: FunctionType) -> ScopeProxy:
return ScopeProxy(inspect.getmodule(func).__dict__ | vars(builtins))


def resolve_attribute_to_object(attribute: str, scope: ScopeProxy | object) -> object:
obj = None
try:
for attr in attribute.split("."):
obj = getattr(obj or scope, attr)
return obj
except AttributeError as e:
raise ValueError(f"Could not find attribute '{attr}' of {attribute}") from e


def resolve_symbol_to_object(
node: ast.expr, # Expecting a Name or Attribute here, and will otherwise TypeError
scope: ScopeProxy | object,
_chain: list[str] | None = None,
) -> object:
""" """
_chain = _chain or []
error_suffix = f" while attempting to resolve the symbol chain '{'.'.join(_chain)}'"
if isinstance(node, ast.Name):
attr = node.id
try:
obj = getattr(scope, attr)
for attr in _chain:
obj = getattr(obj, attr)
return obj
except AttributeError as e:
raise ValueError(f"Could not find attribute '{attr}' {error_suffix}") from e
return resolve_attribute_to_object(".".join([node.id] + _chain), scope)
elif isinstance(node, ast.Attribute):
return resolve_symbol_to_object(node.value, scope, [node.attr] + _chain)
else:
raise TypeError(
f"Cannot resolve symbol {node} {error_suffix}. "
f"Cannot resolve symbol {node} or the symbol chain '{'.'.join(_chain)}'. "
f"Expected an ast.Name or chain of ast.Attribute and ast.Name, but got "
f"{node}."
)
22 changes: 21 additions & 1 deletion tests/unit/models/parsers/test_dependency_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ def _fqns(deps: dependency_parser.CallDependencies) -> set[str]:
return {info.fully_qualified_name for info in deps}


MyCustomType = int | float


def _custom_type_used(x):
return isinstance(x, MyCustomType)


def _custom_type_type_hint(x: MyCustomType):
y = x
return y


class TestGetCallDependencies(unittest.TestCase):
"""Tests for :func:`dependency_parser.get_call_dependencies`."""

Expand Down Expand Up @@ -132,7 +144,7 @@ def test_cycle_does_not_recurse_infinitely(self):

def test_builtin_callable_included(self):
deps = dependency_parser.get_call_dependencies(_calls_len)
self.assertIn(_fqn(len), _fqns(deps))
self.assertEqual(_fqns(deps), set())

def test_returns_dict_type(self):
deps = dependency_parser.get_call_dependencies(_leaf)
Expand Down Expand Up @@ -171,6 +183,14 @@ def test_non_callable_resolved_symbol_is_skipped(self):
deps = dependency_parser.get_call_dependencies(_calls_non_callable)
self.assertIsInstance(deps, dict)

def test_variables(self):
self.assertEqual(
dependency_parser.get_dependencies(_custom_type_used)[1], [int | float]
)
self.assertEqual(
dependency_parser.get_dependencies(_custom_type_type_hint)[1], [int | float]
)


class TestSplitByVersionAvailability(unittest.TestCase):
"""Tests for :func:`dependency_parser.split_by_version_availability`."""
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/models/parsers/test_object_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,7 @@ def test_unrecognized_node_raises(self):
node = ast.Constant(value=42)
with self.assertRaises(TypeError):
object_scope.resolve_symbol_to_object(node, scope)


if __name__ == "__main__":
unittest.main()
Loading