Skip to content

Commit

Permalink
Read only the necessary input on each process
Browse files Browse the repository at this point in the history
  • Loading branch information
tjwsch committed Jun 13, 2024
1 parent 6de8950 commit 703fb70
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 26 deletions.
58 changes: 42 additions & 16 deletions micro_manager/snapshot/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,29 @@ def collect_output_files(
main_file.create_dataset(
key,
shape=(database_length, *current_data.shape),
chunks=(1, *current_data.shape),
fillvalue=np.nan,
)
# Loop over files
crashed_snapshots = []
current_position = 0
outer_position = 0
for file in file_list:
parameter_file = h5py.File(os.path.join(dir_name, file), "r")
# Add all data sets to the main file.
for key in parameter_file.keys():
current_data = parameter_file[key][:]
current_length = len(current_data)
# If the key is "crashed_snapshots" add the indices to the list of crashed snapshots
# Otherwise write the data to the main file
if key == "crashed_snapshots":
crashed_snapshots.extend(current_position + parameter_file[key][:])
else:
main_file[key][
current_position : current_position + current_length
] = current_data

current_position += current_length
inner_position = outer_position
for chunk in parameter_file[key].iter_chunks():
current_data = parameter_file[key][chunk]
# If the key is "crashed_snapshots" add the indices to the list of crashed snapshots
# Otherwise write the data to the main file
if key == "crashed_snapshots":
crashed_snapshots.extend(
inner_position + parameter_file[key][:]
)
else:
main_file[key][inner_position] = current_data
inner_position += 1
outer_position = inner_position
parameter_file.close()
os.remove(os.path.join(dir_name, file))

Expand Down Expand Up @@ -124,7 +126,10 @@ def write_output_to_hdf(
for key in input_data.keys():
current_data = np.asarray(input_data[key])
parameter_file.create_dataset(
key, shape=(length, *current_data.shape), fillvalue=np.nan
key,
shape=(length, *current_data.shape),
chunks=(1, *current_data.shape),
fillvalue=np.nan,
)
self._has_datasets = True

Expand All @@ -134,7 +139,7 @@ def write_output_to_hdf(
parameter_file[key][idx] = current_data
parameter_file.close()

def read_hdf(self, file_path: str, data_names: dict) -> list:
def read_hdf(self, file_path: str, data_names: dict, start: int, end: int) -> list:
"""
Read data from an HDF5 file and return it as a list of dictionaries.
Expand All @@ -144,6 +149,10 @@ def read_hdf(self, file_path: str, data_names: dict) -> list:
Path of file to read data from.
data_names : dict
Names of parameters to read from the file.
start: int
Index of the first snapshot to read on process.
end: int
Index of the last snapshot to read on process.
Returns
-------
Expand All @@ -156,7 +165,7 @@ def read_hdf(self, file_path: str, data_names: dict) -> list:
output = []
# Read data by iterating over the relevant datasets
for key in data_names.keys():
parameter_data[key] = np.asarray(parameter_file[key][:])
parameter_data[key] = np.asarray(parameter_file[key][start:end])
my_key = (
key # Save one key to be able to iterate over the length of the data
)
Expand All @@ -169,6 +178,23 @@ def read_hdf(self, file_path: str, data_names: dict) -> list:
output.append(current_data)
return output

def get_length(self, file_path: str) -> int:
"""
Get the length of the parameter space from the HDF5 file.
Parameters
----------
file_path : str
Path of file to read data from.
Returns
-------
int
Size of Parameter Space
"""
with h5py.File(file_path, "r") as file:
return file[list(file.keys())[0]].len()

def write_crashed_snapshots_macro(self, file_path: str, crashed_input: list):
"""
Write indices of crashed snapshots to the HDF5 database.
Expand Down
20 changes: 13 additions & 7 deletions micro_manager/snapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,22 +164,28 @@ def initialize(self) -> None:
# Create object responsible for reading parameters and writing simulation output
self._data_storage = ReadWriteHDF(self._logger)

self._parameter_space_size = self._data_storage.get_length(self._parameter_file)
# Read macro parameters from the parameter file
self._macro_parameters = self._data_storage.read_hdf(
self._parameter_file, self._read_data_names
)
self._parameter_space_size = len(self._macro_parameters)
# Decompose parameters if the snapshot creation is executed in parallel
if self._is_parallel:
equal_partition = int(len(self._macro_parameters) / self._size)
rest = len(self._macro_parameters) % self._size
equal_partition = int(self._parameter_space_size / self._size)
rest = self._parameter_space_size % self._size
if self._rank < rest:
start = self._rank * (equal_partition + 1)
end = start + equal_partition + 1
else:
start = self._rank * equal_partition + rest
end = start + equal_partition
self._macro_parameters = self._macro_parameters[start:end]
self._macro_parameters = self._data_storage.read_hdf(
self._parameter_file, self._read_data_names, start, end
)
else:
self._macro_parameters = self._data_storage.read_hdf(
self._parameter_file,
self._read_data_names,
0,
self._parameter_space_size,
)

# Create database file to store output from a rank in
if self._is_parallel:
Expand Down
23 changes: 20 additions & 3 deletions tests/unit/test_hdf5_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def test_collect_output_files(self):
for key in data.keys():
current_data = np.asarray(data[key])
f.create_dataset(
key, data=current_data, shape=(1, *current_data.shape)
key,
data=current_data,
shape=(1, *current_data.shape),
chunks=(1, *current_data.shape),
)
# Ensure output file does not exist
if os.path.isfile(os.path.join(dir_name, "snapshot_data.hdf5")):
Expand Down Expand Up @@ -97,7 +100,7 @@ def test_simulation_output_to_hdf(self):
if os.path.isfile(file_name):
os.remove(file_name)

# Create artifical output data
# Create artificial output data
macro_data = {
"macro_vector_data": np.array([3, 1, 2]),
"macro_scalar_data": 2,
Expand Down Expand Up @@ -148,9 +151,23 @@ def test_hdf_to_dict(self):
)
read_data_names = {"macro_vector_data": True, "macro_scalar_data": False}
data_manager = ReadWriteHDF(MagicMock())
read = data_manager.read_hdf(file_name, read_data_names)
read = data_manager.read_hdf(file_name, read_data_names, 0, -1)
for i in range(len(read)):
self.assertEqual(read[i]["macro_scalar_data"], expected_macro_scalar)
self.assertListEqual(
read[i]["macro_vector_data"].tolist(), expected_macro_vector.tolist()
)

def test_get_parameter_space_length(self):
"""
Test if reading the length of the parameter space works as expected.
"""
file_name = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"hdf_files",
"test_parameter.hdf5",
)
data_manager = ReadWriteHDF(MagicMock())

data_manager.get_length(file_name)
self.assertEqual(data_manager.get_length(file_name), 1)

0 comments on commit 703fb70

Please sign in to comment.