Skip to content

Commit

Permalink
fix: behavior of split healthy in extreme cases (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilman151 authored Apr 14, 2023
1 parent 4bbb9f8 commit a5aeccb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion rul_datasets/adaption.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def _get_sections(
split_idx = cast(int, target.flip(0).argmax().item())
sections = [len(target) - split_idx, split_idx]
else:
by_steps = cast(int, by_steps)
by_steps = min(cast(int, by_steps), len(target))
sections = [by_steps, len(target) - by_steps]

return sections
Expand Down
11 changes: 11 additions & 0 deletions tests/test_adaption.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,14 @@ def test_split_healthy(features, targets, by_max_rul, by_steps):
assert len(degraded_sample) == 3 # features, degradation steps, and labels
assert degraded_sample[0].shape == (2, 100) # features are channel first
assert degraded_sample[1] == i # degradation step is timestep since healthy


@pytest.mark.parametrize(["by_max_rul", "by_steps"], [(True, None), (False, 15)])
def test_split_healthy_no_degraded(by_steps, by_max_rul):
features = [np.random.randn(11, 100, 2)]
targets = [np.ones(11) * 125]

healthy, degraded = adaption.split_healthy(features, targets, by_max_rul, by_steps)

assert len(healthy) == 11
assert len(degraded) == 0

0 comments on commit a5aeccb

Please sign in to comment.