Skip to content

Commit 6554702

Browse files
[ADD] Test evaluator (#368)
* add test evaluator * add no resampling and other changes for test evaluator * finalise changes for test_evaluator, TODO: tests * add tests for new functionality * fix flake and mypy * add documentation for the evaluator * add NoResampling to fit_pipeline * raise error when trying to construct ensemble with noresampling * fix tests * reduce fit_pipeline accuracy check * Apply suggestions from code review Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> * address comments from shuhei * fix bug in base data loader * fix bug in data loader for val set * fix bugs introduced in suggestions * fix flake * fix bug in test preprocessing * fix bug in test data loader * merge tests for evaluators and change listcomp in get_best_epoch * rename resampling strategies * add test for get dataset Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com>
1 parent c0fb82e commit 6554702

22 files changed

+817
-120
lines changed

autoPyTorch/api/base_task.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,12 @@
4040
)
4141
from autoPyTorch.data.base_validator import BaseInputValidator
4242
from autoPyTorch.datasets.base_dataset import BaseDataset, BaseDatasetPropertiesType
43-
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes
43+
from autoPyTorch.datasets.resampling_strategy import (
44+
CrossValTypes,
45+
HoldoutValTypes,
46+
NoResamplingStrategyTypes,
47+
ResamplingStrategies,
48+
)
4449
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
4550
from autoPyTorch.ensemble.singlebest_ensemble import SingleBest
4651
from autoPyTorch.evaluation.abstract_evaluator import fit_and_suppress_warnings
@@ -145,6 +150,13 @@ class BaseTask(ABC):
145150
name and Value is an Iterable of the names of the components
146151
to exclude. All except these components will be present in
147152
the search space.
153+
resampling_strategy resampling_strategy (RESAMPLING_STRATEGIES),
154+
(default=HoldoutValTypes.holdout_validation):
155+
strategy to split the training data.
156+
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
157+
required for the chosen resampling strategy. If None, uses
158+
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
159+
in ```datasets/resampling_strategy.py```.
148160
search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
149161
Search space updates that can be used to modify the search
150162
space of particular components or choice modules of the pipeline
@@ -166,11 +178,15 @@ def __init__(
166178
include_components: Optional[Dict[str, Any]] = None,
167179
exclude_components: Optional[Dict[str, Any]] = None,
168180
backend: Optional[Backend] = None,
169-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
181+
resampling_strategy: ResamplingStrategies = HoldoutValTypes.holdout_validation,
170182
resampling_strategy_args: Optional[Dict[str, Any]] = None,
171183
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
172184
task_type: Optional[str] = None
173185
) -> None:
186+
187+
if isinstance(resampling_strategy, NoResamplingStrategyTypes) and ensemble_size != 0:
188+
raise ValueError("`NoResamplingStrategy` cannot be used for ensemble construction")
189+
174190
self.seed = seed
175191
self.n_jobs = n_jobs
176192
self.n_threads = n_threads
@@ -280,7 +296,7 @@ def _get_dataset_input_validator(
280296
y_train: Union[List, pd.DataFrame, np.ndarray],
281297
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
282298
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
283-
resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] = None,
299+
resampling_strategy: Optional[ResamplingStrategies] = None,
284300
resampling_strategy_args: Optional[Dict[str, Any]] = None,
285301
dataset_name: Optional[str] = None,
286302
) -> Tuple[BaseDataset, BaseInputValidator]:
@@ -298,7 +314,7 @@ def _get_dataset_input_validator(
298314
Testing feature set
299315
y_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
300316
Testing target set
301-
resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
317+
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
302318
Strategy to split the training data. if None, uses
303319
HoldoutValTypes.holdout_validation.
304320
resampling_strategy_args (Optional[Dict[str, Any]]):
@@ -322,7 +338,7 @@ def get_dataset(
322338
y_train: Union[List, pd.DataFrame, np.ndarray],
323339
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
324340
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
325-
resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] = None,
341+
resampling_strategy: Optional[ResamplingStrategies] = None,
326342
resampling_strategy_args: Optional[Dict[str, Any]] = None,
327343
dataset_name: Optional[str] = None,
328344
) -> BaseDataset:
@@ -338,7 +354,7 @@ def get_dataset(
338354
Testing feature set
339355
y_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
340356
Testing target set
341-
resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
357+
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
342358
Strategy to split the training data. if None, uses
343359
HoldoutValTypes.holdout_validation.
344360
resampling_strategy_args (Optional[Dict[str, Any]]):
@@ -1360,7 +1376,7 @@ def fit_pipeline(
13601376
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
13611377
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
13621378
dataset_name: Optional[str] = None,
1363-
resampling_strategy: Optional[Union[HoldoutValTypes, CrossValTypes]] = None,
1379+
resampling_strategy: Optional[Union[HoldoutValTypes, CrossValTypes, NoResamplingStrategyTypes]] = None,
13641380
resampling_strategy_args: Optional[Dict[str, Any]] = None,
13651381
run_time_limit_secs: int = 60,
13661382
memory_limit: Optional[int] = None,
@@ -1395,7 +1411,7 @@ def fit_pipeline(
13951411
be provided to track the generalization performance of each stage.
13961412
dataset_name (Optional[str]):
13971413
Name of the dataset, if None, random value is used.
1398-
resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
1414+
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
13991415
Strategy to split the training data. if None, uses
14001416
HoldoutValTypes.holdout_validation.
14011417
resampling_strategy_args (Optional[Dict[str, Any]]):
@@ -1657,7 +1673,7 @@ def predict(
16571673
# Mypy assert
16581674
assert self.ensemble_ is not None, "Load models should error out if no ensemble"
16591675

1660-
if isinstance(self.resampling_strategy, HoldoutValTypes):
1676+
if isinstance(self.resampling_strategy, (HoldoutValTypes, NoResamplingStrategyTypes)):
16611677
models = self.models_
16621678
elif isinstance(self.resampling_strategy, CrossValTypes):
16631679
models = self.cv_models_

autoPyTorch/api/tabular_classification.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from autoPyTorch.data.tabular_validator import TabularInputValidator
1414
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1515
from autoPyTorch.datasets.resampling_strategy import (
16-
CrossValTypes,
1716
HoldoutValTypes,
17+
ResamplingStrategies,
1818
)
1919
from autoPyTorch.datasets.tabular_dataset import TabularDataset
2020
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
@@ -64,8 +64,15 @@ class TabularClassificationTask(BaseTask):
6464
name and Value is an Iterable of the names of the components
6565
to exclude. All except these components will be present in
6666
the search space.
67+
resampling_strategy resampling_strategy (RESAMPLING_STRATEGIES),
68+
(default=HoldoutValTypes.holdout_validation):
69+
strategy to split the training data.
70+
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
71+
required for the chosen resampling strategy. If None, uses
72+
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
73+
in ```datasets/resampling_strategy.py```.
6774
search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
68-
search space updates that can be used to modify the search
75+
Search space updates that can be used to modify the search
6976
space of particular components or choice modules of the pipeline
7077
"""
7178
def __init__(
@@ -83,7 +90,7 @@ def __init__(
8390
delete_output_folder_after_terminate: bool = True,
8491
include_components: Optional[Dict[str, Any]] = None,
8592
exclude_components: Optional[Dict[str, Any]] = None,
86-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
93+
resampling_strategy: ResamplingStrategies = HoldoutValTypes.holdout_validation,
8794
resampling_strategy_args: Optional[Dict[str, Any]] = None,
8895
backend: Optional[Backend] = None,
8996
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
@@ -153,7 +160,7 @@ def _get_dataset_input_validator(
153160
y_train: Union[List, pd.DataFrame, np.ndarray],
154161
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
155162
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
156-
resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] = None,
163+
resampling_strategy: Optional[ResamplingStrategies] = None,
157164
resampling_strategy_args: Optional[Dict[str, Any]] = None,
158165
dataset_name: Optional[str] = None,
159166
) -> Tuple[TabularDataset, TabularInputValidator]:
@@ -170,7 +177,7 @@ def _get_dataset_input_validator(
170177
Testing feature set
171178
y_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
172179
Testing target set
173-
resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
180+
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
174181
Strategy to split the training data. if None, uses
175182
HoldoutValTypes.holdout_validation.
176183
resampling_strategy_args (Optional[Dict[str, Any]]):

autoPyTorch/api/tabular_regression.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from autoPyTorch.data.tabular_validator import TabularInputValidator
1414
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1515
from autoPyTorch.datasets.resampling_strategy import (
16-
CrossValTypes,
1716
HoldoutValTypes,
17+
ResamplingStrategies,
1818
)
1919
from autoPyTorch.datasets.tabular_dataset import TabularDataset
2020
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
@@ -64,8 +64,15 @@ class TabularRegressionTask(BaseTask):
6464
name and Value is an Iterable of the names of the components
6565
to exclude. All except these components will be present in
6666
the search space.
67+
resampling_strategy resampling_strategy (RESAMPLING_STRATEGIES),
68+
(default=HoldoutValTypes.holdout_validation):
69+
strategy to split the training data.
70+
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
71+
required for the chosen resampling strategy. If None, uses
72+
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
73+
in ```datasets/resampling_strategy.py```.
6774
search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
68-
search space updates that can be used to modify the search
75+
Search space updates that can be used to modify the search
6976
space of particular components or choice modules of the pipeline
7077
"""
7178

@@ -84,7 +91,7 @@ def __init__(
8491
delete_output_folder_after_terminate: bool = True,
8592
include_components: Optional[Dict[str, Any]] = None,
8693
exclude_components: Optional[Dict[str, Any]] = None,
87-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
94+
resampling_strategy: ResamplingStrategies = HoldoutValTypes.holdout_validation,
8895
resampling_strategy_args: Optional[Dict[str, Any]] = None,
8996
backend: Optional[Backend] = None,
9097
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
@@ -154,7 +161,7 @@ def _get_dataset_input_validator(
154161
y_train: Union[List, pd.DataFrame, np.ndarray],
155162
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
156163
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
157-
resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] = None,
164+
resampling_strategy: Optional[ResamplingStrategies] = None,
158165
resampling_strategy_args: Optional[Dict[str, Any]] = None,
159166
dataset_name: Optional[str] = None,
160167
) -> Tuple[TabularDataset, TabularInputValidator]:
@@ -171,7 +178,7 @@ def _get_dataset_input_validator(
171178
Testing feature set
172179
y_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
173180
Testing target set
174-
resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
181+
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
175182
Strategy to split the training data. if None, uses
176183
HoldoutValTypes.holdout_validation.
177184
resampling_strategy_args (Optional[Dict[str, Any]]):

autoPyTorch/datasets/base_dataset.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
DEFAULT_RESAMPLING_PARAMETERS,
2222
HoldOutFunc,
2323
HoldOutFuncs,
24-
HoldoutValTypes
24+
HoldoutValTypes,
25+
NoResamplingFunc,
26+
NoResamplingFuncs,
27+
NoResamplingStrategyTypes,
28+
ResamplingStrategies
2529
)
2630
from autoPyTorch.utils.common import FitRequirement
2731

@@ -78,7 +82,7 @@ def __init__(
7882
dataset_name: Optional[str] = None,
7983
val_tensors: Optional[BaseDatasetInputType] = None,
8084
test_tensors: Optional[BaseDatasetInputType] = None,
81-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
85+
resampling_strategy: ResamplingStrategies = HoldoutValTypes.holdout_validation,
8286
resampling_strategy_args: Optional[Dict[str, Any]] = None,
8387
shuffle: Optional[bool] = True,
8488
seed: Optional[int] = 42,
@@ -95,8 +99,7 @@ def __init__(
9599
validation data
96100
test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute):
97101
test data
98-
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]),
99-
(default=HoldoutValTypes.holdout_validation):
102+
resampling_strategy (RESAMPLING_STRATEGIES: default=HoldoutValTypes.holdout_validation):
100103
strategy to split the training data.
101104
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
102105
required for the chosen resampling strategy. If None, uses
@@ -109,16 +112,18 @@ def __init__(
109112
val_transforms (Optional[torchvision.transforms.Compose]):
110113
Additional Transforms to be applied to the validation/test data
111114
"""
112-
self.dataset_name = dataset_name
113115

114-
if self.dataset_name is None:
116+
if dataset_name is None:
115117
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
118+
else:
119+
self.dataset_name = dataset_name
116120

117121
if not hasattr(train_tensors[0], 'shape'):
118122
type_check(train_tensors, val_tensors)
119123
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
120124
self.cross_validators: Dict[str, CrossValFunc] = {}
121125
self.holdout_validators: Dict[str, HoldOutFunc] = {}
126+
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
122127
self.random_state = np.random.RandomState(seed=seed)
123128
self.shuffle = shuffle
124129
self.resampling_strategy = resampling_strategy
@@ -143,6 +148,8 @@ def __init__(
143148
# Make sure cross validation splits are created once
144149
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
145150
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
151+
self.no_resampling_validators = NoResamplingFuncs.get_no_resampling_validators(*NoResamplingStrategyTypes)
152+
146153
self.splits = self.get_splits_from_resampling_strategy()
147154

148155
# We also need to be able to transform the data, be it for pre-processing
@@ -210,7 +217,7 @@ def __len__(self) -> int:
210217
def _get_indices(self) -> np.ndarray:
211218
return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self))
212219

213-
def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]:
220+
def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], Optional[List[int]]]]:
214221
"""
215222
Creates a set of splits based on a resampling strategy provided
216223
@@ -241,6 +248,9 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
241248
num_splits=cast(int, num_splits),
242249
)
243250
)
251+
elif isinstance(self.resampling_strategy, NoResamplingStrategyTypes):
252+
splits.append((self.no_resampling_validators[self.resampling_strategy.name](self.random_state,
253+
self._get_indices()), None))
244254
else:
245255
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}")
246256
return splits
@@ -312,22 +322,29 @@ def create_holdout_val_split(
312322
self.random_state, val_share, self._get_indices(), **kwargs)
313323
return train, val
314324

315-
def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
325+
def get_dataset(self, split_id: int, train: bool) -> Dataset:
316326
"""
317327
The above split methods employ the Subset to internally subsample the whole dataset.
318328
319329
During training, we need access to one of those splits. This is a handy function
320330
to provide training data to fit a pipeline
321331
322332
Args:
323-
split (int): The desired subset of the dataset to split and use
333+
split_id (int): which split id to get from the splits
334+
train (bool): whether the dataset is required for training or evaluating.
324335
325336
Returns:
326337
Dataset: the reduced dataset to be used for testing
327338
"""
328339
# Subset creates a dataset. Splits is a (train_indices, test_indices) tuple
329-
return (TransformSubset(self, self.splits[split_id][0], train=True),
330-
TransformSubset(self, self.splits[split_id][1], train=False))
340+
if split_id >= len(self.splits): # old version: split_id > len(self.splits)
341+
raise IndexError(f"self.splits index out of range, got split_id={split_id}"
342+
f" (>= num_splits={len(self.splits)})")
343+
indices = self.splits[split_id][int(not train)] # 0: for training, 1: for evaluation
344+
if indices is None:
345+
raise ValueError("Specified fold (or subset) does not exist")
346+
347+
return TransformSubset(self, indices, train=train)
331348

332349
def replace_data(self, X_train: BaseDatasetInputType,
333350
X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset':

autoPyTorch/datasets/image_dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from autoPyTorch.datasets.resampling_strategy import (
2525
CrossValTypes,
2626
HoldoutValTypes,
27+
NoResamplingStrategyTypes
2728
)
2829

2930
IMAGE_DATASET_INPUT = Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]
@@ -39,7 +40,7 @@ class ImageDataset(BaseDataset):
3940
validation data
4041
test (Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]):
4142
testing data
42-
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]),
43+
resampling_strategy (Union[CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]),
4344
(default=HoldoutValTypes.holdout_validation):
4445
strategy to split the training data.
4546
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
@@ -57,7 +58,9 @@ def __init__(self,
5758
train: IMAGE_DATASET_INPUT,
5859
val: Optional[IMAGE_DATASET_INPUT] = None,
5960
test: Optional[IMAGE_DATASET_INPUT] = None,
60-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
61+
resampling_strategy: Union[CrossValTypes,
62+
HoldoutValTypes,
63+
NoResamplingStrategyTypes] = HoldoutValTypes.holdout_validation,
6164
resampling_strategy_args: Optional[Dict[str, Any]] = None,
6265
shuffle: Optional[bool] = True,
6366
seed: Optional[int] = 42,

0 commit comments

Comments
 (0)