Skip to content

Commit c57d387

Browse files
[patch] Persistent process (#476)
* Catch already-running children in Composite * Add temporary results infrastructure * Allow run results to be (de)serialized at run-time By adding a new flag, having `Node.on_run` directly handle the serialization of results (making `_on_run` and `_run_args` new abstract methods that have the behaviour of the old public methods), and deserialize temporary results instead of running when a node is already running and such results exist * Format black * Remove redundant arg * Save a checkpoint when running with result serialization So that the graph gets saved with the serializer in a running state * Refactor: extract method * Catch the case that you re-run a node still waiting for its serialized result * Add a helper method for resuming composite runs from broken processes * Revert resume_from_broken_process * Make cleaning revert to false It's a bit messier for the filesystem, but for now let's default to keeping the data around * Make the flag private I want this functionality in, but I'm not at all happy with the UI, and don't totally trust edge cases (e.g. input changing under our feet), so let's put it in private for now in anticipation of changes * Remove the checkpoint save We'll get a recovery file when we close the parent process anyhow * Extend doc * Don't delete checkpoint as it's no longer written * Refactor: slide * Update HPC example notebook With a real living example of `Node._serialize_result` working --------- Co-authored-by: pyiron-runner <pyiron@mpie.de>
1 parent e4a066c commit c57d387

File tree

9 files changed

+1421
-54
lines changed

9 files changed

+1421
-54
lines changed

notebooks/deepdive.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -3391,7 +3391,7 @@
33913391
" File \"/Users/huber/anaconda3/envs/pyiron_311/lib/python3.11/concurrent/futures/process.py\", line 261, in _process_worker\n",
33923392
" r = call_item.fn(*call_item.args, **call_item.kwargs)\n",
33933393
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
3394-
" File \"/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/nodes/function.py\", line 317, in on_run\n",
3394+
" File \"/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/nodes/function.py\", line 317, in _on_run\n",
33953395
" return self.node_function(**kwargs)\n",
33963396
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
33973397
" File \"/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/nodes/standard.py\", line 518, in Add\n",

notebooks/hpc_example.ipynb

+1,238-23
Large diffs are not rendered by default.

pyiron_workflow/node.py

+86-5
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
from __future__ import annotations
99

10-
from abc import ABC
10+
from abc import ABC, abstractmethod
1111
from concurrent.futures import Future
1212
from importlib import import_module
1313
from typing import Any, Literal, Optional, TYPE_CHECKING
1414

15+
import cloudpickle
1516
from pyiron_snippets.colors import SeabornColors
1617
from pyiron_snippets.dotdict import DotDict
1718

@@ -152,8 +153,8 @@ class Node(
152153
153154
This is an abstract class.
154155
Children *must* define how :attr:`inputs` and :attr:`outputs` are constructed,
155-
what will happen :meth:`on_run`, the :attr:`run_args` that will get passed to
156-
:meth:`on_run`, and how to :meth:`process_run_result` once :meth:`on_run` finishes.
156+
what will happen :meth:`_on_run`, the :attr:`run_args` that will get passed to
157+
:meth:`_on_run`, and how to :meth:`process_run_result` once :meth:`_on_run` finishes.
157158
They may optionally add additional signal channels to the signals IO.
158159
159160
Attributes:
@@ -192,6 +193,9 @@ class Node(
192193
autoload (Literal["pickle"] | StorageInterface | None): Whether to check
193194
for a matching saved node and what storage back end to use to do so (no
194195
auto-loading if the back end is `None`.)
196+
_serialize_result (bool): (IN DEVELOPMENT) Cloudpickle the output of running
197+
the node; this is useful if the run is happening in a parallel process and
198+
the parent process may be killed before it is finished. (Default is False.)
195199
signals (pyiron_workflow.io.Signals): A container for input and output
196200
signals, which are channels for controlling execution flow. By default, has
197201
a :attr:`signals.inputs.run` channel which has a callback to the
@@ -218,7 +222,7 @@ class Node(
218222
its internal structure.
219223
execute: An alias for :meth:`run`, but with flags to run right here, right now,
220224
and with the input it currently has.
221-
on_run: **Abstract.** Do the thing. What thing must be specified by child
225+
_on_run: **Abstract.** Do the thing. What thing must be specified by child
222226
classes.
223227
pull: An alias for :meth:`run` that runs everything upstream, then runs this
224228
node (but doesn't fire off the `ran` signal, so nothing happens farther
@@ -227,7 +231,7 @@ class Node(
227231
object is encountered).
228232
replace_with: If the node belongs to a parent, attempts to replace itself in
229233
that parent with a new provided node.
230-
run: Run the node function from :meth:`on_run`. Handles status automatically.
234+
run: Run the node function from :meth:`_on_run`. Handles status automatically.
231235
Various execution options are available as boolean flags.
232236
set_input_values: Allows input channels' values to be updated without any
233237
running.
@@ -290,6 +294,10 @@ def __init__(
290294
)
291295
self.checkpoint = checkpoint
292296
self.recovery: Literal["pickle"] | StorageInterface | None = "pickle"
297+
self._serialize_result = False # Advertised, but private to indicate
298+
# under-development status -- API may change to be more user-friendly
299+
self._do_clean: bool = False # Power-user override for cleaning up temporary
300+
# serialized results and empty directories (or not).
293301
self._cached_inputs = None
294302
self._user_data = {} # A place for power-users to bypass node-injection
295303

@@ -373,6 +381,29 @@ def _readiness_error_message(self) -> str:
373381
f" conform to type hints.\n" + self.readiness_report
374382
)
375383

384+
def on_run(self, *args, **kwargs) -> Any:
385+
save_result: bool = args[0]
386+
args = args[1:]
387+
result = self._on_run(*args, **kwargs)
388+
if save_result:
389+
self._temporary_result_pickle(result)
390+
return result
391+
392+
@abstractmethod
393+
def _on_run(self, *args, **kwargs) -> Any:
394+
pass
395+
396+
@property
397+
def run_args(self) -> tuple[tuple, dict]:
398+
args, kwargs = self._run_args
399+
args = (self._serialize_result,) + args
400+
return args, kwargs
401+
402+
@property
403+
@abstractmethod
404+
def _run_args(self, *args, **kwargs) -> Any:
405+
pass
406+
376407
def run(
377408
self,
378409
*args,
@@ -431,6 +462,22 @@ def run(
431462
Kwargs updating input channel values happens _first_ and will get
432463
overwritten by any subsequent graph-based data manipulation.
433464
"""
465+
if self.running and self._serialize_result:
466+
if self._temporary_result_file.is_file():
467+
return self._finish_run(
468+
self._temporary_result_unpickle(),
469+
raise_run_exceptions=raise_run_exceptions,
470+
run_exception_kwargs={},
471+
run_finally_kwargs={
472+
"emit_ran_signal": emit_ran_signal,
473+
"raise_run_exceptions": raise_run_exceptions,
474+
},
475+
)
476+
else:
477+
raise ValueError(
478+
f"{self.full_label} is still waiting for a serialized result"
479+
)
480+
434481
self.set_input_values(*args, **kwargs)
435482

436483
return super().run(
@@ -520,6 +567,9 @@ def _run_finally(self, /, emit_ran_signal: bool, raise_run_exceptions: bool):
520567
backend=self.recovery, filename=self.as_path().joinpath("recovery")
521568
)
522569

570+
if self._do_clean:
571+
self._clean_graph_directory()
572+
523573
def run_data_tree(self, run_parent_trees_too=False) -> None:
524574
"""
525575
Use topological analysis to build a tree of all upstream dependencies and run
@@ -628,6 +678,21 @@ def cache_hit(self):
628678
except:
629679
return False
630680

681+
@property
682+
def _temporary_result_file(self):
683+
return self.as_path().joinpath("run_result.tmp")
684+
685+
def _temporary_result_pickle(self, results):
686+
self._temporary_result_file.parent.mkdir(parents=True, exist_ok=True)
687+
self._temporary_result_file.touch(exist_ok=False)
688+
with self._temporary_result_file.open("wb") as f:
689+
cloudpickle.dump(results, f)
690+
691+
def _temporary_result_unpickle(self):
692+
with self._temporary_result_file.open("rb") as f:
693+
results = cloudpickle.load(f)
694+
return results
695+
631696
def _outputs_to_run_return(self):
632697
return DotDict(self.outputs.to_value_dict())
633698

@@ -994,6 +1059,22 @@ def report_import_readiness(self, tabs=0, report_so_far=""):
9941059
f"{'ok' if self.import_ready else 'NOT IMPORTABLE'}"
9951060
)
9961061

1062+
def _clean_graph_directory(self):
1063+
"""
1064+
Delete the temporary results file (if any), and then go from this node's
1065+
semantic directory up to its semantic root's directory removing any empty
1066+
directories. Note: doesn't do a sophisticated walk, so sibling empty
1067+
directories will cause a parent to identify as non-empty.
1068+
"""
1069+
self._temporary_result_file.unlink(missing_ok=True)
1070+
1071+
# Recursively remove empty directories
1072+
root_directory = self.semantic_root.as_path().parent
1073+
for parent in self._temporary_result_file.parents:
1074+
if parent == root_directory or not parent.exists() or any(parent.iterdir()):
1075+
break
1076+
parent.rmdir()
1077+
9971078
def display_state(self, state=None, ignore_private=True):
9981079
state = dict(self.__getstate__()) if state is None else state
9991080
if self.parent is not None:

pyiron_workflow/nodes/composite.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,27 @@ def deactivate_strict_hints(self):
141141
for node in self:
142142
node.deactivate_strict_hints()
143143

144-
def on_run(self):
144+
def _on_run(self):
145145
# Reset provenance and run status trackers
146146
self.provenance_by_execution = []
147147
self.provenance_by_completion = []
148-
self.running_children = []
148+
self.running_children = [n.label for n in self if n.running]
149149
self.signal_queue = []
150150

151-
for node in self.starting_nodes:
152-
node.run()
151+
if len(self.running_children) > 0: # Start from a broken process
152+
for label in self.running_children:
153+
self.children[label].run()
154+
# Running children will find serialized result and proceed,
155+
# or raise an error because they're already running
156+
else: # Start fresh
157+
for node in self.starting_nodes:
158+
node.run()
153159

160+
self._run_while_children_or_signals_exist()
161+
162+
return self
163+
164+
def _run_while_children_or_signals_exist(self):
154165
errors = {}
155166
while len(self.running_children) > 0 or len(self.signal_queue) > 0:
156167
try:
@@ -172,8 +183,6 @@ def on_run(self):
172183
f"{self.full_label} encountered multiple errors in children: {errors}"
173184
) from None
174185

175-
return self
176-
177186
def register_child_starting(self, child: Node) -> None:
178187
"""
179188
To be called by children when they start their run cycle.
@@ -218,7 +227,7 @@ def register_child_emitting(self, child: Node) -> None:
218227
self.signal_queue.append((firing, receiving))
219228

220229
@property
221-
def run_args(self) -> tuple[tuple, dict]:
230+
def _run_args(self) -> tuple[tuple, dict]:
222231
return (), {}
223232

224233
def process_run_result(self, run_output):

pyiron_workflow/nodes/for_loop.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,9 @@ def _setup_node(self) -> None:
236236
self.starting_nodes = input_nodes
237237
self._input_node_labels = tuple(n.label for n in input_nodes)
238238

239-
def on_run(self):
239+
def _on_run(self):
240240
self._build_body()
241-
return super().on_run()
241+
return super()._on_run()
242242

243243
def _build_body(self):
244244
"""

pyiron_workflow/nodes/function.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -313,11 +313,11 @@ def _build_outputs_preview(cls) -> dict[str, Any]:
313313
return preview if len(preview) > 0 else {"None": type(None)}
314314
# If clause facilitates functions with no return value
315315

316-
def on_run(self, **kwargs):
316+
def _on_run(self, **kwargs):
317317
return self.node_function(**kwargs)
318318

319319
@property
320-
def run_args(self) -> tuple[tuple, dict]:
320+
def _run_args(self) -> tuple[tuple, dict]:
321321
kwargs = self.inputs.to_value_dict()
322322
return (), kwargs
323323

pyiron_workflow/nodes/transform.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ class FromManyInputs(Transformer, ABC):
3636

3737
# _build_inputs_preview required from parent class
3838
# Inputs convert to `run_args` as a value dictionary
39-
# This must be commensurate with the internal expectations of on_run
39+
# This must be commensurate with the internal expectations of _on_run
4040

4141
@abstractmethod
42-
def on_run(self, **inputs_to_value_dict) -> Any:
42+
def _on_run(self, **inputs_to_value_dict) -> Any:
4343
"""Must take inputs kwargs"""
4444

4545
@property
46-
def run_args(self) -> tuple[tuple, dict]:
46+
def _run_args(self) -> tuple[tuple, dict]:
4747
return (), self.inputs.to_value_dict()
4848

4949
@classmethod
@@ -64,11 +64,11 @@ class ToManyOutputs(Transformer, ABC):
6464
# Must be commensurate with the dictionary returned by transform_to_output
6565

6666
@abstractmethod
67-
def on_run(self, input_object) -> callable[..., Any | tuple]:
67+
def _on_run(self, input_object) -> callable[..., Any | tuple]:
6868
"""Must take the single object to be transformed"""
6969

7070
@property
71-
def run_args(self) -> tuple[tuple, dict]:
71+
def _run_args(self) -> tuple[tuple, dict]:
7272
return (self.inputs[self._input_name].value,), {}
7373

7474
@classmethod
@@ -89,7 +89,7 @@ class InputsToList(_HasLength, FromManyInputs, ABC):
8989
_output_name: ClassVar[str] = "list"
9090
_output_type_hint: ClassVar[Any] = list
9191

92-
def on_run(self, **inputs_to_value_dict):
92+
def _on_run(self, **inputs_to_value_dict):
9393
return list(inputs_to_value_dict.values())
9494

9595
@classmethod
@@ -101,7 +101,7 @@ class ListToOutputs(_HasLength, ToManyOutputs, ABC):
101101
_input_name: ClassVar[str] = "list"
102102
_input_type_hint: ClassVar[Any] = list
103103

104-
def on_run(self, input_object: list):
104+
def _on_run(self, input_object: list):
105105
return {f"item_{i}": v for i, v in enumerate(input_object)}
106106

107107
@classmethod
@@ -184,7 +184,7 @@ class InputsToDict(FromManyInputs, ABC):
184184
list[str] | dict[str, tuple[Any | None, Any | NOT_DATA]]
185185
]
186186

187-
def on_run(self, **inputs_to_value_dict):
187+
def _on_run(self, **inputs_to_value_dict):
188188
return inputs_to_value_dict
189189

190190
@classmethod
@@ -284,7 +284,7 @@ class InputsToDataframe(_HasLength, FromManyInputs, ABC):
284284
_output_name: ClassVar[str] = "df"
285285
_output_type_hint: ClassVar[Any] = DataFrame
286286

287-
def on_run(self, *rows: dict[str, Any]) -> Any:
287+
def _on_run(self, *rows: dict[str, Any]) -> Any:
288288
df_dict = {}
289289
for i, row in enumerate(rows):
290290
for key, value in row.items():
@@ -295,7 +295,7 @@ def on_run(self, *rows: dict[str, Any]) -> Any:
295295
return DataFrame(df_dict)
296296

297297
@property
298-
def run_args(self) -> tuple[tuple, dict]:
298+
def _run_args(self) -> tuple[tuple, dict]:
299299
return tuple(self.inputs.to_value_dict().values()), {}
300300

301301
@classmethod
@@ -363,11 +363,11 @@ def _setup_node(self) -> None:
363363
):
364364
self.inputs[name] = self._dataclass_fields[name].default_factory()
365365

366-
def on_run(self, **inputs_to_value_dict):
366+
def _on_run(self, **inputs_to_value_dict):
367367
return self.dataclass(**inputs_to_value_dict)
368368

369369
@property
370-
def run_args(self) -> tuple[tuple, dict]:
370+
def _run_args(self) -> tuple[tuple, dict]:
371371
return (), self.inputs.to_value_dict()
372372

373373
@classmethod

tests/unit/nodes/test_composite.py

+32
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,38 @@ def test_with_executor(self):
603603
"retain its executor"
604604
)
605605

606+
def test_result_serialization(self):
607+
"""
608+
This is actually only a useful feature if you have an executor which will
609+
continue the process _after_ the parent python process has been shut down
610+
(e.g. you sent the run code off to a slurm queue using `executorlib`.), but
611+
we'll ensure that the plumbing works here by faking things a bit.
612+
"""
613+
self.comp.use_cache = False
614+
615+
self.comp.child = Composite.create.function_node(plus_one, x=42)
616+
self.comp.starting_nodes = [self.comp.child]
617+
618+
self.comp.child._serialize_result = True
619+
self.comp.child.use_cache = False
620+
self.comp.child._do_clean = False
621+
622+
out = self.comp.run()
623+
self.assertTrue(self.comp.child._temporary_result_file.is_file())
624+
self.assertEqual(self.comp.child.outputs.y.value, 42 + 1)
625+
626+
self.comp.child.running = True # Fake it
627+
self.comp.child._do_clean = True # Clean up this time
628+
self.comp.run()
629+
630+
self.assertFalse(self.comp.child._temporary_result_file.is_file())
631+
self.assertEqual(self.comp.child.outputs.y.value, 42 + 1)
632+
self.assertFalse(
633+
self.comp.as_path().is_dir(),
634+
msg="Actually, we expect cleanup to have removed empty directories up to "
635+
"and including the semantic root's own directory"
636+
)
637+
606638

607639
if __name__ == '__main__':
608640
unittest.main()

0 commit comments

Comments
 (0)