Skip to content

Commit

Permalink
feat: extend hyperparameters of data modules (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilman151 authored Nov 16, 2023
1 parent 89ca926 commit 3347b65
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 77 deletions.
12 changes: 5 additions & 7 deletions rul_datasets/adaption.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
source: The data module of the labeled source domain.
target: The data module of the unlabeled target domain.
paired_val: Whether to include paired data in validation.
inductive: Whether to use the target test set for training.
"""
super().__init__()

Expand All @@ -73,13 +74,10 @@ def __init__(

self.save_hyperparameters(
{
"fd_source": self.source.reader.fd,
"fd_target": self.target.reader.fd,
"batch_size": self.batch_size,
"window_size": self.source.reader.window_size,
"max_rul": self.source.reader.max_rul,
"percent_broken": self.target.reader.percent_broken,
"percent_fail_runs": self.target.reader.percent_fail_runs,
"source": self.source.hparams,
"target": self.target.hparams,
"paired_val": self.paired_val,
"inductive": self.inductive,
}
)

Expand Down
5 changes: 2 additions & 3 deletions rul_datasets/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,14 @@ def __init__(self, data_module: RulDataModule) -> None:
super().__init__()

self.data_module = data_module
hparams = self.data_module.hparams
self.save_hyperparameters(hparams)
self.save_hyperparameters(self.data_module.hparams)

self.subsets = {}
for fd in self.data_module.fds:
self.subsets[fd] = self._get_fd(fd)

def _get_fd(self, fd):
if fd == self.hparams["fd"]:
if fd == self.data_module.reader.fd:
dm = self.data_module
else:
loader = deepcopy(self.data_module.reader)
Expand Down
15 changes: 8 additions & 7 deletions rul_datasets/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Basic data modules for experiments involving only a single subset of any RUL
dataset. """

from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Any, Callable

import numpy as np
Expand Down Expand Up @@ -105,12 +104,14 @@ def __init__(
"to set a window size for re-windowing."
)

hparams = deepcopy(self.reader.hparams)
hparams["batch_size"] = self.batch_size
hparams["feature_extractor"] = (
str(self.feature_extractor) if self.feature_extractor else None
)
hparams["window_size"] = self.window_size or hparams["window_size"]
hparams = {
"reader": self.reader.hparams,
"batch_size": self.batch_size,
"feature_extractor": (
str(self.feature_extractor) if self.feature_extractor else None
),
"window_size": self.window_size,
}
self.save_hyperparameters(hparams)

@property
Expand Down
12 changes: 9 additions & 3 deletions rul_datasets/reader/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def __init__(

@property
def hparams(self) -> Dict[str, Any]:
"""A dictionary containing all input arguments of the constructor. This
information is used by the data modules to log their hyperparameters in
PyTorch Lightning."""
"""All information logged by the data modules as hyperparameters in PyTorch
Lightning."""
return {
"dataset": self.dataset_name,
"fd": self.fd,
"window_size": self.window_size,
"max_rul": self.max_rul,
Expand All @@ -105,6 +105,12 @@ def hparams(self) -> Dict[str, Any]:
"truncate_degraded_only": self.truncate_degraded_only,
}

@property
@abc.abstractmethod
def dataset_name(self) -> str:
"""Name of the dataset."""
raise NotImplementedError

@property
@abc.abstractmethod
def fds(self) -> List[int]:
Expand Down
8 changes: 6 additions & 2 deletions rul_datasets/reader/cmapss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
class CmapssReader(AbstractReader):
"""
This reader represents the NASA CMAPSS Turbofan Degradation dataset. Each of its
four sub-datasets contain a training and a test split. Upon first usage,
four sub-datasets contains a training and a test split. Upon first usage,
the training split will be further divided into a development and a validation
split. 20% of the original training split are reserved for validation.
split. 20% of the original training split is reserved for validation.
The features are provided as sliding windows over each time series in the
dataset. The label of a window is the label of its last time step. The RUL labels
Expand Down Expand Up @@ -128,6 +128,10 @@ def __init__(
self.feature_select = feature_select
self.operation_condition_aware_scaling = operation_condition_aware_scaling

@property
def dataset_name(self) -> str:
return "cmapss"

@property
def fds(self) -> List[int]:
"""Indices of available sub-datasets."""
Expand Down
9 changes: 7 additions & 2 deletions rul_datasets/reader/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class DummyReader(AbstractReader):
"""

_FDS = [1, 2]

_DEFAULT_WINDOW_SIZE = 10
_NOISE_FACTOR = {1: 0.01, 2: 0.02}
_OFFSET = {1: 0.5, 2: 0.75}
Expand All @@ -62,12 +63,12 @@ def __init__(
truncate_degraded_only: bool = False,
):
"""
Create a new dummy reader for one of the two sub-datasets. The maximun RUL
Create a new dummy reader for one of the two sub-datasets. The maximum RUL
value is set to 50 by default. Please be aware that changing this value will
lead to different features, too, as they are calculated based on the RUL
values.
For more information about using readers refer to the [reader]
For more information about using readers, refer to the [reader]
[rul_datasets.reader] module page.
Args:
Expand All @@ -94,6 +95,10 @@ def __init__(
scaler = preprocessing.MinMaxScaler(feature_range=(-1, 1))
self.scaler = scaling.fit_scaler(features, scaler)

@property
def dataset_name(self) -> str:
return "xjtu-sy"

@property
def fds(self) -> List[int]:
"""Indices of available sub-datasets."""
Expand Down
8 changes: 6 additions & 2 deletions rul_datasets/reader/femto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
class FemtoReader(AbstractReader):
"""
This reader represents the FEMTO (PRONOSTIA) Bearing dataset. Each of its three
sub-datasets contain a training and a test split. By default, the reader
sub-datasets contains a training and a test split. By default, the reader
constructs a validation split for sub-datasets 1 and 2 each by taking the first
run of the test split. For sub-dataset 3 the second training run is used for
run of the test split. For sub-dataset 3, the second training run is used for
validation because only one test run is available. The remaining training data is
denoted as the development split. This run to split assignment can be overridden
by setting `run_split_dist`.
Expand Down Expand Up @@ -130,6 +130,10 @@ def __init__(

self._preparator = FemtoPreparator(self.fd, self._FEMTO_ROOT, run_split_dist)

@property
def dataset_name(self) -> str:
return "femto"

@property
def fds(self) -> List[int]:
"""Indices of available sub-datasets."""
Expand Down
8 changes: 6 additions & 2 deletions rul_datasets/reader/xjtu_sy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
constant. The `norm_rul` argument can then be used to scale the RUL of each
run between zero and one.
For more information about using readers refer to the [reader]
For more information about using readers, refer to the [reader]
[rul_datasets.reader] module page.
Args:
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(

if (first_time_to_predict is not None) and (max_rul is not None):
raise ValueError(
"FemtoReader cannot use 'first_time_to_predict' "
"XjtuSyReader cannot use 'first_time_to_predict' "
"and 'max_rul' in conjunction."
)

Expand All @@ -123,6 +123,10 @@ def __init__(

self._preparator = XjtuSyPreparator(self.fd, self._XJTU_SY_ROOT, run_split_dist)

@property
def dataset_name(self) -> str:
return "xjtu-sy"

@property
def fds(self) -> List[int]:
"""Indices of available sub-datasets."""
Expand Down
9 changes: 1 addition & 8 deletions rul_datasets/ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,7 @@ def __init__(self, labeled: RulDataModule, unlabeled: RulDataModule) -> None:
self._check_compatibility()

self.save_hyperparameters(
{
"fd": self.labeled.reader.fd,
"batch_size": self.batch_size,
"window_size": self.labeled.reader.window_size,
"max_rul": self.labeled.reader.max_rul,
"percent_broken_unlabeled": self.unlabeled.reader.percent_broken,
"percent_fail_runs_labeled": self.labeled.reader.percent_fail_runs,
}
{"labeled": self.labeled.hparams, "unlabeled": self.unlabeled.hparams}
)

def _check_compatibility(self) -> None:
Expand Down
46 changes: 28 additions & 18 deletions tests/reader/test_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rul_datasets import reader


class DummyReader(reader.AbstractReader):
class DummyAbstractReader(reader.AbstractReader):
fd: int
window_size: int
max_rul: int
Expand All @@ -17,6 +17,10 @@ class DummyReader(reader.AbstractReader):

_NUM_TRAIN_RUNS = {1: 100}

@property
def dataset_name(self) -> str:
return "dummy_abstract"

@property
def fds(self):
return [1]
Expand All @@ -36,38 +40,44 @@ def load_complete_split(
class TestAbstractLoader:
@mock.patch("rul_datasets.reader.truncating.truncate_runs", return_value=([], []))
def test_truncation_dev_split(self, mock_truncate_runs):
this = DummyReader(1, 30, 125, percent_broken=0.2, percent_fail_runs=0.8)
this = DummyAbstractReader(
1, 30, 125, percent_broken=0.2, percent_fail_runs=0.8
)
this.load_split("dev")
mock_truncate_runs.assert_called_with([], [], 0.2, 0.8, False)

@mock.patch("rul_datasets.reader.truncating.truncate_runs", return_value=([], []))
def test_truncation_val_split(self, mock_truncate_runs):
this = DummyReader(1, 30, 125, percent_broken=0.2, percent_fail_runs=0.8)
this = DummyAbstractReader(
1, 30, 125, percent_broken=0.2, percent_fail_runs=0.8
)
this.load_split("val")
mock_truncate_runs.assert_not_called()

this = DummyReader(
this = DummyAbstractReader(
1, 30, 125, percent_broken=0.2, percent_fail_runs=0.8, truncate_val=True
)
this.load_split("val")
mock_truncate_runs.assert_called_with([], [], 0.2, degraded_only=False)

@mock.patch("rul_datasets.reader.truncating.truncate_runs", return_value=([], []))
def test_truncation_test_split(self, mock_truncate_runs):
this = DummyReader(1, 30, 125, percent_broken=0.2, percent_fail_runs=0.8)
this = DummyAbstractReader(
1, 30, 125, percent_broken=0.2, percent_fail_runs=0.8
)
this.load_split("val")
mock_truncate_runs.assert_not_called()

def test_check_compatibility(self):
this = DummyReader(1, 30, 125)
this.check_compatibility(DummyReader(1, 30, 125))
this = DummyAbstractReader(1, 30, 125)
this.check_compatibility(DummyAbstractReader(1, 30, 125))
with pytest.raises(ValueError):
this.check_compatibility(DummyReader(1, 20, 125))
this.check_compatibility(DummyAbstractReader(1, 20, 125))
with pytest.raises(ValueError):
this.check_compatibility(DummyReader(1, 30, 120))
this.check_compatibility(DummyAbstractReader(1, 30, 120))

def test_get_compatible_same(self):
this = DummyReader(1, 30, 125)
this = DummyAbstractReader(1, 30, 125)
other = this.get_compatible()
this.check_compatibility(other)
assert other is not this
Expand All @@ -79,7 +89,7 @@ def test_get_compatible_same(self):
assert this.truncate_val == other.truncate_val

def test_get_compatible_different(self):
this = DummyReader(1, 30, 125)
this = DummyAbstractReader(1, 30, 125)
other = this.get_compatible(2, 0.2, 0.8, False)
this.check_compatibility(other)
assert other is not this
Expand All @@ -92,21 +102,21 @@ def test_get_compatible_different(self):
assert not other.truncate_val

def test_get_complement_percentage(self):
this = DummyReader(1, 30, 125, percent_fail_runs=0.8)
this = DummyAbstractReader(1, 30, 125, percent_fail_runs=0.8)
other = this.get_complement(0.8, False)
assert other.percent_fail_runs == list(range(80, 100))
assert 0.8 == other.percent_broken
assert not other.truncate_val

def test_get_complement_idx(self):
this = DummyReader(1, 30, 125, percent_fail_runs=list(range(80)))
this = DummyAbstractReader(1, 30, 125, percent_fail_runs=list(range(80)))
other = this.get_complement(0.8, False)
assert other.percent_fail_runs == list(range(80, 100))
assert 0.8 == other.percent_broken
assert not other.truncate_val

def test_get_complement_empty(self):
this = DummyReader(1, 30, 125) # Uses all runs
this = DummyAbstractReader(1, 30, 125) # Uses all runs
other = this.get_complement(0.8, False)
assert not other.percent_fail_runs # Complement is empty
assert 0.8 == other.percent_broken
Expand All @@ -125,8 +135,8 @@ def test_get_complement_empty(self):
],
)
def test_is_mutually_exclusive(self, runs_this, runs_other, success):
this = DummyReader(1, percent_fail_runs=runs_this)
other = DummyReader(1, percent_fail_runs=runs_other)
this = DummyAbstractReader(1, percent_fail_runs=runs_this)
other = DummyAbstractReader(1, percent_fail_runs=runs_other)

assert this.is_mutually_exclusive(other) == success
assert other.is_mutually_exclusive(this) == success
Expand All @@ -136,7 +146,7 @@ def test_is_mutually_exclusive(self, runs_this, runs_other, success):
[("override", 30, 30), ("min", 15, 15), ("none", 30, 15)],
)
def test_consolidate_window_size(self, mode, expected_this, expected_other):
this = DummyReader(1, window_size=30)
this = DummyAbstractReader(1, window_size=30)
other = this.get_compatible(2, consolidate_window_size=mode)

assert this.window_size == expected_this
Expand All @@ -157,7 +167,7 @@ def test_consolidate_window_size(self, mode, expected_this, expected_other):
)
@mock.patch("rul_datasets.reader.truncating.truncate_runs", return_value=([], []))
def test_alias(self, mock_truncate_runs, split, alias, truncate_val, exp_truncated):
this = DummyReader(1, truncate_val=truncate_val)
this = DummyAbstractReader(1, truncate_val=truncate_val)
this.load_complete_split = mock.Mock(wraps=this.load_complete_split)

this.load_split(split, alias)
Expand Down
Loading

0 comments on commit 3347b65

Please sign in to comment.