From 703fb7066d50e60a83b88734d2de604aa7eca343 Mon Sep 17 00:00:00 2001 From: Torben Schiz Date: Thu, 13 Jun 2024 16:26:43 +0200 Subject: [PATCH] Read only the necessary input on each process --- micro_manager/snapshot/dataset.py | 58 +++++++++++++++++++-------- micro_manager/snapshot/snapshot.py | 20 +++++---- tests/unit/test_hdf5_functionality.py | 23 +++++++++-- 3 files changed, 75 insertions(+), 26 deletions(-) diff --git a/micro_manager/snapshot/dataset.py b/micro_manager/snapshot/dataset.py index 298f1f9..a2463a0 100644 --- a/micro_manager/snapshot/dataset.py +++ b/micro_manager/snapshot/dataset.py @@ -57,27 +57,29 @@ def collect_output_files( main_file.create_dataset( key, shape=(database_length, *current_data.shape), + chunks=(1, *current_data.shape), fillvalue=np.nan, ) # Loop over files crashed_snapshots = [] - current_position = 0 + outer_position = 0 for file in file_list: parameter_file = h5py.File(os.path.join(dir_name, file), "r") # Add all data sets to the main file. for key in parameter_file.keys(): - current_data = parameter_file[key][:] - current_length = len(current_data) - # If the key is "crashed_snapshots" add the indices to the list of crashed snapshots - # Otherwise write the data to the main file - if key == "crashed_snapshots": - crashed_snapshots.extend(current_position + parameter_file[key][:]) - else: - main_file[key][ - current_position : current_position + current_length - ] = current_data - - current_position += current_length + inner_position = outer_position + for chunk in parameter_file[key].iter_chunks(): + current_data = parameter_file[key][chunk] + # If the key is "crashed_snapshots" add the indices to the list of crashed snapshots + # Otherwise write the data to the main file + if key == "crashed_snapshots": + crashed_snapshots.extend( + inner_position + parameter_file[key][:] + ) + else: + main_file[key][inner_position] = current_data + inner_position += 1 + outer_position = inner_position parameter_file.close() os.remove(os.path.join(dir_name, file)) @@ -124,7 +126,10 @@ def write_output_to_hdf( for key in input_data.keys(): current_data = np.asarray(input_data[key]) parameter_file.create_dataset( - key, shape=(length, *current_data.shape), fillvalue=np.nan + key, + shape=(length, *current_data.shape), + chunks=(1, *current_data.shape), + fillvalue=np.nan, ) self._has_datasets = True @@ -134,7 +139,7 @@ def write_output_to_hdf( parameter_file[key][idx] = current_data parameter_file.close() - def read_hdf(self, file_path: str, data_names: dict) -> list: + def read_hdf(self, file_path: str, data_names: dict, start: int, end: int) -> list: """ Read data from an HDF5 file and return it as a list of dictionaries. @@ -144,6 +149,10 @@ def read_hdf(self, file_path: str, data_names: dict) -> list: Path of file to read data from. data_names : dict Names of parameters to read from the file. + start: int + Index of the first snapshot to read on process. + end: int + Index of the last snapshot to read on process. Returns ------- @@ -156,7 +165,7 @@ def read_hdf(self, file_path: str, data_names: dict) -> list: output = [] # Read data by iterating over the relevant datasets for key in data_names.keys(): - parameter_data[key] = np.asarray(parameter_file[key][:]) + parameter_data[key] = np.asarray(parameter_file[key][start:end]) my_key = ( key # Save one key to be able to iterate over the length of the data ) @@ -169,6 +178,23 @@ def read_hdf(self, file_path: str, data_names: dict) -> list: output.append(current_data) return output + def get_length(self, file_path: str) -> int: + """ + Get the length of the parameter space from the HDF5 file. + + Parameters + ---------- + file_path : str + Path of file to read data from. + + Returns + ------- + int + Size of Parameter Space + """ + with h5py.File(file_path, "r") as file: + return file[list(file.keys())[0]].len() + def write_crashed_snapshots_macro(self, file_path: str, crashed_input: list): """ Write indices of crashed snapshots to the HDF5 database. diff --git a/micro_manager/snapshot/snapshot.py b/micro_manager/snapshot/snapshot.py index d0b089f..2a680e8 100644 --- a/micro_manager/snapshot/snapshot.py +++ b/micro_manager/snapshot/snapshot.py @@ -164,22 +164,28 @@ def initialize(self) -> None: # Create object responsible for reading parameters and writing simulation output self._data_storage = ReadWriteHDF(self._logger) + self._parameter_space_size = self._data_storage.get_length(self._parameter_file) # Read macro parameters from the parameter file - self._macro_parameters = self._data_storage.read_hdf( - self._parameter_file, self._read_data_names - ) - self._parameter_space_size = len(self._macro_parameters) # Decompose parameters if the snapshot creation is executed in parallel if self._is_parallel: - equal_partition = int(len(self._macro_parameters) / self._size) - rest = len(self._macro_parameters) % self._size + equal_partition = int(self._parameter_space_size / self._size) + rest = self._parameter_space_size % self._size if self._rank < rest: start = self._rank * (equal_partition + 1) end = start + equal_partition + 1 else: start = self._rank * equal_partition + rest end = start + equal_partition - self._macro_parameters = self._macro_parameters[start:end] + self._macro_parameters = self._data_storage.read_hdf( + self._parameter_file, self._read_data_names, start, end + ) + else: + self._macro_parameters = self._data_storage.read_hdf( + self._parameter_file, + self._read_data_names, + 0, + self._parameter_space_size, + ) # Create database file to store output from a rank in if self._is_parallel: diff --git a/tests/unit/test_hdf5_functionality.py b/tests/unit/test_hdf5_functionality.py index 0dfc9c6..01e4572 100644 --- a/tests/unit/test_hdf5_functionality.py +++ b/tests/unit/test_hdf5_functionality.py @@ -54,7 +54,10 @@ def test_collect_output_files(self): for key in data.keys(): current_data = np.asarray(data[key]) f.create_dataset( - key, data=current_data, shape=(1, *current_data.shape) + key, + data=current_data, + shape=(1, *current_data.shape), + chunks=(1, *current_data.shape), ) # Ensure output file does not exist if os.path.isfile(os.path.join(dir_name, "snapshot_data.hdf5")): @@ -97,7 +100,7 @@ def test_simulation_output_to_hdf(self): if os.path.isfile(file_name): os.remove(file_name) - # Create artifical output data + # Create artificial output data macro_data = { "macro_vector_data": np.array([3, 1, 2]), "macro_scalar_data": 2, @@ -148,9 +151,23 @@ def test_hdf_to_dict(self): ) read_data_names = {"macro_vector_data": True, "macro_scalar_data": False} data_manager = ReadWriteHDF(MagicMock()) - read = data_manager.read_hdf(file_name, read_data_names) + read = data_manager.read_hdf(file_name, read_data_names, 0, -1) for i in range(len(read)): self.assertEqual(read[i]["macro_scalar_data"], expected_macro_scalar) self.assertListEqual( read[i]["macro_vector_data"].tolist(), expected_macro_vector.tolist() ) + + def test_get_parameter_space_length(self): + """ + Test if reading the length of the parameter space works as expected. + """ + file_name = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "hdf_files", + "test_parameter.hdf5", + ) + data_manager = ReadWriteHDF(MagicMock()) + + data_manager.get_length(file_name) + self.assertEqual(data_manager.get_length(file_name), 1)