Skip to content

Commit d7fec16

Browse files
committed
fixes
1 parent 06f85c8 commit d7fec16

File tree

5 files changed

+39
-8
lines changed

5 files changed

+39
-8
lines changed

executorlib/standalone/serialize.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def serialize_funct_h5(
3333
fn_args: Optional[list] = None,
3434
fn_kwargs: Optional[dict] = None,
3535
resource_dict: Optional[dict] = None,
36+
cache_key: Optional[str] = None,
3637
) -> tuple[str, dict]:
3738
"""
3839
Serialize a function and its arguments and keyword arguments into an HDF5 file.
@@ -51,6 +52,8 @@ def serialize_funct_h5(
5152
executor: None,
5253
hostname_localhost: False,
5354
}
55+
cache_key (str, optional): By default the cache_key is generated based on the function hash, this can be
56+
overwritten by setting the cache_key.
5457
5558
Returns:
5659
Tuple[str, dict]: A tuple containing the task key and the serialized data.
@@ -62,16 +65,11 @@ def serialize_funct_h5(
6265
fn_kwargs = {}
6366
if resource_dict is None:
6467
resource_dict = {}
65-
if "cache_key" in resource_dict:
66-
task_key = resource_dict["cache_key"]
68+
if cache_key is not None:
69+
task_key = cache_key
6770
else:
6871
binary_all = cloudpickle.dumps(
69-
{
70-
"fn": fn,
71-
"args": fn_args,
72-
"kwargs": fn_kwargs,
73-
"resource_dict": resource_dict,
74-
}
72+
{"fn": fn, "args": fn_args, "kwargs": fn_kwargs, "resource_dict": resource_dict}
7573
)
7674
task_key = fn.__name__ + _get_hash(binary=binary_all)
7775
data = {

executorlib/task_scheduler/file/shared.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def execute_tasks_h5(
7272
terminate_function (Callable): The function to terminate the tasks.
7373
pysqa_config_directory (str, optional): path to the pysqa config directory (only for pysqa based backend).
7474
backend (str, optional): name of the backend used to spawn tasks.
75+
disable_dependencies (boolean): Disable resolving future objects during the submission.
7576
7677
Returns:
7778
None
@@ -101,11 +102,13 @@ def execute_tasks_h5(
101102
task_resource_dict.update(
102103
{k: v for k, v in resource_dict.items() if k not in task_resource_dict}
103104
)
105+
cache_key = task_resource_dict.pop("cache_key", None)
104106
task_key, data_dict = serialize_funct_h5(
105107
fn=task_dict["fn"],
106108
fn_args=task_args,
107109
fn_kwargs=task_kwargs,
108110
resource_dict=task_resource_dict,
111+
cache_key=cache_key,
109112
)
110113
if task_key not in memory_dict:
111114
if not (

executorlib/task_scheduler/interactive/shared.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def execute_tasks(
2222
hostname_localhost: Optional[bool] = None,
2323
init_function: Optional[Callable] = None,
2424
cache_directory: Optional[str] = None,
25+
cache_key: Optional[str] = None,
2526
queue_join_on_shutdown: bool = True,
2627
**kwargs,
2728
) -> None:
@@ -41,6 +42,8 @@ def execute_tasks(
4142
option to true
4243
init_function (Callable): optional function to preset arguments for functions which are submitted later
4344
cache_directory (str, optional): The directory to store cache files. Defaults to "cache".
45+
cache_key (str, optional): By default the cache_key is generated based on the function hash, this can be
46+
overwritten by setting the cache_key.
4447
queue_join_on_shutdown (bool): Join communication queue when thread is closed. Defaults to True.
4548
"""
4649
interface = interface_bootup(
@@ -73,6 +76,7 @@ def execute_tasks(
7376
task_dict=task_dict,
7477
future_queue=future_queue,
7578
cache_directory=cache_directory,
79+
cache_key=cache_key,
7680
)
7781

7882

@@ -129,6 +133,7 @@ def _execute_task_with_cache(
129133
task_dict: dict,
130134
future_queue: queue.Queue,
131135
cache_directory: str,
136+
cache_key: Optional[str] = None,
132137
):
133138
"""
134139
Execute the task in the task_dict by communicating it via the interface using the cache in the cache directory.
@@ -139,6 +144,8 @@ def _execute_task_with_cache(
139144
{"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}}
140145
future_queue (Queue): Queue for receiving new tasks.
141146
cache_directory (str): The directory to store cache files.
147+
cache_key (str, optional): By default the cache_key is generated based on the function hash, this can be
148+
overwritten by setting the cache_key.
142149
"""
143150
from executorlib.task_scheduler.file.hdf import dump, get_output
144151

@@ -147,6 +154,7 @@ def _execute_task_with_cache(
147154
fn_args=task_dict["args"],
148155
fn_kwargs=task_dict["kwargs"],
149156
resource_dict=task_dict.get("resource_dict", {}),
157+
cache_key=cache_key,
150158
)
151159
os.makedirs(os.path.join(cache_directory, task_key), exist_ok=True)
152160
file_name = os.path.join(cache_directory, task_key, "cache.h5out")

tests/test_cache_fileexecutor_serial.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def test_executor_mixed(self):
4242
self.assertEqual(fs1.result(), 3)
4343
self.assertTrue(fs1.done())
4444

45+
def test_executor_mixed_cache_key(self):
46+
with FileTaskScheduler(execute_function=execute_in_subprocess) as exe:
47+
fs1 = exe.submit(my_funct, 1, b=2, resource_dict={"cache_key": "abc"})
48+
self.assertFalse(fs1.done())
49+
self.assertEqual(fs1.result(), 3)
50+
self.assertTrue(fs1.done())
51+
4552
def test_executor_dependence_mixed(self):
4653
with FileTaskScheduler(execute_function=execute_in_subprocess) as exe:
4754
fs1 = exe.submit(my_funct, 1, b=2)

tests/test_singlenodeexecutor_cache.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ def test_cache_data(self):
3434
sum([sum(c["input_args"][0]) for c in cache_lst]), sum(result_lst)
3535
)
3636

37+
def test_cache_key(self):
38+
cache_directory = "./cache"
39+
with SingleNodeExecutor(cache_directory=cache_directory) as exe:
40+
self.assertTrue(exe)
41+
future_lst = [exe.submit(sum, [i, i], resource_dict={"cache_key": "same_" + str(i)}) for i in range(1, 4)]
42+
result_lst = [f.result() for f in future_lst]
43+
44+
cache_lst = get_cache_data(cache_directory=cache_directory)
45+
for entry in cache_lst:
46+
self.assertTrue("same" in entry['filename'])
47+
self.assertEqual(sum([c["output"] for c in cache_lst]), sum(result_lst))
48+
self.assertEqual(
49+
sum([sum(c["input_args"][0]) for c in cache_lst]), sum(result_lst)
50+
)
51+
3752
def test_cache_error(self):
3853
cache_directory = "./cache_error"
3954
with SingleNodeExecutor(cache_directory=cache_directory) as exe:

0 commit comments

Comments
 (0)