From a5aeccbd5cafe298db1fe7596f61654df27978bd Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Fri, 14 Apr 2023 16:01:33 +0200 Subject: [PATCH] fix: behavior of split healthy in extreme cases (#35) --- rul_datasets/adaption.py | 2 +- tests/test_adaption.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/rul_datasets/adaption.py b/rul_datasets/adaption.py index 8eb4380..737d8d8 100644 --- a/rul_datasets/adaption.py +++ b/rul_datasets/adaption.py @@ -357,7 +357,7 @@ def _get_sections( split_idx = cast(int, target.flip(0).argmax().item()) sections = [len(target) - split_idx, split_idx] else: - by_steps = cast(int, by_steps) + by_steps = min(cast(int, by_steps), len(target)) sections = [by_steps, len(target) - by_steps] return sections diff --git a/tests/test_adaption.py b/tests/test_adaption.py index 3aed7ed..0eee5db 100644 --- a/tests/test_adaption.py +++ b/tests/test_adaption.py @@ -471,3 +471,14 @@ def test_split_healthy(features, targets, by_max_rul, by_steps): assert len(degraded_sample) == 3 # features, degradation steps, and labels assert degraded_sample[0].shape == (2, 100) # features are channel first assert degraded_sample[1] == i # degradation step is timestep since healthy + + +@pytest.mark.parametrize(["by_max_rul", "by_steps"], [(True, None), (False, 15)]) +def test_split_healthy_no_degraded(by_steps, by_max_rul): + features = [np.random.randn(11, 100, 2)] + targets = [np.ones(11) * 125] + + healthy, degraded = adaption.split_healthy(features, targets, by_max_rul, by_steps) + + assert len(healthy) == 11 + assert len(degraded) == 0