Skip to content

Commit fd826ae

Browse files
Interactive: refactor task done (#795)
* Interactive: refactor task done * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes * update test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5530fb5 commit fd826ae

File tree

2 files changed

+24
-37
lines changed

2 files changed

+24
-37
lines changed

executorlib/task_scheduler/interactive/shared.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,47 +77,38 @@ def execute_tasks(
7777
if error_log_file is not None:
7878
task_dict["error_log_file"] = error_log_file
7979
if cache_directory is None:
80-
_execute_task_without_cache(
81-
interface=interface, task_dict=task_dict, future_queue=future_queue
82-
)
80+
_execute_task_without_cache(interface=interface, task_dict=task_dict)
8381
else:
8482
_execute_task_with_cache(
8583
interface=interface,
8684
task_dict=task_dict,
87-
future_queue=future_queue,
8885
cache_directory=cache_directory,
8986
cache_key=cache_key,
9087
)
88+
_task_done(future_queue=future_queue)
9189

9290

93-
def _execute_task_without_cache(
94-
interface: SocketInterface, task_dict: dict, future_queue: queue.Queue
95-
):
91+
def _execute_task_without_cache(interface: SocketInterface, task_dict: dict):
9692
"""
9793
Execute the task in the task_dict by communicating it via the interface.
9894
9995
Args:
10096
interface (SocketInterface): socket interface for zmq communication
10197
task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys
10298
{"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}}
103-
future_queue (Queue): Queue for receiving new tasks.
10499
"""
105100
f = task_dict.pop("future")
106101
if not f.done() and f.set_running_or_notify_cancel():
107102
try:
108103
f.set_result(interface.send_and_receive_dict(input_dict=task_dict))
109104
except Exception as thread_exception:
110105
interface.shutdown(wait=True)
111-
_task_done(future_queue=future_queue)
112106
f.set_exception(exception=thread_exception)
113-
else:
114-
_task_done(future_queue=future_queue)
115107

116108

117109
def _execute_task_with_cache(
118110
interface: SocketInterface,
119111
task_dict: dict,
120-
future_queue: queue.Queue,
121112
cache_directory: str,
122113
cache_key: Optional[str] = None,
123114
):
@@ -128,7 +119,6 @@ def _execute_task_with_cache(
128119
interface (SocketInterface): socket interface for zmq communication
129120
task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys
130121
{"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}}
131-
future_queue (Queue): Queue for receiving new tasks.
132122
cache_directory (str): The directory to store cache files.
133123
cache_key (str, optional): By default the cache_key is generated based on the function hash, this can be
134124
overwritten by setting the cache_key.
@@ -155,16 +145,11 @@ def _execute_task_with_cache(
155145
f.set_result(result)
156146
except Exception as thread_exception:
157147
interface.shutdown(wait=True)
158-
_task_done(future_queue=future_queue)
159148
f.set_exception(exception=thread_exception)
160-
raise thread_exception
161-
else:
162-
_task_done(future_queue=future_queue)
163149
else:
164150
_, _, result = get_output(file_name=file_name)
165151
future = task_dict["future"]
166152
future.set_result(result)
167-
_task_done(future_queue=future_queue)
168153

169154

170155
def _task_done(future_queue: queue.Queue):

tests/test_mpiexecspawner.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -443,13 +443,13 @@ def test_execute_task_failed_no_argument(self):
443443
q.put({"fn": calc_array, "args": (), "kwargs": {}, "future": f})
444444
q.put({"shutdown": True, "wait": True})
445445
cloudpickle_register(ind=1)
446+
execute_tasks(
447+
future_queue=q,
448+
cores=1,
449+
openmpi_oversubscribe=False,
450+
spawner=MpiExecSpawner,
451+
)
446452
with self.assertRaises(TypeError):
447-
execute_tasks(
448-
future_queue=q,
449-
cores=1,
450-
openmpi_oversubscribe=False,
451-
spawner=MpiExecSpawner,
452-
)
453453
f.result()
454454
q.join()
455455

@@ -459,13 +459,13 @@ def test_execute_task_failed_wrong_argument(self):
459459
q.put({"fn": calc_array, "args": (), "kwargs": {"j": 4}, "future": f})
460460
q.put({"shutdown": True, "wait": True})
461461
cloudpickle_register(ind=1)
462+
execute_tasks(
463+
future_queue=q,
464+
cores=1,
465+
openmpi_oversubscribe=False,
466+
spawner=MpiExecSpawner,
467+
)
462468
with self.assertRaises(TypeError):
463-
execute_tasks(
464-
future_queue=q,
465-
cores=1,
466-
openmpi_oversubscribe=False,
467-
spawner=MpiExecSpawner,
468-
)
469469
f.result()
470470
q.join()
471471

@@ -533,13 +533,15 @@ def test_execute_task_cache_failed_no_argument(self):
533533
f = Future()
534534
q = Queue()
535535
q.put({"fn": calc_array, "args": (), "kwargs": {}, "future": f})
536+
q.put({"shutdown": True, "wait": True})
536537
cloudpickle_register(ind=1)
538+
execute_tasks(
539+
future_queue=q,
540+
cores=1,
541+
openmpi_oversubscribe=False,
542+
spawner=MpiExecSpawner,
543+
cache_directory="executorlib_cache",
544+
)
537545
with self.assertRaises(TypeError):
538-
execute_tasks(
539-
future_queue=q,
540-
cores=1,
541-
openmpi_oversubscribe=False,
542-
spawner=MpiExecSpawner,
543-
cache_directory="executorlib_cache",
544-
)
546+
f.result()
545547
q.join()

0 commit comments

Comments
 (0)