Skip to content

Restarting process #458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 74 additions & 8 deletions pyiron_workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

from __future__ import annotations

from abc import ABC
from abc import ABC, abstractmethod
from concurrent.futures import Future
from importlib import import_module
from inspect import getsource
from typing import Any, Literal, Optional, TYPE_CHECKING

import cloudpickle
from pyiron_snippets.colors import SeabornColors
from pyiron_snippets.dotdict import DotDict

Expand Down Expand Up @@ -152,8 +152,8 @@ class Node(

This is an abstract class.
Children *must* define how :attr:`inputs` and :attr:`outputs` are constructed,
what will happen :meth:`on_run`, the :attr:`run_args` that will get passed to
:meth:`on_run`, and how to :meth:`process_run_result` once :meth:`on_run` finishes.
what will happen :meth:`_on_run`, the :attr:`_run_args` that will get passed to
:meth:`_on_run`, and how to :meth:`process_run_result` once :meth:`_on_run` finishes.
They may optionally add additional signal channels to the signals IO.

Attributes:
Expand All @@ -178,8 +178,8 @@ class Node(
graph_path (str): The file-path-like path of node labels from the parent-most
node down to this node.
graph_root (Node): The parent-most node in this graph.
run_args (dict): **Abstract** the argmuments to use for actually running the
node. Must be specified in child classes.
run_args (dict): What to pass to the `on_run` method when :meth:`run` is called.
Leans on the abstract :attr:`_run_args` defined in child classes.
running (bool): Whether the node has called :meth:`run` and has not yet
received output from this call. (Default is False.)
checkpoint (Literal["pickle"] | StorageInterface | None): Whether to trigger a
Expand All @@ -188,6 +188,9 @@ class Node(
autoload (Literal["pickle"] | StorageInterface | None): Whether to check
for a matching saved node and what storage back end to use to do so (no
auto-loading if the back end is `None`.)
serialize_result (bool): Cloudpickle the output of running the node; this is
useful if the run is happening in a parallel process and the parent process
may be killed before it is finished.
signals (pyiron_workflow.io.Signals): A container for input and output
signals, which are channels for controlling execution flow. By default, has
a :attr:`signals.inputs.run` channel which has a callback to the
Expand All @@ -214,8 +217,8 @@ class Node(
its internal structure.
execute: An alias for :meth:`run`, but with flags to run right here, right now,
and with the input it currently has.
on_run: **Abstract.** Do the thing. What thing must be specified by child
classes.
on_run: What the node does on running, leans on abstract `_on_run` method
defined by children.
pull: An alias for :meth:`run` that runs everything upstream, then runs this
node (but doesn't fire off the `ran` signal, so nothing happens farther
downstream). "Upstream" may optionally break out of the local scope to run
Expand Down Expand Up @@ -277,6 +280,9 @@ def __init__(
checkpoint (Literal["pickle"] | StorageInterface | None): The storage
back end to use for saving the overall graph at the end of this node's
run. (Default is None, don't do checkpoint saves.)
serialize_result (bool): Cloudpickle the output of running the node; this
is useful if the run is happening in a parallel process and the parent
process may be killed before it is finished. (Default is False.)
**kwargs: Interpreted as node input data, with keys corresponding to
channel labels.
"""
Expand All @@ -285,6 +291,10 @@ def __init__(
parent=parent,
)
self.checkpoint = checkpoint
self.serialize_result = False
self._do_clean: bool = True # Power-user override for cleaning up serialized
# results (or not).

self._cached_inputs = None
self._user_data = {} # A place for power-users to bypass node-injection

Expand Down Expand Up @@ -368,6 +378,33 @@ def _readiness_error_message(self) -> str:
f" conform to type hints.\n" + self.readiness_report
)

def on_run(self, *args, **kwargs) -> Any:
save_result: bool = args[0]
save_file: Path = args[1]
args = args[2:]
result = self._on_run(*args, **kwargs)
if save_result:
save_file.parent.mkdir(parents=True, exist_ok=True)
save_file.touch(exist_ok=False)
with save_file.open(mode="wb") as f:
cloudpickle.dump(result, f)
return result

@abstractmethod
def _on_run(self, *args, **kwargs) -> Any:
pass

@property
def run_args(self) -> tuple[tuple, dict]:
args, kwargs = self._run_args
args = (self.serialize_result, self._temporary_results_file) + args
return args, kwargs

@property
@abstractmethod
def _run_args(self, *args, **kwargs) -> Any:
pass

def run(
self,
*args,
Expand Down Expand Up @@ -426,6 +463,10 @@ def run(
Kwargs updating input channel values happens _first_ and will get
overwritten by any subsequent graph-based data manipulation.
"""
if self.running and self._temporary_results_file.is_file():
# Bypass running and just load the results
return self._load_and_finish(emit_ran_signal)

self.set_input_values(*args, **kwargs)

if run_data_tree:
Expand Down Expand Up @@ -579,6 +620,8 @@ def _finish_run(self, run_output: tuple | Future) -> Any | tuple:
finally:
if self.checkpoint is not None:
self.save_checkpoint(self.checkpoint)
if self._temporary_results_file.is_file():
self._clean_temporary_results()

def _finish_run_and_emit_ran(self, run_output: tuple | Future) -> Any | tuple:
processed_output = self._finish_run(run_output)
Expand All @@ -596,6 +639,29 @@ def _finish_run_and_emit_ran(self, run_output: tuple | Future) -> Any | tuple:
"""
)

def _load_and_finish(self, emit_ran_signal: bool):
_finished_callback = (
self._finish_run_and_emit_ran if emit_ran_signal else self._finish_run
)
with self._temporary_results_file.open("rb") as f:
result = cloudpickle.load(f)
return _finished_callback(result)

@property
def _temporary_results_file(self) -> Path:
return self.as_path().joinpath("TMP_RESULT.CPKL")

def _clean_temporary_results(self):
if self._do_clean:
self._temporary_results_file.unlink()

# Recursively remove empty directories
root_directory = self.semantic_root.as_path().parent
for parent in self._temporary_results_file.parents:
if parent == root_directory or any(parent.iterdir()):
break
parent.rmdir()

@property
def emitting_channels(self) -> tuple[OutputSignal]:
return (self.signals.output.ran,)
Expand Down
6 changes: 3 additions & 3 deletions pyiron_workflow/nodes/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ def deactivate_strict_hints(self):
for node in self:
node.deactivate_strict_hints()

def on_run(self):
def _on_run(self):
# Reset provenance and run status trackers
self.provenance_by_execution = []
self.provenance_by_completion = []
self.running_children = []
self.running_children = [n.label for n in self if n.running]
self.signal_queue = []

for node in self.starting_nodes:
Expand Down Expand Up @@ -200,7 +200,7 @@ def register_child_emitting(self, child: Node) -> None:
self.signal_queue.append((firing, receiving))

@property
def run_args(self) -> tuple[tuple, dict]:
def _run_args(self) -> tuple[tuple, dict]:
return (), {}

def process_run_result(self, run_output):
Expand Down
4 changes: 2 additions & 2 deletions pyiron_workflow/nodes/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ def _setup_node(self) -> None:
self.starting_nodes = input_nodes
self._input_node_labels = tuple(n.label for n in input_nodes)

def on_run(self):
def _on_run(self):
self._build_body()
return super().on_run()
return super()._on_run()

def _build_body(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions pyiron_workflow/nodes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,11 @@ def _build_outputs_preview(cls) -> dict[str, Any]:
return preview if len(preview) > 0 else {"None": type(None)}
# If clause facilitates functions with no return value

def on_run(self, **kwargs):
def _on_run(self, **kwargs):
return self.node_function(**kwargs)

@property
def run_args(self) -> tuple[tuple, dict]:
def _run_args(self) -> tuple[tuple, dict]:
kwargs = self.inputs.to_value_dict()
return (), kwargs

Expand Down
26 changes: 13 additions & 13 deletions pyiron_workflow/nodes/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ class FromManyInputs(Transformer, ABC):
_output_type_hint: ClassVar[Any] = None

# _build_inputs_preview required from parent class
# Inputs convert to `run_args` as a value dictionary
# This must be commensurate with the internal expectations of on_run
# Inputs convert to `_run_args` as a value dictionary
# This must be commensurate with the internal expectations of _on_run

@abstractmethod
def on_run(self, **inputs_to_value_dict) -> Any:
def _on_run(self, **inputs_to_value_dict) -> Any:
"""Must take inputs kwargs"""

@property
def run_args(self) -> tuple[tuple, dict]:
def _run_args(self) -> tuple[tuple, dict]:
return (), self.inputs.to_value_dict()

@classmethod
Expand All @@ -64,11 +64,11 @@ class ToManyOutputs(Transformer, ABC):
# Must be commensurate with the dictionary returned by transform_to_output

@abstractmethod
def on_run(self, input_object) -> callable[..., Any | tuple]:
def _on_run(self, input_object) -> callable[..., Any | tuple]:
"""Must take the single object to be transformed"""

@property
def run_args(self) -> tuple[tuple, dict]:
def _run_args(self) -> tuple[tuple, dict]:
return (self.inputs[self._input_name].value,), {}

@classmethod
Expand All @@ -89,7 +89,7 @@ class InputsToList(_HasLength, FromManyInputs, ABC):
_output_name: ClassVar[str] = "list"
_output_type_hint: ClassVar[Any] = list

def on_run(self, **inputs_to_value_dict):
def _on_run(self, **inputs_to_value_dict):
return list(inputs_to_value_dict.values())

@classmethod
Expand All @@ -101,7 +101,7 @@ class ListToOutputs(_HasLength, ToManyOutputs, ABC):
_input_name: ClassVar[str] = "list"
_input_type_hint: ClassVar[Any] = list

def on_run(self, input_object: list):
def _on_run(self, input_object: list):
return {f"item_{i}": v for i, v in enumerate(input_object)}

@classmethod
Expand Down Expand Up @@ -184,7 +184,7 @@ class InputsToDict(FromManyInputs, ABC):
list[str] | dict[str, tuple[Any | None, Any | NOT_DATA]]
]

def on_run(self, **inputs_to_value_dict):
def _on_run(self, **inputs_to_value_dict):
return inputs_to_value_dict

@classmethod
Expand Down Expand Up @@ -284,7 +284,7 @@ class InputsToDataframe(_HasLength, FromManyInputs, ABC):
_output_name: ClassVar[str] = "df"
_output_type_hint: ClassVar[Any] = DataFrame

def on_run(self, *rows: dict[str, Any]) -> Any:
def _on_run(self, *rows: dict[str, Any]) -> Any:
df_dict = {}
for i, row in enumerate(rows):
for key, value in row.items():
Expand All @@ -295,7 +295,7 @@ def on_run(self, *rows: dict[str, Any]) -> Any:
return DataFrame(df_dict)

@property
def run_args(self) -> tuple[tuple, dict]:
def _run_args(self) -> tuple[tuple, dict]:
return tuple(self.inputs.to_value_dict().values()), {}

@classmethod
Expand Down Expand Up @@ -363,11 +363,11 @@ def _setup_node(self) -> None:
):
self.inputs[name] = self._dataclass_fields[name].default_factory()

def on_run(self, **inputs_to_value_dict):
def _on_run(self, **inputs_to_value_dict):
return self.dataclass(**inputs_to_value_dict)

@property
def run_args(self) -> tuple[tuple, dict]:
def _run_args(self) -> tuple[tuple, dict]:
return (), self.inputs.to_value_dict()

@classmethod
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/nodes/test_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,38 @@ def test_with_executor(self):
"retain its executor"
)

def test_result_serialization(self):
"""
This is actually only a useful feature if you have an executor which will
continue the process _after_ the parent python process has been shut down
(e.g. you sent the run code off to a slurm queue using `executorlib`.), but
we'll ensure that the plumbing works here by faking things a bit.
"""
self.comp.use_cache = False

self.comp.child = Composite.create.function_node(plus_one, x=42)
self.comp.starting_nodes = [self.comp.child]

self.comp.child.serialize_result = True
self.comp.child.use_cache = False
self.comp.child._do_clean = False

out = self.comp.run()
self.assertTrue(self.comp.child._temporary_results_file.is_file())
self.assertEqual(self.comp.child.outputs.y.value, 42 + 1)

self.comp.child.running = True # Fake it
self.comp.child._do_clean = True # Clean up this time
self.comp.run()

self.assertFalse(self.comp.child._temporary_results_file.is_file())
self.assertEqual(self.comp.child.outputs.y.value, 42 + 1)
self.assertFalse(
self.comp.as_path().is_dir(),
msg="Actually, we expect cleanup to have removed empty directories up to "
"and including the semantic root's own directory"
)


if __name__ == '__main__':
unittest.main()
34 changes: 32 additions & 2 deletions tests/unit/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def inputs(self) -> Inputs:
def outputs(self) -> OutputsWithInjection:
return self._outputs

def on_run(self, *args, **kwargs):
def _on_run(self, *args, **kwargs):
return add_one(*args)

@property
def run_args(self) -> dict:
def _run_args(self) -> dict:
return (self.inputs.x.value,), {}

def process_run_result(self, run_output):
Expand Down Expand Up @@ -511,6 +511,36 @@ def test_checkpoint(self):
finally:
saves.delete_storage(backend) # Clean up

def test_result_serialization(self):
"""
This is actually only a useful feature if you have an executor which will
continue the process _after_ the parent python process has been shut down
(e.g. you sent the run code off to a slurm queue using `executorlib`.), but
we'll ensure that the plumbing works here by faking things a bit.
"""
n = ANode(label="test", x=42)
n.serialize_result = True
n.use_cache = False
n._do_clean = False # Power-user override to prevent the serialization from
# being removed
out = n()
self.assertTrue(
n._temporary_results_file.is_file(),
msg="Sanity check that we've saved the output"
)
# Now fake it
n.running = True
n._do_clean = True # This time clean up after yourself
reloaded = n()
self.assertEqual(out, reloaded)
self.assertFalse(n.running)
self.assertFalse(n._temporary_results_file.is_file())
self.assertFalse(
n.as_path().is_dir(),
msg="Actually, we expect cleanup to have removed empty directories up to "
"and including the node's own directory"
)


if __name__ == '__main__':
unittest.main()
Loading