Skip to content

Measure time for execution and store it in the HDF5 files #524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions executorlib/backend/cache_parallel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle
import sys
import time

import cloudpickle

Expand Down Expand Up @@ -32,6 +33,7 @@ def main() -> None:
mpi_size_larger_one = MPI.COMM_WORLD.Get_size() > 1
file_name = sys.argv[1]

time_start = time.time()
if mpi_rank_zero:
apply_dict = backend_load_file(file_name=file_name)
else:
Expand All @@ -46,6 +48,7 @@ def main() -> None:
backend_write_file(
file_name=file_name,
output=result,
runtime=time.time() - time_start,
)
MPI.COMM_WORLD.Barrier()

Expand Down
11 changes: 9 additions & 2 deletions executorlib/cache/backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time
from typing import Any

from executorlib.cache.shared import FutureItem
Expand Down Expand Up @@ -28,21 +29,25 @@ def backend_load_file(file_name: str) -> dict:
return apply_dict


def backend_write_file(file_name: str, output: Any) -> None:
def backend_write_file(file_name: str, output: Any, runtime: float) -> None:
"""
Write the output to an HDF5 file.

Args:
file_name (str): The name of the HDF5 file.
output (Any): The output to be written.
runtime (float): Time for executing function.

Returns:
None

"""
file_name_out = os.path.splitext(file_name)[0]
os.rename(file_name, file_name_out + ".h5ready")
dump(file_name=file_name_out + ".h5ready", data_dict={"output": output})
dump(
file_name=file_name_out + ".h5ready",
data_dict={"output": output, "runtime": runtime},
)
os.rename(file_name_out + ".h5ready", file_name_out + ".h5out")


Expand All @@ -57,10 +62,12 @@ def backend_execute_task_in_file(file_name: str) -> None:
None
"""
apply_dict = backend_load_file(file_name=file_name)
time_start = time.time()
result = apply_dict["fn"].__call__(*apply_dict["args"], **apply_dict["kwargs"])
backend_write_file(
file_name=file_name,
output=result,
runtime=time.time() - time_start,
)


Expand Down
3 changes: 3 additions & 0 deletions executorlib/interactive/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import queue
import sys
import time
from concurrent.futures import Future
from time import sleep
from typing import Callable, List, Optional
Expand Down Expand Up @@ -627,8 +628,10 @@ def _execute_task_with_cache(
f = task_dict.pop("future")
if f.set_running_or_notify_cancel():
try:
time_start = time.time()
result = interface.send_and_receive_dict(input_dict=task_dict)
data_dict["output"] = result
data_dict["runtime"] = time.time() - time_start
dump(file_name=file_name, data_dict=data_dict)
f.set_result(result)
except Exception as thread_exception:
Expand Down
18 changes: 18 additions & 0 deletions executorlib/standalone/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def dump(file_name: str, data_dict: dict) -> None:
"args": "input_args",
"kwargs": "input_kwargs",
"output": "output",
"runtime": "runtime",
"queue_id": "queue_id",
}
with h5py.File(file_name, "a") as fname:
Expand Down Expand Up @@ -73,6 +74,23 @@ def get_output(file_name: str) -> Tuple[bool, object]:
return False, None


def get_runtime(file_name: str) -> float:
"""
Get run time from HDF5 file

Args:
file_name (str): file name of the HDF5 file as absolute path

Returns:
float: run time from the execution of the python function
"""
with h5py.File(file_name, "r") as hdf:
if "runtime" in hdf:
return cloudpickle.loads(np.void(hdf["/runtime"]))
else:
return 0.0


def get_queue_id(file_name: str) -> Optional[int]:
with h5py.File(file_name, "r") as hdf:
if "queue_id" in hdf:
Expand Down
12 changes: 11 additions & 1 deletion tests/test_cache_hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@


try:
from executorlib.standalone.hdf import dump, load, get_output, get_queue_id
from executorlib.standalone.hdf import (
dump,
load,
get_output,
get_runtime,
get_queue_id,
)

skip_h5py_test = False
except ImportError:
Expand Down Expand Up @@ -34,6 +40,7 @@ def test_hdf_mixed(self):
self.assertEqual(data_dict["args"], [a])
self.assertEqual(data_dict["kwargs"], {"b": b})
flag, output = get_output(file_name=file_name)
self.assertTrue(get_runtime(file_name=file_name) == 0.0)
self.assertFalse(flag)
self.assertIsNone(output)

Expand All @@ -49,6 +56,7 @@ def test_hdf_args(self):
self.assertEqual(data_dict["args"], [a, b])
self.assertEqual(data_dict["kwargs"], {})
flag, output = get_output(file_name=file_name)
self.assertTrue(get_runtime(file_name=file_name) == 0.0)
self.assertFalse(flag)
self.assertIsNone(output)

Expand All @@ -73,6 +81,7 @@ def test_hdf_kwargs(self):
self.assertEqual(data_dict["kwargs"], {"a": a, "b": b})
self.assertEqual(get_queue_id(file_name=file_name), 123)
flag, output = get_output(file_name=file_name)
self.assertTrue(get_runtime(file_name=file_name) == 0.0)
self.assertFalse(flag)
self.assertIsNone(output)

Expand All @@ -87,6 +96,7 @@ def test_hdf_queue_id(self):
)
self.assertEqual(get_queue_id(file_name=file_name), 123)
flag, output = get_output(file_name=file_name)
self.assertTrue(get_runtime(file_name=file_name) == 0.0)
self.assertFalse(flag)
self.assertIsNone(output)

Expand Down
14 changes: 13 additions & 1 deletion tests/test_cache_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
try:
from executorlib.cache.backend import backend_execute_task_in_file
from executorlib.cache.shared import _check_task_output, FutureItem
from executorlib.standalone.hdf import dump
from executorlib.standalone.hdf import dump, get_runtime
from executorlib.standalone.serialize import serialize_funct_h5

skip_h5io_test = False
Expand Down Expand Up @@ -40,6 +40,10 @@ def test_execute_function_mixed(self):
)
self.assertTrue(future_obj.done())
self.assertEqual(future_obj.result(), 3)
self.assertTrue(
get_runtime(file_name=os.path.join(cache_directory, task_key + ".h5out"))
> 0.0
)
Comment on lines +43 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider enhancing runtime validation tests

The current runtime validation is minimal and duplicated across test methods. Consider:

  1. Testing more specific runtime ranges based on the function's expected execution time
  2. Extracting the runtime validation into a helper method to reduce duplication
  3. Adding tests for error cases (invalid files, corrupted runtime data)

Here's a suggested refactor to reduce duplication:

def assert_valid_runtime(self, task_key: str, cache_directory: str) -> None:
    """Helper method to validate task runtime"""
    runtime = get_runtime(file_name=os.path.join(cache_directory, task_key + ".h5out"))
    self.assertIsNotNone(runtime, "Runtime should be recorded")
    self.assertGreater(runtime, 0.0, "Runtime should be positive")
    # Add more specific assertions based on expected execution time
    self.assertLess(runtime, 1.0, "Simple addition should take less than 1 second")

Then use it in each test:

self.assert_valid_runtime(task_key, cache_directory)

Also applies to: 70-73, 97-100

future_file_obj = FutureItem(
file_name=os.path.join(cache_directory, task_key + ".h5out")
)
Expand All @@ -63,6 +67,10 @@ def test_execute_function_args(self):
)
self.assertTrue(future_obj.done())
self.assertEqual(future_obj.result(), 3)
self.assertTrue(
get_runtime(file_name=os.path.join(cache_directory, task_key + ".h5out"))
> 0.0
)
future_file_obj = FutureItem(
file_name=os.path.join(cache_directory, task_key + ".h5out")
)
Expand All @@ -86,6 +94,10 @@ def test_execute_function_kwargs(self):
)
self.assertTrue(future_obj.done())
self.assertEqual(future_obj.result(), 3)
self.assertTrue(
get_runtime(file_name=os.path.join(cache_directory, task_key + ".h5out"))
> 0.0
)
future_file_obj = FutureItem(
file_name=os.path.join(cache_directory, task_key + ".h5out")
)
Expand Down
Loading