Skip to content

Commit

Permalink
fix: load_multiple crashed with empty input list (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilman151 authored Apr 13, 2023
1 parent 780d6fd commit 565156b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
7 changes: 5 additions & 2 deletions rul_datasets/reader/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,11 @@ def load_multiple(
features: The feature arrays saved in `save_paths`
targets: The target arrays saved in `save_paths`
"""
runs = [load(save_path, memmap) for save_path in save_paths]
features, targets = [list(x) for x in zip(*runs)]
if save_paths:
runs = [load(save_path, memmap) for save_path in save_paths]
features, targets = [list(x) for x in zip(*runs)]
else:
features, targets = [], []

return features, targets

Expand Down
11 changes: 11 additions & 0 deletions tests/reader/test_saving.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os.path
from pathlib import Path
from unittest import mock

import numpy as np
import numpy.testing as npt
Expand Down Expand Up @@ -36,6 +37,16 @@ def test_load(tmp_path, file_name):
npt.assert_equal(loaded_targets, targets)


@mock.patch("rul_datasets.reader.saving.load", return_value=(None, None))
@pytest.mark.parametrize("file_names", [["run1", "run2"], []])
def test_load_multiple(mock_load, file_names):
features, targets = saving.load_multiple(file_names)

mock_load.assert_has_calls([mock.call(name, False) for name in file_names])
assert len(features) == len(file_names)
assert len(targets) == len(file_names)


@pytest.mark.parametrize("file_name", ["run", "run.npy"])
def test_exists(tmp_path, file_name):
save_path = os.path.join(tmp_path, file_name)
Expand Down

0 comments on commit 565156b

Please sign in to comment.