Skip to content
Merged
2 changes: 1 addition & 1 deletion src/executorlib/task_scheduler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def shutdown(self, wait: bool = True, *, cancel_futures: bool = False):
self._future_queue.put(
{"shutdown": True, "wait": wait, "cancel_futures": cancel_futures}
)
if wait and isinstance(self._process, Thread):
if isinstance(self._process, Thread):
self._process.join()
self._future_queue.join()
self._process = None
Expand Down
107 changes: 76 additions & 31 deletions src/executorlib/task_scheduler/file/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,37 +96,17 @@ def execute_tasks_h5(
with contextlib.suppress(queue.Empty):
task_dict = future_queue.get_nowait()
if task_dict is not None and "shutdown" in task_dict and task_dict["shutdown"]:
if task_dict["wait"] and wait:
while len(memory_dict) > 0:
memory_dict = _refresh_memory_dict(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
process_dict=process_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
refresh_rate=refresh_rate,
)
if not task_dict["cancel_futures"] and wait:
_cancel_processes(
process_dict=process_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
)
else:
memory_dict = _refresh_memory_dict(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
process_dict=process_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
refresh_rate=refresh_rate,
)
for value in memory_dict.values():
if not value.done():
value.cancel()
_shutdown_executor(
wait=wait and task_dict["wait"],
cancel_futures=task_dict.get("cancel_futures", False),
memory_dict=memory_dict,
process_dict=process_dict,
cache_dir_dict=cache_dir_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
refresh_rate=refresh_rate,
)
future_queue.task_done()
future_queue.join()
break
Expand Down Expand Up @@ -381,3 +361,68 @@ def _get_task_input(
cache_directory = os.path.abspath(task_resource_dict.pop("cache_directory"))
error_log_file = task_resource_dict.pop("error_log_file", None)
return task_resource_dict, cache_key, cache_directory, error_log_file


def _cancel_futures(future_dict: dict):
for value in future_dict.values():
if not value.done():
value.cancel()


def _shutdown_executor(
wait: bool,
cancel_futures: bool,
memory_dict: dict,
process_dict: dict,
cache_dir_dict: dict,
terminate_function: Optional[Callable] = None,
pysqa_config_directory: Optional[str] = None,
backend: Optional[str] = None,
refresh_rate: float = 0.01,
):
if wait and not cancel_futures:
while len(memory_dict) > 0:
memory_dict = _refresh_memory_dict(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
process_dict=process_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
refresh_rate=refresh_rate,
)
elif wait and cancel_futures:
for value in memory_dict.values():
if not value.done():
value.cancel()
while len(memory_dict) > 0:
memory_dict = _refresh_memory_dict(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
process_dict=process_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
refresh_rate=refresh_rate,
)
elif cancel_futures: # wait is False
_cancel_processes(
process_dict=process_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
)
_cancel_futures(future_dict=memory_dict)
else: # wait is False and cancel_futures is False
memory_dict = _refresh_memory_dict(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
process_dict=process_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
refresh_rate=refresh_rate,
)
# The future objects are detached so mark them as cancelled even though the processes are
# not terminated. This is to prevent the main process from waiting indefinitely for the results.
_cancel_futures(future_dict=memory_dict)
63 changes: 61 additions & 2 deletions tests/unit/executor/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import shutil
import unittest
from concurrent.futures import Future
from time import sleep
from concurrent.futures import Future, wait

from executorlib import get_cache_data, get_future_from_cache
from executorlib.api import TestClusterExecutor
Expand All @@ -10,6 +11,7 @@

try:
import h5py
from executorlib.task_scheduler.file.shared import _shutdown_executor

skip_h5py_test = False
except ImportError:
Expand All @@ -28,6 +30,11 @@ def get_error(i):
raise ValueError(f"error {i}")


def add_with_sleep(parameter_1, parameter_2):
sleep(1)
return parameter_1 + parameter_2


@unittest.skipIf(
skip_h5py_test, "h5py is not installed, so the h5io tests are skipped."
)
Expand Down Expand Up @@ -132,5 +139,57 @@ def test_executor_dependency_plot(self):
self.assertEqual(len(nodes), 4)
self.assertEqual(len(edges), 4)

def test_shutdown_wait_false_cancel_futures_false(self):
exe = TestClusterExecutor(cache_directory="shutdown_1_dir")
cloudpickle_register(ind=1)
future_1 = exe.submit(add_with_sleep, 1, parameter_2=2)
exe.shutdown(wait=False, cancel_futures=False)
self.assertTrue(future_1.done())
self.assertTrue(future_1.cancelled())
sleep(2)
exe = TestClusterExecutor(cache_directory="shutdown_1_dir")
cloudpickle_register(ind=1)
future_1 = exe.submit(add_with_sleep, 1, parameter_2=2)
exe.shutdown(wait=False, cancel_futures=False)
self.assertTrue(future_1.done())
self.assertEqual(future_1.result(), 3)

def test_shutdown_wait_false_cancel_futures_true(self):
exe = TestClusterExecutor(cache_directory="shutdown_2_dir")
cloudpickle_register(ind=1)
future_1 = exe.submit(add_with_sleep, 1, parameter_2=3)
exe.shutdown(wait=False, cancel_futures=True)
self.assertTrue(future_1.done())
self.assertTrue(future_1.cancelled())

def test_shutdown_wait_true_cancel_futures_true(self):
exe = TestClusterExecutor(cache_directory="shutdown_3_dir")
cloudpickle_register(ind=1)
future_1 = exe.submit(add_with_sleep, 1, parameter_2=3)
future_2 = exe.submit(add_with_sleep, future_1, parameter_2=3)
exe.shutdown(wait=True, cancel_futures=True)
self.assertTrue(future_1.done())
self.assertTrue(future_1.cancelled())
self.assertTrue(future_2.done())
self.assertTrue(future_2.cancelled())

def tearDown(self):
shutil.rmtree("rather_this_dir", ignore_errors=True)
for f in ["rather_this_dir", "shutdown_1_dir", "shutdown_2_dir", "shutdown_3_dir", "cache_dir"]:
if os.path.exists(f):
shutil.rmtree(f, ignore_errors=True)

def test_shutdown_executor_function(self):
memory_dict={"a": Future()}
_shutdown_executor(
wait=True,
cancel_futures=True,
memory_dict=memory_dict,
process_dict={},
cache_dir_dict={"a": "cache_dir"},
terminate_function=None,
pysqa_config_directory=None,
backend=None,
refresh_rate=0.01,
)
self.assertTrue(memory_dict["a"].done())
self.assertTrue(memory_dict["a"].cancelled())
Loading