Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .binder/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ dependencies:
- python >=3.11, <3.14
- networkx =3.6.1
- pydantic =2.12.5
- pyiron_snippets =1.1.0
- pyiron_snippets =1.2.0
2 changes: 1 addition & 1 deletion .ci_support/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ dependencies:
- python >=3.11, <3.14
- networkx =3.6.1
- pydantic =2.12.5
- pyiron_snippets =1.1.0
- pyiron_snippets =1.2.0
2 changes: 1 addition & 1 deletion .ci_support/lower-bounds.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ dependencies:
- python =3.11
- networkx =3.4.2
- pydantic =2.12.0
- pyiron_snippets =1.1.0
- pyiron_snippets =1.2.0
2 changes: 1 addition & 1 deletion docs/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ dependencies:
- python >=3.11, <3.14
- networkx =3.6.1
- pydantic =2.12.5
- pyiron_snippets =1.1.0
- pyiron_snippets =1.2.0
24 changes: 19 additions & 5 deletions flowrep/models/parsers/atomic_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,16 +256,30 @@ def get_labeled_recipe(
ast_call: ast.Call,
existing_names: Iterable[str],
scope: object_scope.ScopeProxy,
info_factory: versions.VersionInfoFactory,
) -> helper_models.LabeledNode:
child_call = cast(
FunctionType, object_scope.resolve_symbol_to_object(ast_call.func, scope)
)
# Since it is the .func attribute of an ast.Call,
# the retrieved object had better be a function
child_recipe = (
child_call.flowrep_recipe
if hasattr(child_call, "flowrep_recipe")
else parse_atomic(child_call)
)
if hasattr(child_call, "flowrep_recipe"):
child_recipe = child_call.flowrep_recipe
if hasattr(child_recipe, "source") and isinstance(
child_recipe.source, versions.VersionInfo
):
child_recipe.source.validate_constraints(
forbid_main=info_factory.forbid_main,
forbid_locals=info_factory.forbid_locals,
require_version=info_factory.require_version,
)
else:
child_recipe = parse_atomic(
child_call,
version_scraping=info_factory.version_scraping,
forbid_main=info_factory.forbid_main,
forbid_locals=info_factory.forbid_locals,
require_version=info_factory.require_version,
)
child_name = label_helpers.unique_suffix(child_call.__name__, existing_names)
return helper_models.LabeledNode(label=child_name, node=child_recipe)
21 changes: 10 additions & 11 deletions flowrep/models/parsers/case_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import ast
import dataclasses

from pyiron_snippets import versions

from flowrep.models import edge_models, subgraph_validation
from flowrep.models.nodes import helper_models
from flowrep.models.parsers import (
Expand All @@ -18,6 +20,7 @@ def parse_case(
test: ast.expr,
scope: object_scope.ScopeProxy,
symbol_map: symbol_scope.SymbolScope,
info_factory: versions.VersionInfoFactory,
label: str,
) -> tuple[helper_models.LabeledNode, edge_models.InputEdges]:
"""
Expand All @@ -31,14 +34,14 @@ def parse_case(
"Test conditions must be a function call, but got " f"{type(test).__name__}"
)

condition = atomic_parser.get_labeled_recipe(test, set(), scope)
condition = atomic_parser.get_labeled_recipe(test, set(), scope, info_factory)
if len(condition.node.outputs) != 1:
raise ValueError(
f"If/elif condition must return exactly one value (and it had better be "
f"truthy), but got {condition.node.outputs}"
)

scope_copy = symbol_map.fork_scope()
scope_copy = symbol_map.fork()
parser_helpers.consume_call_arguments(scope_copy, test, condition)
return _relabel_node_data(condition, scope_copy.input_edges, label)

Expand Down Expand Up @@ -70,18 +73,14 @@ def to_labeled_node(self) -> helper_models.LabeledNode:


def walk_branch(
label: str,
stmts: list[ast.stmt],
symbol_map: symbol_scope.SymbolScope,
scope: object_scope.ScopeProxy,
walker_factory: parser_protocol.WalkerFactory,
walker: parser_protocol.BodyWalker, label: str, stmts: list[ast.stmt]
) -> WalkedBranch:
fork = symbol_map.fork_scope()
w = walker_factory(scope, fork)
w.walk(stmts)
fork = walker.symbol_map.fork()
branch_walker = walker.fork(new_symbol_map=fork)
branch_walker.walk(stmts)
assigned = fork.assigned_symbols
fork.produce_symbols(assigned)
return WalkedBranch(label, w, assigned)
return WalkedBranch(label, branch_walker, assigned)


def wire_inputs(
Expand Down
23 changes: 9 additions & 14 deletions flowrep/models/parsers/for_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from flowrep.models import edge_models
from flowrep.models.nodes import for_model, helper_models
from flowrep.models.parsers import object_scope, parser_protocol, symbol_scope
from flowrep.models.parsers import parser_protocol, symbol_scope

FOR_BODY_LABEL: str = "body"

Expand All @@ -24,41 +24,36 @@ class _IterationAxis(NamedTuple):


def parse_for_node(
tree: ast.For,
scope: object_scope.ScopeProxy,
symbol_map: symbol_scope.SymbolScope,
walker_factory: parser_protocol.WalkerFactory,
walker: parser_protocol.BodyWalker, tree: ast.For
) -> for_model.ForNode:
"""
Walk a for-loop.

Args:
walker: A walker to fork and use for collecting state inside the tree.
tree: The top-level ``ast.For`` node (may contain immediately
nested for-headers that declare additional iteration axes).
scope: Object-level scope for resolving callable references.
symbol_map: The enclosing :class:`SymbolScope` (used for forking).
walker_factory: Callable that creates a :class:`BodyWalker` from a
:class:`SymbolScope`. Avoids a circular import with
``workflow_parser.WorkflowParser``.
"""
# Parse the iteration header — pure AST, no parser state needed
nested_iters, zipped_iters, body_tree = _parse_for_iterations(tree)
all_iters = nested_iters + zipped_iters

# When we fork the scope here, we replace iterated-over symbols with iteration
# variables, all as InputSources from the body's perspective
body_symbol_map = symbol_map.fork_scope(
body_symbol_map = walker.symbol_map.fork(
{src: var for var, src in all_iters},
available_accumulators=symbol_map.declared_accumulators.copy(),
available_accumulators=walker.symbol_map.declared_accumulators.copy(),
)

body_walker = walker_factory(scope, body_symbol_map)
body_walker = walker.fork(new_symbol_map=body_symbol_map)
body_walker.walk(body_tree.body)
consumed = body_walker.symbol_map.consumed_accumulators

_validate_some_output_exists(consumed)
_validate_no_unused_iterators(all_iters, body_walker, consumed)
_validate_no_leaked_reassignments(all_iters, body_walker, consumed, symbol_map)
_validate_no_leaked_reassignments(
all_iters, body_walker, consumed, walker.symbol_map
)

nested_ports = [var for var, _ in nested_iters]
zipped_ports = [var for var, _ in zipped_iters]
Expand Down
34 changes: 10 additions & 24 deletions flowrep/models/parsers/if_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@

from flowrep.models import edge_models
from flowrep.models.nodes import helper_models, if_model
from flowrep.models.parsers import (
case_helpers,
object_scope,
parser_protocol,
symbol_scope,
)
from flowrep.models.parsers import case_helpers, parser_protocol

IF_CONDITION_LABEL_PREFIX: str = "condition"
IF_BODY_LABEL_PREFIX: str = "body"
Expand All @@ -26,22 +21,13 @@ class _CaseComponents:
body: case_helpers.WalkedBranch


def parse_if_node(
tree: ast.If,
scope: object_scope.ScopeProxy,
symbol_map: symbol_scope.SymbolScope,
walker_factory: parser_protocol.WalkerFactory,
):
def parse_if_node(walker: parser_protocol.BodyWalker, tree: ast.If) -> if_model.IfNode:
"""
Walk an if/elif/else chain.

Args:
walker: A walker to fork and use for collecting state inside the tree.
tree: The top-level ``ast.If`` node.
scope: Object-level scope for resolving callable references.
symbol_map: The enclosing :class:`SymbolScope` (used for forking).
walker_factory: Callable that creates a :class:`BodyWalker` from a
:class:`SymbolScope`. Avoids a circular import with
``workflow_parser.WorkflowParser``.
"""

cases: list[_CaseComponents] = []
Expand All @@ -55,11 +41,13 @@ def parse_if_node(
body_label = f"{IF_BODY_LABEL_PREFIX}_{idx}"

labeled_cond, cond_inputs = case_helpers.parse_case(
test_expr, scope, symbol_map, cond_label
)
body = case_helpers.walk_branch(
body_label, body_stmts, symbol_map, scope, walker_factory
test_expr,
walker.scope,
walker.symbol_map,
walker.info_factory,
cond_label,
)
body = case_helpers.walk_branch(walker, body_label, body_stmts)
cases.append(
_CaseComponents(
condition=labeled_cond,
Expand All @@ -70,9 +58,7 @@ def parse_if_node(

# --- process else case (if present) ---
if else_stmts is not None:
else_branch = case_helpers.walk_branch(
IF_ELSE_LABEL, else_stmts, symbol_map, scope, walker_factory
)
else_branch = case_helpers.walk_branch(walker, IF_ELSE_LABEL, else_stmts)

# --- wire edges ---
body_branches = [cc.body for cc in cases]
Expand Down
14 changes: 7 additions & 7 deletions flowrep/models/parsers/parser_protocol.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
from __future__ import annotations

import ast
from collections.abc import Callable
from typing import Protocol, runtime_checkable

from pyiron_snippets import versions

from flowrep.models import edge_models
from flowrep.models.nodes import union, workflow_model
from flowrep.models.parsers import object_scope, symbol_scope

WalkerFactory = Callable[
[object_scope.ScopeProxy, symbol_scope.SymbolScope], "BodyWalker"
]


@runtime_checkable
class BodyWalker(Protocol):
"""What control flow parsers need to walk a sub-body."""

scope: object_scope.ScopeProxy
symbol_map: symbol_scope.SymbolScope
info_factory: versions.VersionInfoFactory
nodes: union.Nodes

@property
Expand All @@ -36,8 +34,10 @@ def output_edges(self) -> edge_models.OutputEdges: ...
@property
def outputs(self) -> list[str]: ...

def visit(self, stmt: ast.AST) -> None: ...
def build_model(self) -> workflow_model.WorkflowNode: ...

def fork(self, *, new_symbol_map: symbol_scope.SymbolScope) -> BodyWalker: ...

def walk(self, statements: list[ast.stmt]) -> None: ...

def build_model(self) -> workflow_model.WorkflowNode: ...
def visit(self, stmt: ast.AST) -> None: ...
2 changes: 1 addition & 1 deletion flowrep/models/parsers/symbol_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def use_accumulator(self, accumulator_symbol: str, appended_symbol: str) -> None
self.consumed_accumulators[accumulator_symbol] = appended_symbol

# --- Forking for child scopes ---
def fork_scope(
def fork(
self,
symbol_remap: dict[str, str] | None = None,
available_accumulators: set[str] | None = None,
Expand Down
36 changes: 6 additions & 30 deletions flowrep/models/parsers/try_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,21 @@
from pyiron_snippets import versions

from flowrep.models.nodes import helper_models, try_model
from flowrep.models.parsers import (
case_helpers,
object_scope,
parser_protocol,
symbol_scope,
)
from flowrep.models.parsers import case_helpers, object_scope, parser_protocol

TRY_BODY_LABEL: str = "try_body"
EXCEPT_BODY_LABEL_PREFIX: str = "except_body"


def parse_try_node(
tree: ast.Try,
scope: object_scope.ScopeProxy,
symbol_map: symbol_scope.SymbolScope,
walker_factory: parser_protocol.WalkerFactory,
walker: parser_protocol.BodyWalker, tree: ast.Try
) -> try_model.TryNode:
"""
Walk a try/except block.

Args:
walker: A walker to fork and use for collecting state inside the tree.
tree: The ``ast.Try`` node.
scope: Object-level scope for resolving callable references.
symbol_map: The enclosing :class:`SymbolScope` (used for forking).
walker_factory: Callable that creates a :class:`BodyWalker` from a
:class:`SymbolScope`. Avoids a circular import with
``workflow_parser.WorkflowParser``.
"""
# 0. Fail early for unsupported syntax
if tree.orelse:
Expand All @@ -45,29 +33,17 @@ def parse_try_node(
)

# 1. Parse the try body
try_branch = case_helpers.walk_branch(
TRY_BODY_LABEL,
tree.body,
symbol_map,
scope,
walker_factory,
)
try_branch = case_helpers.walk_branch(walker, TRY_BODY_LABEL, tree.body)

# 2. Parse each except handler
exception_groups: list[list[str]] = []
except_branches: list[case_helpers.WalkedBranch] = []
for idx, handler in enumerate(tree.handlers):
body_label = f"{EXCEPT_BODY_LABEL_PREFIX}_{idx}"

exception_groups.append(_parse_exception_types(handler, scope))
exception_groups.append(_parse_exception_types(handler, walker.scope))

exception_branch = case_helpers.walk_branch(
body_label,
handler.body,
symbol_map,
scope,
walker_factory,
)
exception_branch = case_helpers.walk_branch(walker, body_label, handler.body)
except_branches.append(exception_branch)

# 3. Wire edges
Expand Down
Loading
Loading