Skip to content

Commit d8f84a4

Browse files
Write cache first (#492)
* revert test * Write cache before updating future * Add more tests * Add test for error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 038eaec commit d8f84a4

File tree

3 files changed

+64
-8
lines changed

3 files changed

+64
-8
lines changed

executorlib/interactive/shared.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -624,13 +624,20 @@ def _execute_task_with_cache(
624624
os.makedirs(cache_directory, exist_ok=True)
625625
file_name = os.path.join(cache_directory, task_key + ".h5out")
626626
if task_key + ".h5out" not in os.listdir(cache_directory):
627-
_execute_task(
628-
interface=interface,
629-
task_dict=task_dict,
630-
future_queue=future_queue,
631-
)
632-
data_dict["output"] = future.result()
633-
dump(file_name=file_name, data_dict=data_dict)
627+
f = task_dict.pop("future")
628+
if f.set_running_or_notify_cancel():
629+
try:
630+
result = interface.send_and_receive_dict(input_dict=task_dict)
631+
data_dict["output"] = result
632+
dump(file_name=file_name, data_dict=data_dict)
633+
f.set_result(result)
634+
except Exception as thread_exception:
635+
interface.shutdown(wait=True)
636+
future_queue.task_done()
637+
f.set_exception(exception=thread_exception)
638+
raise thread_exception
639+
else:
640+
future_queue.task_done()
634641
else:
635642
_, result = get_output(file_name=file_name)
636643
future = task_dict["future"]

tests/test_executor_backend_mpi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def tearDown(self):
9797
)
9898
def test_meta_executor_parallel_cache(self):
9999
with Executor(
100-
max_cores=2,
100+
max_workers=2,
101101
resource_dict={"cores": 2},
102102
backend="local",
103103
block_allocation=True,

tests/test_local_executor.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import importlib.util
33
from queue import Queue
44
from time import sleep
5+
import shutil
56
import unittest
67

78
import numpy as np
@@ -16,6 +17,12 @@
1617
from executorlib.standalone.interactive.backend import call_funct
1718
from executorlib.standalone.serialize import cloudpickle_register
1819

20+
try:
21+
import h5py
22+
23+
skip_h5py_test = False
24+
except ImportError:
25+
skip_h5py_test = True
1926

2027
skip_mpi4py_test = importlib.util.find_spec("mpi4py") is None
2128

@@ -473,3 +480,45 @@ def test_execute_task_parallel(self):
473480
)
474481
self.assertEqual(f.result(), [np.array(4), np.array(4)])
475482
q.join()
483+
484+
485+
class TestFuturePoolCache(unittest.TestCase):
486+
def tearDown(self):
487+
shutil.rmtree("./cache")
488+
489+
@unittest.skipIf(
490+
skip_h5py_test, "h5py is not installed, so the h5py tests are skipped."
491+
)
492+
def test_execute_task_cache(self):
493+
f = Future()
494+
q = Queue()
495+
q.put({"fn": calc, "args": (), "kwargs": {"i": 1}, "future": f})
496+
q.put({"shutdown": True, "wait": True})
497+
cloudpickle_register(ind=1)
498+
execute_parallel_tasks(
499+
future_queue=q,
500+
cores=1,
501+
openmpi_oversubscribe=False,
502+
spawner=MpiExecSpawner,
503+
cache_directory="./cache",
504+
)
505+
self.assertEqual(f.result(), 1)
506+
q.join()
507+
508+
@unittest.skipIf(
509+
skip_h5py_test, "h5py is not installed, so the h5py tests are skipped."
510+
)
511+
def test_execute_task_cache_failed_no_argument(self):
512+
f = Future()
513+
q = Queue()
514+
q.put({"fn": calc_array, "args": (), "kwargs": {}, "future": f})
515+
cloudpickle_register(ind=1)
516+
with self.assertRaises(TypeError):
517+
execute_parallel_tasks(
518+
future_queue=q,
519+
cores=1,
520+
openmpi_oversubscribe=False,
521+
spawner=MpiExecSpawner,
522+
cache_directory="./cache",
523+
)
524+
q.join()

0 commit comments

Comments
 (0)